Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

train_loader = DataLoader(dataset=train_dataset, batch_size=16, sampler=sampler)val_loader = DataLoader(dataset=val_dataset, batch_size=16)Once again, if we’re using a sampler, we cannot use the shuffle argument.There is a lot of boilerplate code here, right? Let’s build yet another function,Helper Function #5, to wrap it all up:Helper Function #51 def make_balanced_sampler(y):2 # Computes weights for compensating imbalanced classes3 classes, counts = y.unique(return_counts=True)4 weights = 1.0 / counts.float()5 sample_weights = weights[y.squeeze().long()]6 # Builds sampler with compute weights7 generator = torch.Generator()8 sampler = WeightedRandomSampler(9 weights=sample_weights,10 num_samples=len(sample_weights),11 generator=generator,12 replacement=True13 )14 return samplersampler = make_balanced_sampler(y_train_tensor)Much better! Its only argument is the tensor containing the labels: The functionwill compute the weights and build the corresponding weighted sampler on its own.Seeds and more (seeds)Time to set the seed for the generator used in the sampler assigned to the dataloader. It is a long sequence of objects, but we can work our way through it toretrieve the generator and call its manual_seed() method:Data Preparation | 293

train_loader.sampler.generator.manual_seed(42)random.seed(42)Now we can check if our sampler is doing its job correctly. Let’s have it sample a fullrun (240 data points in 15 mini-batches of 16 points each), and sum up the labels sowe know how many points are in the positive class:torch.tensor([t[1].sum() for t in iter(train_loader)]).sum()Outputtensor(123.)Close enough! We have 160 images of the positive class, and now, thanks to theweighted sampler, we’re sampling only 123 of them. It means we’re oversamplingthe negative class (which has 80 images) to a total of 117 images, adding up to 240images. Mission accomplished, our dataset is balanced now."Wait a minute! Why on Earth there was an extra seed(random.seed(42)) in the code above? Don’t we have enoughalready?"I agree, too many seeds. Besides one specific seed for the generator, we also haveto set yet another seed for Python’s random module.Honestly, this came to me as a surprise too when I found outabout it! As weird as it may sound, in Torchvision versions priorto 0.8, there was still some code that depended upon Python’snative random module, instead of PyTorch’s own randomgenerators. The problem happened when some of the randomtransformations for data augmentation were used, likeRandomRotation(), RandomAffine(), and others.It’s better to be safe than sorry, so we better set yet another seed to ensure thereproducibility of our code.And that’s exactly what we’re going to do! Remember the set_seed() method we294 | Chapter 4: Classifying Images

train_loader = DataLoader(

dataset=train_dataset, batch_size=16, sampler=sampler

)

val_loader = DataLoader(dataset=val_dataset, batch_size=16)

Once again, if we’re using a sampler, we cannot use the shuffle argument.

There is a lot of boilerplate code here, right? Let’s build yet another function,

Helper Function #5, to wrap it all up:

Helper Function #5

1 def make_balanced_sampler(y):

2 # Computes weights for compensating imbalanced classes

3 classes, counts = y.unique(return_counts=True)

4 weights = 1.0 / counts.float()

5 sample_weights = weights[y.squeeze().long()]

6 # Builds sampler with compute weights

7 generator = torch.Generator()

8 sampler = WeightedRandomSampler(

9 weights=sample_weights,

10 num_samples=len(sample_weights),

11 generator=generator,

12 replacement=True

13 )

14 return sampler

sampler = make_balanced_sampler(y_train_tensor)

Much better! Its only argument is the tensor containing the labels: The function

will compute the weights and build the corresponding weighted sampler on its own.

Seeds and more (seeds)

Time to set the seed for the generator used in the sampler assigned to the data

loader. It is a long sequence of objects, but we can work our way through it to

retrieve the generator and call its manual_seed() method:

Data Preparation | 293

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!