Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

peiying410632
from peiying410632 More from this publisher
22.02.2024 Views

Then, let’s create an instance of the data collator and apply it to our mini-batch ofone:from transformers import DataCollatorForLanguageModelingtorch.manual_seed(41)data_collator = DataCollatorForLanguageModeling(tokenizer=bert_tokenizer, mlm_probability=0.15)mlm_tokens = data_collator([tokens])mlm_tokensOutput{'input_ids': tensor([[ 101, 5650, 2003, 1999, 10288, 24759,103, 6321, 2206, 1996, 2317, 10442, 102]]),'labels': tensor([[-100, -100, -100, -100, -100, -100, 5555, -100,-100, -100, -100, -100, -100]])}If you look closely, you’ll see that the seventh token (5555 in the original input)was replaced by some other token (103). Moreover, the labels contain thereplaced tokens in their original positions (and -100 everywhere else to indicatethese tokens are irrelevant for computing the loss). It’s actually easier to visualizethe difference if we convert the IDs back to tokens:print(bert_tokenizer.convert_ids_to_tokens(tokens['input_ids']))print(bert_tokenizer.convert_ids_to_tokens(mlm_tokens['input_ids'][0]))Output['[CLS]', 'alice', 'is', 'in', '##ex', '##pl', '##ica', '##bly','following', 'the', 'white', 'rabbit', '[SEP]']['[CLS]', 'alice', 'is', 'in', '##ex', '##pl', '[MASK]', '##bly','following', 'the', 'white', 'rabbit', '[SEP]']See? The seventh token (##ica) got masked!BERT | 977

We’re not using collators in our example, but they can be usedtogether with HuggingFace’s Trainer (more on that in the "Fine-Tuning with HuggingFace" section) if you’re into training someBERT from scratch on the MLM task.But that’s not the only thing that BERT is trained to do…Next Sentence Prediction (NSP)The second pre-training task is a binary classification task: BERT was trained topredict if a second sentence is actually the next sentence in the original text ornot. The purpose of this task is to give BERT the ability to understand therelationship between sentences, which can be useful for some of the tasks BERTcan be fine-tuned for, like question answering.So, BERT takes two sentences as inputs (with the special separator token [SEP]between them):• 50% of the time, the second sentence is indeed the next sentence (the positiveclass).• 50% of the time, the second sentence is a randomly chosen one (the negativeclass).This task uses the special classifier token [CLS], taking the values of thecorresponding final hidden state as features for a classifier. For example, let’s taketwo sentences and tokenize them:sentence1 = 'alice follows the white rabbit'sentence2 = 'follow the white rabbit neo'bert_tokenizer(sentence1, sentence2, return_tensors='pt')Output{'input_ids': tensor([[ 101, 5650, 4076, 1996, 2317, 10442,102, 3582, 1996, 2317, 10442, 9253, 102]]),'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,1]]),'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1]])}978 | Chapter 11: Down the Yellow Brick Rabbit Hole

Then, let’s create an instance of the data collator and apply it to our mini-batch of

one:

from transformers import DataCollatorForLanguageModeling

torch.manual_seed(41)

data_collator = DataCollatorForLanguageModeling(

tokenizer=bert_tokenizer, mlm_probability=0.15

)

mlm_tokens = data_collator([tokens])

mlm_tokens

Output

{'input_ids': tensor([[ 101, 5650, 2003, 1999, 10288, 24759,

103, 6321, 2206, 1996, 2317, 10442, 102]]),

'labels': tensor([[-100, -100, -100, -100, -100, -100, 5555, -100,

-100, -100, -100, -100, -100]])}

If you look closely, you’ll see that the seventh token (5555 in the original input)

was replaced by some other token (103). Moreover, the labels contain the

replaced tokens in their original positions (and -100 everywhere else to indicate

these tokens are irrelevant for computing the loss). It’s actually easier to visualize

the difference if we convert the IDs back to tokens:

print(bert_tokenizer.convert_ids_to_tokens(tokens['input_ids']))

print(bert_tokenizer.convert_ids_to_tokens(

mlm_tokens['input_ids'][0]

))

Output

['[CLS]', 'alice', 'is', 'in', '##ex', '##pl', '##ica', '##bly',

'following', 'the', 'white', 'rabbit', '[SEP]']

['[CLS]', 'alice', 'is', 'in', '##ex', '##pl', '[MASK]', '##bly',

'following', 'the', 'white', 'rabbit', '[SEP]']

See? The seventh token (##ica) got masked!

BERT | 977

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!