Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

Figure 10.20 - Sample image—label "2""But this is a classification problem, not a sequence-to-sequenceone—why are we using a Transformer then?"Well, we’re not using the full Transformer architecture, only its encoder. In Chapter8, we used recurrent neural networks to generate a final hidden state that we usedas the input for classification. Similarly, the encoder generates a sequence of"hidden states" (the memory, in Transformer lingo), and we’re using one "hiddenstate" as the input for classification again."Which one? The last 'hidden state'?"No, not the last one, but a special one. We’ll prepend a special classifier token[CLS] to our sequence and use its corresponding "hidden state" as input to aclassifier. The figure below illustrates the idea.Figure 10.21 - Hidden states and the special classifier token [CLS]But I’m jumping the gun here—we’ll get back to that in the "Special Classifier Token"section.Vision Transformer | 847

The data preparation step is exactly the same one we used in Chapter 5:Data Preparation1 class TransformedTensorDataset(Dataset):2 def __init__(self, x, y, transform=None):3 self.x = x4 self.y = y5 self.transform = transform67 def __getitem__(self, index):8 x = self.x[index]9 if self.transform:10 x = self.transform(x)1112 return x, self.y[index]1314 def __len__(self):15 return len(self.x)1617 # Builds tensors from numpy arrays BEFORE split18 # Modifies the scale of pixel values from [0, 255] to [0, 1]19 x_tensor = torch.as_tensor(images / 255).float()20 y_tensor = torch.as_tensor(labels).long()2122 # Uses index_splitter to generate indices for training and23 # validation sets24 train_idx, val_idx = index_splitter(len(x_tensor), [80, 20])25 # Uses indices to perform the split26 x_train_tensor = x_tensor[train_idx]27 y_train_tensor = y_tensor[train_idx]28 x_val_tensor = x_tensor[val_idx]29 y_val_tensor = y_tensor[val_idx]3031 # We're not doing any data augmentation now32 train_composer = Compose([Normalize(mean=(.5,), std=(.5,))])33 val_composer = Compose([Normalize(mean=(.5,), std=(.5,))])3435 # Uses custom dataset to apply composed transforms to each set36 train_dataset = TransformedTensorDataset(37 x_train_tensor, y_train_tensor, transform=train_composer)38 val_dataset = TransformedTensorDataset(39 x_val_tensor, y_val_tensor, transform=val_composer)848 | Chapter 10: Transform and Roll Out

The data preparation step is exactly the same one we used in Chapter 5:

Data Preparation

1 class TransformedTensorDataset(Dataset):

2 def __init__(self, x, y, transform=None):

3 self.x = x

4 self.y = y

5 self.transform = transform

6

7 def __getitem__(self, index):

8 x = self.x[index]

9 if self.transform:

10 x = self.transform(x)

11

12 return x, self.y[index]

13

14 def __len__(self):

15 return len(self.x)

16

17 # Builds tensors from numpy arrays BEFORE split

18 # Modifies the scale of pixel values from [0, 255] to [0, 1]

19 x_tensor = torch.as_tensor(images / 255).float()

20 y_tensor = torch.as_tensor(labels).long()

21

22 # Uses index_splitter to generate indices for training and

23 # validation sets

24 train_idx, val_idx = index_splitter(len(x_tensor), [80, 20])

25 # Uses indices to perform the split

26 x_train_tensor = x_tensor[train_idx]

27 y_train_tensor = y_tensor[train_idx]

28 x_val_tensor = x_tensor[val_idx]

29 y_val_tensor = y_tensor[val_idx]

30

31 # We're not doing any data augmentation now

32 train_composer = Compose([Normalize(mean=(.5,), std=(.5,))])

33 val_composer = Compose([Normalize(mean=(.5,), std=(.5,))])

34

35 # Uses custom dataset to apply composed transforms to each set

36 train_dataset = TransformedTensorDataset(

37 x_train_tensor, y_train_tensor, transform=train_composer)

38 val_dataset = TransformedTensorDataset(

39 x_val_tensor, y_val_tensor, transform=val_composer)

848 | Chapter 10: Transform and Roll Out

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!