22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Wx, bx = lstm_state['weight_ih'], lstm_state['bias_ih']

Wh, bh = lstm_state['weight_hh'], lstm_state['bias_hh']

# Split weights and biases for data points

Wxi, Wxf, Wxg, Wxo = Wx.split(hidden_dim, dim=0)

bxi, bxf, bxg, bxo = bx.split(hidden_dim, dim=0)

# Split weights and biases for hidden state

Whi, Whf, Whg, Who = Wh.split(hidden_dim, dim=0)

bhi, bhf, bhg, bho = bh.split(hidden_dim, dim=0)

# Creates linear layers for the components

# input gate - green

i_hidden, i_input = linear_layers(Wxi, bxi, Whi, bhi)

# forget gate - red

f_hidden, f_input = linear_layers(Wxf, bxf, Whf, bhf)

# output gate - blue

o_hidden, o_input = linear_layers(Wxo, bxo, Who, bho)

"Wait! Isn’t there a component missing? You mentioned four of them;

where are the linear layers for g?"

Good catch! It turns out we don’t need linear layers for g because it is an RNN cell

on its own! We can simply use load_state_dict() to create the corresponding cell:

g_cell = nn.RNNCell(n_features, hidden_dim) # black

g_cell.load_state_dict({'weight_ih': Wxg, 'bias_ih': bxg,

'weight_hh': Whg, 'bias_hh': bhg})

Output

<All keys matched successfully>

That was easy, right? Since the other components are gates, we need to create

functions for them:

646 | Chapter 8: Sequences

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!