22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Figure 9.13 - Computing the context vector

The "query" (Q) is matched to both "keys" (K) to compute the attention scores (s)

used to compute the context vector, which is simply the weighted average of the

"values" (V).

Computing the Context Vector

Let’s go over a simple example in code using our own sequence-to-sequence

problem, and the "perfect" square as input:

full_seq = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]])

.float()

.view(1, 4, 2))

source_seq = full_seq[:, :2]

target_seq = full_seq[:, 2:]

The source sequence is the input of the encoder, and the hidden states it outputs

are going to be both "values" (V) and "keys" (K):

torch.manual_seed(21)

encoder = Encoder(n_features=2, hidden_dim=2)

hidden_seq = encoder(source_seq)

values = hidden_seq # N, L, H

values

710 | Chapter 9 — Part I: Sequence-to-Sequence

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!