1. 關(guān)于MLM
1.1 背景
作為 Bert
預(yù)訓(xùn)練的兩大任務(wù)之一,MLM 和 NSP 大家應(yīng)該并不陌生。其中,NSP 任務(wù)在后續(xù)的一些預(yù)訓(xùn)練任務(wù)中經(jīng)常被嫌棄,例如 Roberta
中將 NSP 任務(wù)直接放棄,Albert
中將 NSP 替換成了句子順序預(yù)測(cè)。
這主要是因?yàn)?NSP 作為一個(gè)分類任務(wù)過于簡(jiǎn)單,對(duì)模型的學(xué)習(xí)并沒有太大的幫助,而 MLM 則被多數(shù)預(yù)訓(xùn)練模型保留下來。由 Roberta
的實(shí)驗(yàn)結(jié)果也可以證明,Bert
的主要能力應(yīng)該是來自于 MLM 任務(wù)的訓(xùn)練。
Bert
為代表的預(yù)訓(xùn)練語言模型是在大規(guī)模語料的基礎(chǔ)上訓(xùn)練以獲得的基礎(chǔ)的學(xué)習(xí)能力,而實(shí)際應(yīng)用時(shí),我們所面臨的語料或許具有某些特殊性,這就使得重新進(jìn)行 MLM 訓(xùn)練具有了必要性。
1.2 如何進(jìn)行MLM訓(xùn)練
1.2.1 什么是MLM
MLM 的訓(xùn)練,在不同的預(yù)訓(xùn)練模型中其實(shí)是有所不同的。今天介紹的內(nèi)容以最基礎(chǔ)的 Bert
為例。
Bert的MLM是靜態(tài)mask,而在后續(xù)的其他預(yù)訓(xùn)練模型中,這一策略通常被替換成了動(dòng)態(tài)mask。除此之外還有 whole word mask 的模型,這些都不在今天的討論范圍內(nèi)。
所謂 mask language model 的任務(wù),通俗來講,就是將句子中的一部分token替換掉,然后根據(jù)句子的剩余部分,試圖去還原這部分被mask的token。
1.2.2 如何Mask
mask 的比例一般是15%
,這一比例也被后續(xù)的多數(shù)模型所繼承,而在最初BERT
的論文中,沒有對(duì)這一比例的界定給出具體的說明。在我的印象中,似乎是知道后來同樣是Google提出的 T5
模型的論文中,對(duì)此進(jìn)行了解釋,對(duì) mask 的比例進(jìn)行了實(shí)驗(yàn),最終得出結(jié)論,15%的比例是最合理的(如果我記錯(cuò)了,還請(qǐng)指正)。
將15%
的token選出之后,并不是所有的都替換成[mask]標(biāo)記符。實(shí)際操作是:
-
從這
15%
選出的部分中,將其中的80%
替換成[mask]; -
10%
替換成一個(gè)隨機(jī)的token; -
剩下的
10%
保留原來的token。
這樣做可以提高模型的魯棒性。這個(gè)比例也可以自己控制。
到這里可能有同學(xué)要問了,既然有10%保留不變的話,為什么不干脆只選擇15%*90% = 13.5%的token呢?如果看完后面的代碼,就會(huì)很清楚地理解這個(gè)問題了。
先說結(jié)論:因?yàn)?MLM 的任務(wù)是將選出的這15%的token全部進(jìn)行預(yù)測(cè),不管這個(gè)token是否被替換成了[mask],也就是說,即使它被保留了原樣,也還是需要被預(yù)測(cè)的。
2. 代碼部分
2.1 背景
介紹完了基礎(chǔ)內(nèi)容之后,接下來的內(nèi)容,我將基于 transformers
模塊,介紹如何進(jìn)行 mask language model 的訓(xùn)練。
其實(shí) transformers
模塊中,本身是提供了 MLM 訓(xùn)練任務(wù)的,模型都寫好了,只需要調(diào)用它內(nèi)置的 trainer
和datasets
模塊即可。感興趣的同學(xué)可以去 huggingface
的官網(wǎng)搜索相關(guān)教程。
然而我覺得 datasets
每次調(diào)用的時(shí)候都要去寫數(shù)據(jù)集的py文件,對(duì)arrow
的數(shù)據(jù)格式不熟悉的話還很容易出錯(cuò),而且 trainer
我覺得也不是很好用,任何一點(diǎn)小小的修改都挺費(fèi)勁(就是它以為它寫的很完備,考慮了用戶的所有需求,但是實(shí)際上有一些冗余的部分)。
所以我就參考它的實(shí)現(xiàn)方式,把它的代碼拆解,又按照自己的方式重新組織了一下。
2.2 準(zhǔn)備工作
首先在寫核心代碼之前,先做好準(zhǔn)備工作。
import 所有需要的模塊:
import os
import json
import copy
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast
然后寫一個(gè)config類,將所有參數(shù)集中起來:
class Config:
def __init__(self):
pass
def mlm_config(
self,
mlm_probability=0.15,
special_tokens_mask=None,
prob_replace_mask=0.8,
prob_replace_rand=0.1,
prob_keep_ori=0.1,
):
"""
:param mlm_probability: 被mask的token總數(shù)
:param special_token_mask: 特殊token
:param prob_replace_mask: 被替換成[MASK]的token比率
:param prob_replace_rand: 被隨機(jī)替換成其他token比率
:param prob_keep_ori: 保留原token的比率
"""
assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1, ValueError("Sum of the probs must equal to 1.")
self.mlm_probability = mlm_probability
self.special_tokens_mask = special_tokens_mask
self.prob_replace_mask = prob_replace_mask
self.prob_replace_rand = prob_replace_rand
self.prob_keep_ori = prob_keep_ori
def training_config(
self,
batch_size,
epochs,
learning_rate,
weight_decay,
device,
):
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.device = device
def io_config(
self,
from_path,
save_path,
):
self.from_path = from_path
self.save_path = save_path
接著就是設(shè)置各種配置:
config = Config()
config.mlm_config()
config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
config.io_config(from_path='/data/BERTmodels/huggingface/chinese_wwm/',
save_path='./finetune_embedding_model/mlm/')
最后創(chuàng)建BERT模型。注意,這里的 tokenizer 就是一個(gè)普通的 tokenizer,而BERT模型則是帶了下游任務(wù)的 BertForMaskedLM
,它是 transformers
中寫好的一個(gè)類,
bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)
2.3 數(shù)據(jù)集
因?yàn)樯釛壛?code style="font-size:14px;padding:2px 4px;margin:0 2px;color:#1e6bb8;background-color:rgba(27,31,35,.05);font-family:'Operator Mono', Consolas, Monaco, Menlo, monospace;">datasets這個(gè)包,所以我們現(xiàn)在需要自己實(shí)現(xiàn)數(shù)據(jù)的輸入了。方案就是使用 torch
的 Dataset
類。這個(gè)類一般在構(gòu)建 DataLoader
的時(shí)候,會(huì)與一個(gè)聚合函數(shù)一起使用,以實(shí)現(xiàn)對(duì)batch的組織。而我這里偷個(gè)懶,就沒有寫聚合函數(shù),batch的組織方法放在dataset中進(jìn)行。
在這個(gè)類中,有一個(gè) mask tokens 的方法,作用是從數(shù)據(jù)中選擇出所有需要mask 的token,并且采用三種mask方式中的一個(gè)。這個(gè)方法是從transformers
中拿出來的,將其從類方法轉(zhuǎn)為靜態(tài)方法測(cè)試之后,再將其放在自己的這個(gè)類中為我們所用。仔細(xì)閱讀這一段代碼,也就可以回答1.2.2 中提出的那個(gè)問題了。
取batch的原理很簡(jiǎn)單,一開始我們將原始數(shù)據(jù)deepcopy備份一下,然后每次從中截取一個(gè)batch的大小,這個(gè)時(shí)候的當(dāng)前數(shù)據(jù)就少了一個(gè)batch,我們定義這個(gè)類的長(zhǎng)度為當(dāng)前長(zhǎng)度除以batch size向下取整,所以當(dāng)類的長(zhǎng)度變?yōu)?的時(shí)候,就說明這一個(gè)epoch的所有step都已經(jīng)執(zhí)行結(jié)束,要進(jìn)行下一個(gè)epoch的訓(xùn)練,此時(shí),再將當(dāng)前數(shù)據(jù)變?yōu)樵紨?shù)據(jù),就可以實(shí)現(xiàn)對(duì)epoch的循環(huán)了。
class TrainDataset(Dataset):
"""
注意:由于沒有使用data_collator,batch放在dataset里邊做,
因而在dataloader出來的結(jié)果會(huì)多套一層batch維度,傳入模型時(shí)注意squeeze掉
"""
def __init__(self, input_texts, tokenizer, config):
self.input_texts = input_texts
self.tokenizer = tokenizer
self.config = config
self.ori_inputs = copy.deepcopy(input_texts)
def __len__(self):
return len(self.input_texts) // self.config.batch_size
def __getitem__(self, idx):
batch_text = self.input_texts[: self.config.batch_size]
features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
inputs, labels = self.mask_tokens(features['input_ids'])
batch = {"inputs": inputs, "labels": labels}
self.input_texts = self.input_texts[self.config.batch_size: ]
if not len(self):
self.input_texts = self.ori_inputs
return batch
def mask_tokens(self, inputs):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
if self.config.special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = self.config.special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
然后取一些用于訓(xùn)練的語料,格式很簡(jiǎn)單,就是把所有文本放在一個(gè)list里邊,注意長(zhǎng)度不要超過512個(gè)token,不然多出來的部分就浪費(fèi)掉了??梢宰鲞m當(dāng)?shù)念A(yù)處理。
[
"這是一條文本",
"這是另一條文本",
...,
]
然后構(gòu)建dataloader:
train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
train_dataloader = DataLoader(train_dataset)
2.4 訓(xùn)練
構(gòu)建一個(gè)訓(xùn)練方法,輸入?yún)?shù)分別是我們實(shí)例化好的待訓(xùn)練模型,數(shù)據(jù)集,還有config:
def train(model, train_dataloader, config):
"""
訓(xùn)練
:param model: nn.Module
:param train_dataloader: DataLoader
:param config: Config
---------------
ver: 2021-11-08
by: changhongyu
"""
assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
device = torch.device(config.device)
model.to(device)
if not len(train_dataloader):
raise EOFError("Empty train_dataloader.")
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
training_loss = 0
print("Epoch: {}".format(cur_epc+1))
model.train()
for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
input_ids = batch['inputs'].squeeze(0).to(device)
labels = batch['labels'].squeeze(0).to(device)
loss = model(input_ids=input_ids, labels=labels).loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.zero_grad()
training_loss += loss.item()
print("Training loss: ", training_loss)
調(diào)用它訓(xùn)練幾輪:
train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)
2.5 保存和加載
使用過預(yù)訓(xùn)練模型的同學(xué)應(yīng)該都了解,普通的bert有兩項(xiàng)輸出,分別是:
- 每一個(gè)token對(duì)應(yīng)的768維編碼結(jié)果;
- 以及用于表征整個(gè)句子的句子特征。
其中,這個(gè)句子特征是由模型中的一個(gè) Pooler
模塊對(duì)原句池化得來的??墒沁@個(gè)Pooler的訓(xùn)練,并不是由 MLM 任務(wù)來的,而是由 NSP任務(wù)中來的。
由于沒有 NSP 任務(wù),所以無法對(duì) Pooler
進(jìn)行訓(xùn)練,故而沒有必要在模型中加入 Pooler
。所以在保存的時(shí)候需要分別保存 embedding和 encoder, 加載的時(shí)候也需要分別讀取 embedding 和 encoder,這樣訓(xùn)練出來的模型拿不到 CLS 層的句子表征。如果需要的話,可以手動(dòng)pooling 。
torch.save(bert_mlm_model.bert.embeddings.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_eb.bin'.format(config.epochs)))
torch.save(bert_mlm_model.bert.encoder.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_ec.bin'.format(config.epochs)))
加載的話,也是實(shí)例化完bert模型之后,用bert的 embedding 組件和 encoder 組件分別讀取這兩個(gè)權(quán)重文件即可。
到這里,本期內(nèi)容就全部結(jié)束了,希望看完這篇博客的同學(xué),能夠?qū)?Bert 的基礎(chǔ)原理有更深入的了解。
審核編輯 :李倩
-
模型
+關(guān)注
關(guān)注
1文章
3123瀏覽量
48664 -
語言模型
+關(guān)注
關(guān)注
0文章
502瀏覽量
10239 -
mask
+關(guān)注
關(guān)注
0文章
9瀏覽量
2896
原文標(biāo)題:2. 代碼部分
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論