0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何進(jìn)行MLM訓(xùn)練

深度學(xué)習(xí)自然語言處理 ? 來源:CSDN ? 作者:常鴻宇 ? 2022-08-13 10:54 ? 次閱讀

1. 關(guān)于MLM

1.1 背景

作為 Bert 預(yù)訓(xùn)練的兩大任務(wù)之一,MLMNSP 大家應(yīng)該并不陌生。其中,NSP 任務(wù)在后續(xù)的一些預(yù)訓(xùn)練任務(wù)中經(jīng)常被嫌棄,例如 Roberta 中將 NSP 任務(wù)直接放棄,Albert 中將 NSP 替換成了句子順序預(yù)測(cè)。

這主要是因?yàn)?NSP 作為一個(gè)分類任務(wù)過于簡(jiǎn)單,對(duì)模型的學(xué)習(xí)并沒有太大的幫助,而 MLM 則被多數(shù)預(yù)訓(xùn)練模型保留下來。由 Roberta的實(shí)驗(yàn)結(jié)果也可以證明,Bert 的主要能力應(yīng)該是來自于 MLM 任務(wù)的訓(xùn)練。

Bert為代表的預(yù)訓(xùn)練語言模型是在大規(guī)模語料的基礎(chǔ)上訓(xùn)練以獲得的基礎(chǔ)的學(xué)習(xí)能力,而實(shí)際應(yīng)用時(shí),我們所面臨的語料或許具有某些特殊性,這就使得重新進(jìn)行 MLM 訓(xùn)練具有了必要性。

1.2 如何進(jìn)行MLM訓(xùn)練

1.2.1 什么是MLM

MLM 的訓(xùn)練,在不同的預(yù)訓(xùn)練模型中其實(shí)是有所不同的。今天介紹的內(nèi)容以最基礎(chǔ)的 Bert 為例。

Bert的MLM是靜態(tài)mask,而在后續(xù)的其他預(yù)訓(xùn)練模型中,這一策略通常被替換成了動(dòng)態(tài)mask。除此之外還有 whole word mask 的模型,這些都不在今天的討論范圍內(nèi)。

所謂 mask language model 的任務(wù),通俗來講,就是將句子中的一部分token替換掉,然后根據(jù)句子的剩余部分,試圖去還原這部分被mask的token。

1.2.2 如何Mask

mask 的比例一般是15%,這一比例也被后續(xù)的多數(shù)模型所繼承,而在最初BERT 的論文中,沒有對(duì)這一比例的界定給出具體的說明。在我的印象中,似乎是知道后來同樣是Google提出的 T5 模型的論文中,對(duì)此進(jìn)行了解釋,對(duì) mask 的比例進(jìn)行了實(shí)驗(yàn),最終得出結(jié)論,15%的比例是最合理的(如果我記錯(cuò)了,還請(qǐng)指正)。

15%的token選出之后,并不是所有的都替換成[mask]標(biāo)記符。實(shí)際操作是:

  • 從這15%選出的部分中,將其中的80%替換成[mask];
  • 10%替換成一個(gè)隨機(jī)的token;
  • 剩下的10%保留原來的token。

這樣做可以提高模型的魯棒性。這個(gè)比例也可以自己控制。

到這里可能有同學(xué)要問了,既然有10%保留不變的話,為什么不干脆只選擇15%*90% = 13.5%的token呢?如果看完后面的代碼,就會(huì)很清楚地理解這個(gè)問題了。

先說結(jié)論:因?yàn)?MLM 的任務(wù)是將選出的這15%的token全部進(jìn)行預(yù)測(cè),不管這個(gè)token是否被替換成了[mask],也就是說,即使它被保留了原樣,也還是需要被預(yù)測(cè)的。

2. 代碼部分

2.1 背景

介紹完了基礎(chǔ)內(nèi)容之后,接下來的內(nèi)容,我將基于 transformers 模塊,介紹如何進(jìn)行 mask language model 的訓(xùn)練。

其實(shí) transformers 模塊中,本身是提供了 MLM 訓(xùn)練任務(wù)的,模型都寫好了,只需要調(diào)用它內(nèi)置的 trainerdatasets模塊即可。感興趣的同學(xué)可以去 huggingface 的官網(wǎng)搜索相關(guān)教程。

