Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

1 First change: from RNN to LSTM2 Second change: including the cell state as outputModel Configuration & TrainingModel Configuration1 torch.manual_seed(21)2 model = SquareModelLSTM(n_features=2, hidden_dim=2, n_outputs=1)3 loss = nn.BCEWithLogitsLoss()4 optimizer = optim.Adam(model.parameters(), lr=0.01)Model Training1 sbs_lstm = StepByStep(model, loss, optimizer)2 sbs_lstm.set_loaders(train_loader, test_loader)3 sbs_lstm.train(100)fig = sbs_lstm.plot_losses()Figure 8.24 - Losses—SquareModelLSTMStepByStep.loader_apply(test_loader, sbs_lstm.correct)Long Short-Term Memory (LSTM) | 651

Outputtensor([[53, 53],[75, 75]])And that’s 100% accuracy again!Visualizing the Hidden StatesOnce again, if we use the "perfect" square as the input to our latest trained model,we get the following final hidden states for each of the eight sequences (plottedside-by-side with the previous models for easier comparison):Figure 8.25 - Final hidden states for eight sequences of the "perfect" squareThe LSTM model achieves a difference that’s not necessarily better than the GRU.What about the actual sequences?Figure 8.26 - Sequence of hidden states652 | Chapter 8: Sequences

1 First change: from RNN to LSTM

2 Second change: including the cell state as output

Model Configuration & Training

Model Configuration

1 torch.manual_seed(21)

2 model = SquareModelLSTM(n_features=2, hidden_dim=2, n_outputs=1)

3 loss = nn.BCEWithLogitsLoss()

4 optimizer = optim.Adam(model.parameters(), lr=0.01)

Model Training

1 sbs_lstm = StepByStep(model, loss, optimizer)

2 sbs_lstm.set_loaders(train_loader, test_loader)

3 sbs_lstm.train(100)

fig = sbs_lstm.plot_losses()

Figure 8.24 - Losses—SquareModelLSTM

StepByStep.loader_apply(test_loader, sbs_lstm.correct)

Long Short-Term Memory (LSTM) | 651

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!