Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

Figure 9.27 - Encoder with self- and cross-attentionsIf you’re wondering why we removed the concatenation part, here comes theanswer: We’re using self-attention as a decoder too.Self-Attention | 747

DecoderThere is one main difference (in the code) between the encoder and thedecoder—the latter includes a cross-attention mechanism, as you can see below:Decoder + Self-Attention1 class DecoderSelfAttn(nn.Module):2 def __init__(self, n_heads, d_model,3 ff_units, n_features=None):4 super().__init__()5 self.n_heads = n_heads6 self.d_model = d_model7 self.ff_units = ff_units8 self.n_features = d_model if n_features is None \9 else n_features10 self.self_attn_heads = \11 MultiHeadAttention(n_heads, d_model,12 input_dim=self.n_features)13 self.cross_attn_heads = \ 114 MultiHeadAttention(n_heads, d_model)15 self.ffn = nn.Sequential(16 nn.Linear(d_model, ff_units),17 nn.ReLU(),18 nn.Linear(ff_units, self.n_features))1920 def init_keys(self, states): 121 self.cross_attn_heads.init_keys(states)2223 def forward(self, query, source_mask=None, target_mask=None):24 self.self_attn_heads.init_keys(query)25 att1 = self.self_attn_heads(query, target_mask)26 att2 = self.cross_attn_heads(att1, source_mask) 127 out = self.ffn(att2)28 return out1 Including cross-attention748 | Chapter 9 — Part II: Sequence-to-Sequence

Decoder

There is one main difference (in the code) between the encoder and the

decoder—the latter includes a cross-attention mechanism, as you can see below:

Decoder + Self-Attention

1 class DecoderSelfAttn(nn.Module):

2 def __init__(self, n_heads, d_model,

3 ff_units, n_features=None):

4 super().__init__()

5 self.n_heads = n_heads

6 self.d_model = d_model

7 self.ff_units = ff_units

8 self.n_features = d_model if n_features is None \

9 else n_features

10 self.self_attn_heads = \

11 MultiHeadAttention(n_heads, d_model,

12 input_dim=self.n_features)

13 self.cross_attn_heads = \ 1

14 MultiHeadAttention(n_heads, d_model)

15 self.ffn = nn.Sequential(

16 nn.Linear(d_model, ff_units),

17 nn.ReLU(),

18 nn.Linear(ff_units, self.n_features))

19

20 def init_keys(self, states): 1

21 self.cross_attn_heads.init_keys(states)

22

23 def forward(self, query, source_mask=None, target_mask=None):

24 self.self_attn_heads.init_keys(query)

25 att1 = self.self_attn_heads(query, target_mask)

26 att2 = self.cross_attn_heads(att1, source_mask) 1

27 out = self.ffn(att2)

28 return out

1 Including cross-attention

748 | Chapter 9 — Part II: Sequence-to-Sequence

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!