然而我覺得 datasets 每次調(diào)用的時(shí)候都要去寫數(shù)據(jù)集的py文件,對(duì)arrow的數(shù)據(jù)格式不熟悉的話還很容易出錯(cuò),而且 trainer 我覺得也不是很好用,任何一點(diǎn)小小的修改都挺費(fèi)勁(就是它以為它寫的很完備,考慮了用戶的所有需求,但是實(shí)際上有一些冗余的部分)。

所以我就參考它的實(shí)現(xiàn)方式,把它的代碼拆解,又按照自己的方式重新組織了一下。

2.2 準(zhǔn)備工作

首先在寫核心代碼之前,先做好準(zhǔn)備工作。
import 所有需要的模塊:

import os
import json
import copy
from tqdm.notebook import tqdm

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast

然后寫一個(gè)config類,將所有參數(shù)集中起來:

class Config:
    def __init__(self):
        pass
    
    def mlm_config(
        self, 
        mlm_probability=0.15, 
        special_tokens_mask=None,
        prob_replace_mask=0.8,
        prob_replace_rand=0.1,
        prob_keep_ori=0.1,
    ):
        """
        :param mlm_probability: 被mask的token總數(shù)
        :param special_token_mask: 特殊token
        :param prob_replace_mask: 被替換成[MASK]的token比率
        :param prob_replace_rand: 被隨機(jī)替換成其他token比率
        :param prob_keep_ori: 保留原token的比率
        """
        assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1,                 ValueError("Sum of the probs must equal to 1.")
        self.mlm_probability = mlm_probability
        self.special_tokens_mask = special_tokens_mask
        self.prob_replace_mask = prob_replace_mask
        self.prob_replace_rand = prob_replace_rand
        self.prob_keep_ori = prob_keep_ori
        
    def training_config(
        self,
        batch_size,
        epochs,
        learning_rate,
        weight_decay,
        device,
    ):
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.device = device
        
    def io_config(
        self,
        from_path,
        save_path,
    ):
        self.from_path = from_path
        self.save_path = save_path

接著就是設(shè)置各種配置:

config = Config()
config.mlm_config()
config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
config.io_config(from_path='/data/BERTmodels/huggingface/chinese_wwm/', 
                 save_path='./finetune_embedding_model/mlm/')

最后創(chuàng)建BERT模型。注意,這里的 tokenizer 就是一個(gè)普通的 tokenizer,而BERT模型則是帶了下游任務(wù)的 BertForMaskedLM,它是 transformers 中寫好的一個(gè)類,

bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)

2.3 數(shù)據(jù)集

因?yàn)樯釛壛?code style="font-size:14px;padding:2px 4px;margin:0 2px;color:#1e6bb8;background-color:rgba(27,31,35,.05);font-family:'Operator Mono', Consolas, Monaco, Menlo, monospace;">datasets這個(gè)包,所以我們現(xiàn)在需要自己實(shí)現(xiàn)數(shù)據(jù)的輸入了。方案就是使用 torchDataset 類。這個(gè)類一般在構(gòu)建 DataLoader 的時(shí)候,會(huì)與一個(gè)聚合函數(shù)一起使用,以實(shí)現(xiàn)對(duì)batch的組織。而我這里偷個(gè)懶,就沒有寫聚合函數(shù),batch的組織方法放在dataset中進(jìn)行。

在這個(gè)類中,有一個(gè) mask tokens 的方法,作用是從數(shù)據(jù)中選擇出所有需要mask 的token,并且采用三種mask方式中的一個(gè)。這個(gè)方法是從transformers 中拿出來的,將其從類方法轉(zhuǎn)為靜態(tài)方法測(cè)試之后,再將其放在自己的這個(gè)類中為我們所用。仔細(xì)閱讀這一段代碼,也就可以回答1.2.2 中提出的那個(gè)問題了。

取batch的原理很簡(jiǎn)單,一開始我們將原始數(shù)據(jù)deepcopy備份一下,然后每次從中截取一個(gè)batch的大小,這個(gè)時(shí)候的當(dāng)前數(shù)據(jù)就少了一個(gè)batch,我們定義這個(gè)類的長(zhǎng)度為當(dāng)前長(zhǎng)度除以batch size向下取整,所以當(dāng)類的長(zhǎng)度變?yōu)?的時(shí)候,就說明這一個(gè)epoch的所有step都已經(jīng)執(zhí)行結(jié)束,要進(jìn)行下一個(gè)epoch的訓(xùn)練,此時(shí),再將當(dāng)前數(shù)據(jù)變?yōu)樵紨?shù)據(jù),就可以實(shí)現(xiàn)對(duì)epoch的循環(huán)了。

