22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Loading Weights

First things first: To load the weights into our model, we need to retrieve them.

Sure, an easy way of retrieving them would be to set the argument

pretrained=True while creating the model. But you can also download the weights

from a given URL, which gives you the flexibility to use pre-trained weights from

wherever you want!

PyTorch offers the load_state_dict_from_url() method: It will retrieve the

weights from the specified URL and, optionally, save them to a specified folder

(model_dir argument).

"Great, but what’s the URL for AlexNet’s weights?"

You can get the URL from the model_urls variable in

torchvision.models.alexnet:

URL for AlexNet’s Pre-trained Weights

1 url = model_urls['alexnet']

2 url

Output

'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'

Of course, it doesn’t make any sense to do this manually for models in PyTorch’s

library. But, assuming you’re using pre-trained weights from a third party, you’d be

able to load them like this:

Loading Pre-trained Weights

1 state_dict = load_state_dict_from_url(

2 url, model_dir='pretrained', progress=True

3 )

508 | Chapter 7: Transfer Learning

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!