22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

"values") in the decoder.

• decode(): It takes the shifted target sequence and uses its projection together

with both source and target masks to generate a target sequence that goes

through the last linear layer to be transformed back to the feature space—it is

used for training only!

The parent class is reproduced below for your convenience:

Encoder + Decoder + Self-Attention

1 class EncoderDecoderSelfAttn(nn.Module):

2 def __init__(self, encoder, decoder, input_len, target_len):

3 super().__init__()

4 self.encoder = encoder

5 self.decoder = decoder

6 self.input_len = input_len

7 self.target_len = target_len

8 self.trg_masks = self.subsequent_mask(self.target_len)

9

10 @staticmethod

11 def subsequent_mask(size):

12 attn_shape = (1, size, size)

13 subsequent_mask = (

14 1 - torch.triu(torch.ones(attn_shape), diagonal=1)

15 ).bool()

16 return subsequent_mask

17

18 def encode(self, source_seq, source_mask):

19 # Encodes the source sequence and uses the result

20 # to initialize the decoder

21 encoder_states = self.encoder(source_seq, source_mask)

22 self.decoder.init_keys(encoder_states)

23

24 def decode(self, shifted_target_seq,

25 source_mask=None, target_mask=None):

26 # Decodes/generates a sequence using the shifted (masked)

27 # target sequence - used in TRAIN mode

28 outputs = self.decoder(shifted_target_seq,

29 source_mask=source_mask,

30 target_mask=target_mask)

31 return outputs

32

The Transformer | 833

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!