class TrainDataset(Dataset):
    """
    注意:由于沒有使用data_collator,batch放在dataset里邊做,
    因而在dataloader出來的結(jié)果會(huì)多套一層batch維度,傳入模型時(shí)注意squeeze掉
    """
    def __init__(self, input_texts, tokenizer, config):
        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.config = config
        self.ori_inputs = copy.deepcopy(input_texts)
        
    def __len__(self):
        return len(self.input_texts) // self.config.batch_size
    
    def __getitem__(self, idx):
        batch_text = self.input_texts[: self.config.batch_size]
        features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
        inputs, labels = self.mask_tokens(features['input_ids'])
        batch = {"inputs": inputs, "labels": labels}
        self.input_texts = self.input_texts[self.config.batch_size: ]
        if not len(self):
            self.input_texts = self.ori_inputs
        
        return batch
        
    def mask_tokens(self, inputs):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
        if self.config.special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = self.config.special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
        indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

然后取一些用于訓(xùn)練的語料,格式很簡(jiǎn)單,就是把所有文本放在一個(gè)list里邊,注意長(zhǎng)度不要超過512個(gè)token,不然多出來的部分就浪費(fèi)掉了??梢宰鲞m當(dāng)?shù)念A(yù)處理。

[
"這是一條文本",
"這是另一條文本",
...,
]

然后構(gòu)建dataloader:

train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
train_dataloader = DataLoader(train_dataset)

2.4 訓(xùn)練

構(gòu)建一個(gè)訓(xùn)練方法,輸入?yún)?shù)分別是我們實(shí)例化好的待訓(xùn)練模型,數(shù)據(jù)集,還有config:

def train(model, train_dataloader, config):
    """
    訓(xùn)練
    :param model: nn.Module
    :param train_dataloader: DataLoader
    :param config: Config
    ---------------
    ver: 2021-11-08
    by: changhongyu
    """
    assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
    device = torch.device(config.device)
    
    model.to(device)
    
    if not len(train_dataloader):
        raise EOFError("Empty train_dataloader.")
        
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
    
    optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
    
    for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
        training_loss = 0
        print("Epoch: {}".format(cur_epc+1))
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
            input_ids = batch['inputs'].squeeze(0).to(device)
            labels = batch['labels'].squeeze(0).to(device)
            loss = model(input_ids=input_ids, labels=labels).loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            training_loss += loss.item()
        print("Training loss: ", training_loss)

調(diào)用它訓(xùn)練幾輪:

train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)

2.5 保存和加載

使用過預(yù)訓(xùn)練模型的同學(xué)應(yīng)該都了解,普通的bert有兩項(xiàng)輸出,分別是:

  • 每一個(gè)token對(duì)應(yīng)的768維編碼結(jié)果;
  • 以及用于表征整個(gè)句子的句子特征。

其中,這個(gè)句子特征是由模型中的一個(gè) Pooler 模塊對(duì)原句池化得來的??墒沁@個(gè)Pooler的訓(xùn)練,并不是由 MLM 任務(wù)來的,而是由 NSP任務(wù)中來的。

由于沒有 NSP 任務(wù),所以無法對(duì) Pooler 進(jìn)行訓(xùn)練,故而沒有必要在模型中加入 Pooler。所以在保存的時(shí)候需要分別保存 embedding和 encoder, 加載的時(shí)候也需要分別讀取 embedding 和 encoder,這樣訓(xùn)練出來的模型拿不到 CLS 層的句子表征。如果需要的話,可以手動(dòng)pooling 。

