Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

OutputOrderedDict([('running_mean', tensor([0.9070, 1.0931])),('running_var', tensor([1.2592, 0.9192])),('num_batches_tracked', tensor(2))])Both running mean and running variance are simple averages over the minibatches:mean2, var2 = batch2[0].mean(axis=0), batch2[0].var(axis=0)running_mean, running_var = (mean1 + mean2) / 2, (var1 + var2) / 2running_mean, running_varOutput(tensor([0.9070, 1.0931]), tensor([1.2592, 0.9192]))Now, let’s pretend we have finished training (even though we don’t have an actualmodel), and we’re using the third mini-batch for evaluation.Evaluation PhaseJust like dropout, batch normalization exhibits different behaviors depending onthe mode: train or eval. We’ve already seen what it does during the trainingphase. We’ve also realized that it doesn’t make sense to compute statistics for anydata that isn’t training data.So, in the evaluation phase, it will use the running statistics computed duringtraining to standardize the new data (the third mini-batch, in our small example):batch_normalizer.eval()normed3 = batch_normalizer(batch3[0])normed3.mean(axis=0), normed3.var(axis=0, unbiased=False)Output(tensor([ 0.1590, -0.0970]), tensor([1.0134, 1.4166]))Batch Normalization | 541

"Is it a bit off again?"Actually, no—since it is standardizing unseen data using statistics computed ontraining data, the results above are expected. The mean will be around zero andthe standard deviation will be around one.MomentumThere is an alternative way of computing running statistics: Instead of using asimple average, it uses an exponentially weighted moving average (EWMA) of thestatistics.The naming convention, though, is very unfortunate: The alpha parameter of theEWMA was named momentum, adding to the confusion. There is even a note inPyTorch’s documentation warning about this:"This momentum argument is different from one used in optimizer classes and theconventional notion of momentum." [128]The bottom line is: Ignore the confusing naming convention andthink of the "momentum" argument as the alpha parameter of aregular EWMA.The documentation also uses x to refer to a particular statistic when introducingthe mathematical formula of "momentum," which does not help at all.So, to make it abundantly clear what is being computed, I present the formulasbelow:Equation 7.7 - Running statisticLet’s try it out in practice:batch_normalizer_mom = nn.BatchNorm1d(num_features=2, affine=False, momentum=0.1)batch_normalizer_mom.state_dict()542 | Chapter 7: Transfer Learning

"Is it a bit off again?"

Actually, no—since it is standardizing unseen data using statistics computed on

training data, the results above are expected. The mean will be around zero and

the standard deviation will be around one.

Momentum

There is an alternative way of computing running statistics: Instead of using a

simple average, it uses an exponentially weighted moving average (EWMA) of the

statistics.

The naming convention, though, is very unfortunate: The alpha parameter of the

EWMA was named momentum, adding to the confusion. There is even a note in

PyTorch’s documentation warning about this:

"This momentum argument is different from one used in optimizer classes and the

conventional notion of momentum." [128]

The bottom line is: Ignore the confusing naming convention and

think of the "momentum" argument as the alpha parameter of a

regular EWMA.

The documentation also uses x to refer to a particular statistic when introducing

the mathematical formula of "momentum," which does not help at all.

So, to make it abundantly clear what is being computed, I present the formulas

below:

Equation 7.7 - Running statistic

Let’s try it out in practice:

batch_normalizer_mom = nn.BatchNorm1d(

num_features=2, affine=False, momentum=0.1

)

batch_normalizer_mom.state_dict()

542 | Chapter 7: Transfer Learning

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!