Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

1. Encoder-DecoderThe encoder-decoder architecture was actually extended from the one developedin the previous chapter (EncoderDecoderSelfAttn), which handled training andprediction using greedy decoding. There are no changes here, except for theomission of both encode() and decode() methods, which are going to beoverridden anyway:Encoder + Decoder + Self-Attention1 class EncoderDecoderSelfAttn(nn.Module):2 def __init__(self, encoder, decoder, input_len, target_len):3 super().__init__()4 self.encoder = encoder5 self.decoder = decoder6 self.input_len = input_len7 self.target_len = target_len8 self.trg_masks = self.subsequent_mask(self.target_len)910 @staticmethod11 def subsequent_mask(size):12 attn_shape = (1, size, size)13 subsequent_mask = (14 1 - torch.triu(torch.ones(attn_shape), diagonal=1)15 ).bool()16 return subsequent_mask1718 def predict(self, source_seq, source_mask):19 # Decodes/generates a sequence using one input20 # at a time - used in EVAL mode21 inputs = source_seq[:, -1:]22 for i in range(self.target_len):23 out = self.decode(inputs,24 source_mask,25 self.trg_masks[:, :i+1, :i+1])26 out = torch.cat([inputs, out[:, -1:, :]], dim=-2)27 inputs = out.detach()28 outputs = inputs[:, 1:, :]29 return outputs3031 def forward(self, X, source_mask=None):32 # Sends the mask to the same device as the inputs33 self.trg_masks = self.trg_masks.type_as(X).bool()Putting It All Together | 863

34 # Slices the input to get source sequence35 source_seq = X[:, :self.input_len, :]36 # Encodes source sequence AND initializes decoder37 self.encode(source_seq, source_mask)38 if self.training:39 # Slices the input to get the shifted target seq40 shifted_target_seq = X[:, self.input_len-1:-1, :]41 # Decodes using the mask to prevent cheating42 outputs = self.decode(shifted_target_seq,43 source_mask,44 self.trg_masks)45 else:46 # Decodes using its own predictions47 outputs = self.predict(source_seq, source_mask)4849 return outputs864 | Chapter 10: Transform and Roll Out

34 # Slices the input to get source sequence

35 source_seq = X[:, :self.input_len, :]

36 # Encodes source sequence AND initializes decoder

37 self.encode(source_seq, source_mask)

38 if self.training:

39 # Slices the input to get the shifted target seq

40 shifted_target_seq = X[:, self.input_len-1:-1, :]

41 # Decodes using the mask to prevent cheating

42 outputs = self.decode(shifted_target_seq,

43 source_mask,

44 self.trg_masks)

45 else:

46 # Decodes using its own predictions

47 outputs = self.predict(source_seq, source_mask)

48

49 return outputs

864 | Chapter 10: Transform and Roll Out

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!