22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Saving and Loading Models

Training a model successfully is great, no doubt about that, but not all models will

train quickly, and training may get interrupted (computer crashing, timeout after

12 hours of continuous GPU usage on Google Colab, etc.). It would be a pity to have

to start over, right?

So, it is important to be able to checkpoint or save our model, that is, save it to disk,

in case we’d like to restart training later or deploy it as an application to make

predictions.

Model State

To checkpoint a model, we basically have to save its state to a file so that it can be

loaded back later—nothing special, actually.

What defines the state of a model?

• model.state_dict(): kinda obvious, right?

• optimizer.state_dict(): remember, optimizers have a state_dict() as well

• losses: after all, you should keep track of its evolution

• epoch: it is just a number, so why not? :-)

• anything else you’d like to have restored later

Saving

Now, we wrap everything into a Python dictionary and use torch.save() to dump

it all into a file. Easy peasy! We have just saved our model to a file named

model_checkpoint.pth.

Notebook Cell 2.4 - Saving checkpoint

checkpoint = {'epoch': n_epochs,

'model_state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),

'loss': losses,

'val_loss': val_losses}

torch.save(checkpoint, 'model_checkpoint.pth')

164 | Chapter 2: Rethinking the Training Loop

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!