torch.save(bert_mlm_model.bert.embeddings.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_eb.bin'.format(config.epochs)))
torch.save(bert_mlm_model.bert.encoder.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_ec.bin'.format(config.epochs)))

加載的話,也是實(shí)例化完bert模型之后,用bert的 embedding 組件和 encoder 組件分別讀取這兩個(gè)權(quán)重文件即可。

到這里,本期內(nèi)容就全部結(jié)束了,希望看完這篇博客的同學(xué),能夠?qū)?Bert 的基礎(chǔ)原理有更深入的了解。

審核編輯 :李倩


聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3123

    瀏覽量

    48664
  • 語言模型
    +關(guān)注

    關(guān)注

    0

    文章

    502

    瀏覽量

    10239
  • mask
    +關(guān)注

    關(guān)注

    0

    文章

    9

    瀏覽量

    2896

原文標(biāo)題:2. 代碼部分

文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    何進(jìn)行電源供應(yīng)設(shè)計(jì) – 第 4 部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源供應(yīng)設(shè)計(jì) – 第 4 部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-09 10:34 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源供應(yīng)設(shè)計(jì) – 第 4 部分

    何進(jìn)行電源供應(yīng)設(shè)計(jì)

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源供應(yīng)設(shè)計(jì).pdf》資料免費(fèi)下載
    發(fā)表于 09-09 10:33 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源供應(yīng)設(shè)計(jì)

    何進(jìn)行電源設(shè)計(jì)–第5部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)–第5部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-07 11:11 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)–第5部分

    何進(jìn)行電源設(shè)計(jì)-第1部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)-第1部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-07 11:10 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)-第1部分

    何進(jìn)行電源設(shè)計(jì)–第2部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)–第2部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-07 11:09 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)–第2部分

    何進(jìn)行電源設(shè)計(jì)–第3部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)–第3部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-07 11:08 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)–第3部分

    何進(jìn)行電源設(shè)計(jì)–第6部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)–第6部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-06 15:05 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)–第6部分

    何進(jìn)行電源設(shè)計(jì)–第4部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源設(shè)計(jì)–第4部分.pdf》資料免費(fèi)下載
    發(fā)表于 09-06 15:04 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源設(shè)計(jì)–第4部分

    何進(jìn)行電源供應(yīng)設(shè)計(jì)-第3部分

    電子發(fā)燒友網(wǎng)站提供《如何進(jìn)行電源供應(yīng)設(shè)計(jì)-第3部分.pdf》資料免費(fèi)下載
    發(fā)表于 08-30 09:16 ?0次下載
    如<b class='flag-5'>何進(jìn)行</b>電源供應(yīng)設(shè)計(jì)-第3部分

    何進(jìn)行IP檢測(cè)

    排查網(wǎng)絡(luò)連接問題,并及時(shí)的防范潛在的網(wǎng)絡(luò)攻擊。 那么,如何進(jìn)行 IP 地址檢測(cè)呢?接下來我將進(jìn)行圖示哦~ 使用操作系統(tǒng)自帶的工具 ① Windows 系統(tǒng)中,按win+R,輸入“ipconfig”命令。 ② Mac 系統(tǒng)中,則可以在“系統(tǒng)偏好設(shè)置”中的“網(wǎng)絡(luò)”
    的頭像 發(fā)表于 07-26 14:09 ?438次閱讀
    如<b class='flag-5'>何進(jìn)行</b>IP檢測(cè)

    何進(jìn)行RF PA Ruggedness的測(cè)試和評(píng)估呢?

    關(guān)于PA ruggedness設(shè)計(jì)測(cè)試問題,先介紹一下原理,如何進(jìn)行ruggedness的測(cè)試和評(píng)估。
    的頭像 發(fā)表于 03-27 10:19 ?2028次閱讀
    如<b class='flag-5'>何進(jìn)行</b>RF PA Ruggedness的測(cè)試和評(píng)估呢?

    ADXL355如何進(jìn)行自測(cè)及補(bǔ)償?

    ADXL355如何進(jìn)行自測(cè)及補(bǔ)償,官方的數(shù)據(jù)手冊(cè),寫的很簡(jiǎn)單,對(duì)比ADXL345的數(shù)據(jù)手冊(cè)寫的就很詳細(xì)!!!
    發(fā)表于 12-29 06:46

    Android APP如何進(jìn)行訪問硬件驅(qū)動(dòng)

    本文我們要講的是在用 i.MX8 平臺(tái)開發(fā)時(shí),Android APP 如何進(jìn)行訪問硬件驅(qū)動(dòng)。
    的頭像 發(fā)表于 12-04 13:50 ?1433次閱讀
    Android APP如<b class='flag-5'>何進(jìn)行</b>訪問硬件驅(qū)動(dòng)

    西門子伺服電機(jī)維修如何進(jìn)行調(diào)試?

    西門子伺服電機(jī)維修如何進(jìn)行調(diào)試?
    的頭像 發(fā)表于 11-23 11:00 ?1349次閱讀

    新apcups電源如何進(jìn)行初充電

    電子發(fā)燒友網(wǎng)站提供《新apcups電源如何進(jìn)行初充電.doc》資料免費(fèi)下載
    發(fā)表于 11-15 09:55 ?0次下載
    新apcups電源如<b class='flag-5'>何進(jìn)行</b>初充電