為了預(yù)訓(xùn)練第 15.8 節(jié)中實現(xiàn)的 BERT 模型,我們需要以理想的格式生成數(shù)據(jù)集,以促進(jìn)兩項預(yù)訓(xùn)練任務(wù):掩碼語言建模和下一句預(yù)測。一方面,原始的 BERT 模型是在兩個巨大的語料庫 BookCorpus 和英文維基百科(參見第15.8.5 節(jié))的串聯(lián)上進(jìn)行預(yù)訓(xùn)練的,這使得本書的大多數(shù)讀者難以運行。另一方面,現(xiàn)成的預(yù)訓(xùn)練 BERT 模型可能不適合醫(yī)學(xué)等特定領(lǐng)域的應(yīng)用。因此,在自定義數(shù)據(jù)集上預(yù)訓(xùn)練 BERT 變得越來越流行。為了便于演示 BERT 預(yù)訓(xùn)練,我們使用較小的語料庫 WikiText-2 ( Merity et al. , 2016 )。
與 15.3節(jié)用于預(yù)訓(xùn)練word2vec的PTB數(shù)據(jù)集相比,WikiText-2(i)保留了原有的標(biāo)點符號,適合下一句預(yù)測;(ii) 保留原始案例和編號;(iii) 大兩倍以上。
import os import random import torch from d2l import torch as d2l
import os import random from mxnet import gluon, np, npx from d2l import mxnet as d2l npx.set_np()
在 WikiText-2 數(shù)據(jù)集中,每一行代表一個段落,其中在任何標(biāo)點符號及其前面的標(biāo)記之間插入空格。保留至少兩句話的段落。為了簡單起見,為了拆分句子,我們只使用句點作為分隔符。我們將在本節(jié)末尾的練習(xí)中討論更復(fù)雜的句子拆分技術(shù)。
#@save d2l.DATA_HUB['wikitext-2'] = ( 'https://s3.amazonaws.com/research.metamind.io/wikitext/' 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') #@save def _read_wiki(data_dir): file_name = os.path.join(data_dir, 'wiki.train.tokens') with open(file_name, 'r') as f: lines = f.readlines() # Uppercase letters are converted to lowercase ones paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2] random.shuffle(paragraphs) return paragraphs
#@save d2l.DATA_HUB['wikitext-2'] = ( 'https://s3.amazonaws.com/research.metamind.io/wikitext/' 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') #@save def _read_wiki(data_dir): file_name = os.path.join(data_dir, 'wiki.train.tokens') with open(file_name, 'r') as f: lines = f.readlines() # Uppercase letters are converted to lowercase ones paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2] random.shuffle(paragraphs) return paragraphs
15.9.1。為預(yù)訓(xùn)練任務(wù)定義輔助函數(shù)
下面,我們首先為兩個 BERT 預(yù)訓(xùn)練任務(wù)實現(xiàn)輔助函數(shù):下一句預(yù)測和掩碼語言建模。這些輔助函數(shù)將在稍后將原始文本語料庫轉(zhuǎn)換為理想格式的數(shù)據(jù)集以預(yù)訓(xùn)練 BERT 時調(diào)用。
15.9.1.1。生成下一句預(yù)測任務(wù)
根據(jù)15.8.5.2 節(jié)的描述,該 _get_next_sentence函數(shù)為二元分類任務(wù)生成一個訓(xùn)練樣例。
#@save def _get_next_sentence(sentence, next_sentence, paragraphs): if random.random() < 0.5: is_next = True else: # `paragraphs` is a list of lists of lists next_sentence = random.choice(random.choice(paragraphs)) is_next = False return sentence, next_sentence, is_next
#@save def _get_next_sentence(sentence, next_sentence, paragraphs): if random.random() < 0.5: is_next = True else: # `paragraphs` is a list of lists of lists next_sentence = random.choice(random.choice(paragraphs)) is_next = False return sentence, next_sentence, is_next
以下函數(shù)paragraph通過調(diào)用該 _get_next_sentence函數(shù)從輸入生成用于下一句預(yù)測的訓(xùn)練示例。這paragraph是一個句子列表,其中每個句子都是一個標(biāo)記列表。該參數(shù) max_len指定預(yù)訓(xùn)練期間 BERT 輸入序列的最大長度。
#@save def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len): nsp_data_from_paragraph = [] for i in range(len(paragraph) - 1): tokens_a, tokens_b, is_next = _get_next_sentence( paragraph[i], paragraph[i + 1], paragraphs) # Consider 1 '' token and 2 '' tokens if len(tokens_a) + len(tokens_b) + 3 > max_len: continue tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b) nsp_data_from_paragraph.append((tokens, segments, is_next)) return nsp_data_from_paragraph
#@save def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len): nsp_data_from_paragraph = [] for i in range(len(paragraph) - 1): tokens_a, tokens_b, is_next = _get_next_sentence( paragraph[i], paragraph[i + 1], paragraphs) # Consider 1 '' token and 2 '' tokens if len(tokens_a) + len(tokens_b) + 3 > max_len: continue tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b) nsp_data_from_paragraph.append((tokens, segments, is_next)) return nsp_data_from_paragraph
15.9.1.2。生成掩碼語言建模任務(wù)
為了從 BERT 輸入序列為掩碼語言建模任務(wù)生成訓(xùn)練示例,我們定義了以下 _replace_mlm_tokens函數(shù)。在它的輸入中,tokens是代表BERT輸入序列的token列表,candidate_pred_positions 是BERT輸入序列的token索引列表,不包括特殊token(masked語言建模任務(wù)中不預(yù)測特殊token),num_mlm_preds表示預(yù)測(召回 15% 的隨機標(biāo)記來預(yù)測)。遵循第 15.8.5.1 節(jié)中屏蔽語言建模任務(wù)的定義 ,在每個預(yù)測位置,輸入可能被特殊的“”標(biāo)記或隨機標(biāo)記替換,或者保持不變。最后,該函數(shù)返回可能替換后的輸入標(biāo)記、發(fā)生預(yù)測的標(biāo)記索引以及這些預(yù)測的標(biāo)簽。
#@save def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab): # For the input of a masked language model, make a new copy of tokens and # replace some of them by '' or random tokens mlm_input_tokens = [token for token in tokens] pred_positions_and_labels = [] # Shuffle for getting 15% random tokens for prediction in the masked # language modeling task random.shuffle(candidate_pred_positions) for mlm_pred_position in candidate_pred_positions: if len(pred_positions_and_labels) >= num_mlm_preds: break masked_token = None # 80% of the time: replace the word with the '' token if random.random() < 0.8: masked_token = '' else: # 10% of the time: keep the word unchanged if random.random() < 0.5: masked_token = tokens[mlm_pred_position] # 10% of the time: replace the word with a random word else: masked_token = random.choice(vocab.idx_to_token) mlm_input_tokens[mlm_pred_position] = masked_token pred_positions_and_labels.append( (mlm_pred_position, tokens[mlm_pred_position])) return mlm_input_tokens, pred_positions_and_labels
#@save def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab): # For the input of a masked language model, make a new copy of tokens and # replace some of them by '' or random tokens mlm_input_tokens = [token for token in tokens] pred_positions_and_labels = [] # Shuffle for getting 15% random tokens for prediction in the masked # language modeling task random.shuffle(candidate_pred_positions) for mlm_pred_position in candidate_pred_positions: if len(pred_positions_and_labels) >= num_mlm_preds: break masked_token = None # 80% of the time: replace the word with the '' token if random.random() < 0.8: masked_token = '' else: # 10% of the time: keep the word unchanged if random.random() < 0.5: masked_token = tokens[mlm_pred_position] # 10% of the time: replace the word with a random word else: masked_token = random.choice(vocab.idx_to_token) mlm_input_tokens[mlm_pred_position] = masked_token pred_positions_and_labels.append( (mlm_pred_position, tokens[mlm_pred_position])) return mlm_input_tokens, pred_positions_and_labels
通過調(diào)用上述_replace_mlm_tokens函數(shù),以下函數(shù)將 BERT 輸入序列 ( tokens) 作為輸入并返回輸入標(biāo)記的索引(在可能的標(biāo)記替換之后,如第15.8.5.1 節(jié)所述)、發(fā)生預(yù)測的標(biāo)記索引和標(biāo)簽這些預(yù)測的指標(biāo)。
#@save def _get_mlm_data_from_tokens(tokens, vocab): candidate_pred_positions = [] # `tokens` is a list of strings for i, token in enumerate(tokens): # Special tokens are not predicted in the masked language modeling # task if token in ['', '']: continue candidate_pred_positions.append(i) # 15% of random tokens are predicted in the masked language modeling task num_mlm_preds = max(1, round(len(tokens) * 0.15)) mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens( tokens, candidate_pred_positions, num_mlm_preds, vocab) pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0]) pred_positions = [v[0] for v in pred_positions_and_labels] mlm_pred_labels = [v[1] for v in pred_positions_and_labels] return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
#@save def _get_mlm_data_from_tokens(tokens, vocab): candidate_pred_positions = [] # `tokens` is a list of strings for i, token in enumerate(tokens): # Special tokens are not predicted in the masked language modeling # task if token in ['', '']: continue candidate_pred_positions.append(i) # 15% of random tokens are predicted in the masked language modeling task num_mlm_preds = max(1, round(len(tokens) * 0.15)) mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens( tokens, candidate_pred_positions, num_mlm_preds, vocab) pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0]) pred_positions = [v[0] for v in pred_positions_and_labels] mlm_pred_labels = [v[1] for v in pred_positions_and_labels] return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
15.9.2。將文本轉(zhuǎn)換為預(yù)訓(xùn)練數(shù)據(jù)集
現(xiàn)在我們幾乎準(zhǔn)備好定制一個Dataset用于預(yù)訓(xùn)練 BERT 的類。在此之前,我們?nèi)匀恍枰x一個輔助函數(shù) _pad_bert_inputs來將特殊的“”標(biāo)記附加到輸入中。它的參數(shù)examples包含輔助函數(shù) _get_nsp_data_from_paragraph和_get_mlm_data_from_tokens兩個預(yù)訓(xùn)練任務(wù)的輸出。
#@save def _pad_bert_inputs(examples, max_len, vocab): max_num_mlm_preds = round(max_len * 0.15) all_token_ids, all_segments, valid_lens, = [], [], [] all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], [] nsp_labels = [] for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples: all_token_ids.append(torch.tensor(token_ids + [vocab['']] * ( max_len - len(token_ids)), dtype=torch.long)) all_segments.append(torch.tensor(segments + [0] * ( max_len - len(segments)), dtype=torch.long)) # `valid_lens` excludes count of '' tokens valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32)) all_pred_positions.append(torch.tensor(pred_positions + [0] * ( max_num_mlm_preds - len(pred_positions)), dtype=torch.long)) # Predictions of padded tokens will be filtered out in the loss via # multiplication of 0 weights all_mlm_weights.append( torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * ( max_num_mlm_preds - len(pred_positions)), dtype=torch.float32)) all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * ( max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long)) nsp_labels.append(torch.tensor(is_next, dtype=torch.long)) return (all_token_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, nsp_labels)
#@save def _pad_bert_inputs(examples, max_len, vocab): max_num_mlm_preds = round(max_len * 0.15) all_token_ids, all_segments, valid_lens, = [], [], [] all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], [] nsp_labels = [] for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples: all_token_ids.append(np.array(token_ids + [vocab['']] * ( max_len - len(token_ids)), dtype='int32')) all_segments.append(np.array(segments + [0] * ( max_len - len(segments)), dtype='int32')) # `valid_lens` excludes count of '' tokens valid_lens.append(np.array(len(token_ids), dtype='float32')) all_pred_positions.append(np.array(pred_positions + [0] * ( max_num_mlm_preds - len(pred_positions)), dtype='int32')) # Predictions of padded tokens will be filtered out in the loss via # multiplication of 0 weights all_mlm_weights.append( np.array([1.0] * len(mlm_pred_label_ids) + [0.0] * ( max_num_mlm_preds - len(pred_positions)), dtype='float32')) all_mlm_labels.append(np.array(mlm_pred_label_ids + [0] * ( max_num_mlm_preds - len(mlm_pred_label_ids)), dtype='int32')) nsp_labels.append(np.array(is_next)) return (all_token_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, nsp_labels)
將兩個預(yù)訓(xùn)練任務(wù)生成訓(xùn)練樣例的輔助函數(shù)和填充輸入的輔助函數(shù)放在一起,我們自定義如下類_WikiTextDataset作為預(yù)訓(xùn)練 BERT 的 WikiText-2 數(shù)據(jù)集。通過實現(xiàn)該 __getitem__功能,我們可以任意訪問從 WikiText-2 語料庫中的一對句子生成的預(yù)訓(xùn)練(掩碼語言建模和下一句預(yù)測)示例。
原始 BERT 模型使用詞匯量為 30000 的 WordPiece 嵌入( Wu et al. , 2016 )。WordPiece 的標(biāo)記化方法是對15.6.2 節(jié)中原始字節(jié)對編碼算法的輕微修改。為簡單起見,我們使用該d2l.tokenize函數(shù)進(jìn)行標(biāo)記化。過濾掉出現(xiàn)次數(shù)少于五次的不常見標(biāo)記。
#@save class _WikiTextDataset(torch.utils.data.Dataset): def __init__(self, paragraphs, max_len): # Input `paragraphs[i]` is a list of sentence strings representing a # paragraph; while output `paragraphs[i]` is a list of sentences # representing a paragraph, where each sentence is a list of tokens paragraphs = [d2l.tokenize( paragraph, token='word') for paragraph in paragraphs] sentences = [sentence for paragraph in paragraphs for sentence in paragraph] self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[ '', '', '', '']) # Get data for the next sentence prediction task examples = [] for paragraph in paragraphs: examples.extend(_get_nsp_data_from_paragraph( paragraph, paragraphs, self.vocab, max_len)) # Get data for the masked language model task examples = [(_get_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next)) for tokens, segments, is_next in examples] # Pad inputs (self.all_token_ids, self.all_segments, self.valid_lens, self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs( examples, max_len, self.vocab) def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx], self.all_pred_positions[idx], self.all_mlm_weights[idx], self.all_mlm_labels[idx], self.nsp_labels[idx]) def __len__(self): return len(self.all_token_ids)
#@save class _WikiTextDataset(gluon.data.Dataset): def __init__(self, paragraphs, max_len): # Input `paragraphs[i]` is a list of sentence strings representing a # paragraph; while output `paragraphs[i]` is a list of sentences # representing a paragraph, where each sentence is a list of tokens paragraphs = [d2l.tokenize( paragraph, token='word') for paragraph in paragraphs] sentences = [sentence for paragraph in paragraphs for sentence in paragraph] self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[ '', '', '', '']) # Get data for the next sentence prediction task examples = [] for paragraph in paragraphs: examples.extend(_get_nsp_data_from_paragraph( paragraph, paragraphs, self.vocab, max_len)) # Get data for the masked language model task examples = [(_get_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next)) for tokens, segments, is_next in examples] # Pad inputs (self.all_token_ids, self.all_segments, self.valid_lens, self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs( examples, max_len, self.vocab) def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx], self.all_pred_positions[idx], self.all_mlm_weights[idx], self.all_mlm_labels[idx], self.nsp_labels[idx]) def __len__(self): return len(self.all_token_ids)
通過使用_read_wiki函數(shù)和_WikiTextDataset類,我們定義了以下內(nèi)容load_data_wiki來下載 WikiText-2 數(shù)據(jù)集并從中生成預(yù)訓(xùn)練示例。
#@save def load_data_wiki(batch_size, max_len): """Load the WikiText-2 dataset.""" num_workers = d2l.get_dataloader_workers() data_dir = d2l.download_extract('wikitext-2', 'wikitext-2') paragraphs = _read_wiki(data_dir) train_set = _WikiTextDataset(paragraphs, max_len) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) return train_iter, train_set.vocab
#@save def load_data_wiki(batch_size, max_len): """Load the WikiText-2 dataset.""" num_workers = d2l.get_dataloader_workers() data_dir = d2l.download_extract('wikitext-2', 'wikitext-2') paragraphs = _read_wiki(data_dir) train_set = _WikiTextDataset(paragraphs, max_len) train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) return train_iter, train_set.vocab
將批量大小設(shè)置為 512,將 BERT 輸入序列的最大長度設(shè)置為 64,我們打印出 BERT 預(yù)訓(xùn)練示例的小批量形狀。請注意,在每個 BERT 輸入序列中,10 (64×0.15) 位置是為掩碼語言建模任務(wù)預(yù)測的。
batch_size, max_len = 512, 64 train_iter, vocab = load_data_wiki(batch_size, max_len) for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y) in train_iter: print(tokens_X.shape, segments_X.shape, valid_lens_x.shape, pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape, nsp_y.shape) break
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip... torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
batch_size, max_len = 512, 64 train_iter, vocab = load_data_wiki(batch_size, max_len) for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y) in train_iter: print(tokens_X.shape, segments_X.shape, valid_lens_x.shape, pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape, nsp_y.shape) break
(512, 64) (512, 64) (512,) (512, 10) (512, 10) (512, 10) (512,)
最后,讓我們看一下詞匯量。即使在過濾掉不常見的標(biāo)記后,它仍然比 PTB 數(shù)據(jù)集大兩倍以上。
len(vocab)
20256
len(vocab)
20256
15.9.3。概括
與 PTB 數(shù)據(jù)集相比,WikiText-2 數(shù)據(jù)集保留了原始標(biāo)點符號、大小寫和數(shù)字,并且大了一倍多。
我們可以任意訪問從 WikiText-2 語料庫中的一對句子生成的預(yù)訓(xùn)練(掩碼語言建模和下一句預(yù)測)示例。
15.9.4。練習(xí)
為簡單起見,句點用作拆分句子的唯一分隔符。嘗試其他句子拆分技術(shù),例如 spaCy 和 NLTK。以 NLTK 為例。您需要先安裝 NLTK:. 在代碼中,首先. 然后,下載 Punkt 句子分詞器: 。要拆分諸如 之類的句子 ,調(diào)用 將返回兩個句子字符串的列表:。pip install nltkimport nltknltk.download('punkt')sentences = 'This is great ! Why not ?'nltk.tokenize.sent_tokenize(sentences)['This is great !', 'Why not ?']
如果我們不過濾掉任何不常見的標(biāo)記,詞匯表的大小是多少?
-
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1200瀏覽量
24621 -
pytorch
+關(guān)注
關(guān)注
2文章
802瀏覽量
13115
發(fā)布評論請先 登錄
相關(guān)推薦
評論