0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

DeepSpeed結(jié)合Megatron-LM訓(xùn)練GPT2模型筆記

jf_pmFSk4VX ? 來源:GiantPandaCV ? 2023-06-19 14:45 ? 次閱讀

依賴安裝

準(zhǔn)備訓(xùn)練數(shù)據(jù)

訓(xùn)練詳細(xì)流程和踩坑

參數(shù)量估計

訓(xùn)練顯存占用估計

2卡數(shù)據(jù)并行

2卡模型并行

0x0. 前言

本文基于DeepSpeedExamples倉庫中給出的Megatron相關(guān)例子探索一下訓(xùn)練GPT2模型的流程。主要包含3個部分,第一個部分是基于原始的Megatron如何訓(xùn)練GPT2模型,第二個部分是如何結(jié)合DeepSpeed的特性進(jìn)行訓(xùn)練Megatron GPT2,由于篇幅原因這篇文章只寫了第一部分,主要是非常細(xì)致的記錄了跑起來Megatron GPT2訓(xùn)練流程碰到的一些問題和如何解決的。本文主要以這里的codebase展開寫作。

0x1. Megatron使用單卡訓(xùn)練GPT2

首先閱讀 https://github.com/microsoft/DeepSpeedExamples/tree/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM 這里的README。這里不關(guān)注BERT的部分,目的是把GPT2的訓(xùn)練和推理跑起來。

首先提到,Megatron是一款大型且強(qiáng)大的Transformer,這個代碼庫用于進(jìn)行大的Transformer語言模型的持續(xù)研究。目前,Megatron支持GPT2和BERT的模型并行、多節(jié)點(diǎn)訓(xùn)練,并采用混合精度。Megatron的代碼庫能夠使用512個GPU進(jìn)行8路模型和64路數(shù)據(jù)并行來高效地訓(xùn)練一個72層、83億參數(shù)的GPT2語言模型。作者發(fā)現(xiàn),更大的語言模型(指的是前面的83億參數(shù)的GPT2)能夠在僅5個訓(xùn)練epoch內(nèi)超越當(dāng)前GPT2-1.5B wikitext perplexities。

依賴安裝

首先進(jìn)入到Megatron-LM目錄,安裝一下依賴,pip install -r requirements.txt,注意在requirements.txt里面依賴了TensorFlow,這個是和BERT訓(xùn)練相關(guān),我這里不關(guān)心,就不安裝TensorFlow了。requiresment.txt的內(nèi)容如下:

nltk>=3.4
numpy>=1.15.4
pandas>=0.24.0
sentencepiece>=0.1.8
# tensorflow>=1.12.0
boto3==1.11.11
regex==2020.1.8

安裝的時候會報錯:

ERROR: Could not find a version that satisfies the requirement boto3==1.11.11 (from versions: none)
ERROR: No matching distribution found for boto3==1.11.11

我直接使用 pip install boto3 安裝了個最新版本。

接著按照教程,執(zhí)行bash scripts/pretrain_gpt2.sh。這里有一個PyTorch的報錯:

ModuleNotFoundError: No module named 'torch._six'

這個錯誤是由于PyTorch版本變化產(chǎn)生的,搜索了一下,發(fā)現(xiàn)只需要把from torch._six import inf 這行代碼改成 from torch import inf 就可以了。繼續(xù)執(zhí)行,報錯為:AssertionError: make sure to set PATH for wikipedia data_utils/corpora.py 。這是因為在 scripts/pretrain_gpt2.sh 里面指定了訓(xùn)練的數(shù)據(jù)集為 wikipedia ,所以需要在 DeepSpeedExamples/Megatron-LM/data_utils/corpora.py 這里的 PATH = 'data/wikipedia/wikidump_lines.json' 指定我們本地下載的 wikipedia 數(shù)據(jù)路徑。

準(zhǔn)備訓(xùn)練數(shù)據(jù)

下載數(shù)據(jù)的時候發(fā)現(xiàn)這個 wikipedia 數(shù)據(jù)實在太大了, 所以改用 webtext 數(shù)據(jù)集,關(guān)于這個數(shù)據(jù)集 Megatron 的README介紹如下:

“我們”利用公開可用的OpenWebText(https://github.com/eukaryote31/openwebtext)庫,該庫由jcpeterson(https://github.com/jcpeterson/openwebtext)和eukaryote31(https://github.com/eukaryote31/openwebtext)共同開發(fā),用于下載URL。然后,我們根據(jù)我們在openwebtext目錄中描述的過程對所有下載的內(nèi)容進(jìn)行了過濾、清理和去重。對于截至2018年10月的Reddit URL對應(yīng)的內(nèi)容,我們得到了約37GB的內(nèi)容。37G對于跑訓(xùn)練來說還是太大了,所以我只下載了幾十個url中的第一個1url文件。

3b616aaa-0c1b-11ee-962d-dac502259ad0.png然后把這個文件復(fù)制到Megatron-LM的openwebtxt目錄下:

3b6e909a-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

接下來按照 openwebtext 的 README 開始執(zhí)行。

pipinstallftfylangdetectnumpytorchpandasnltksentencepieceboto3tqdmregexbs4newspaper3khtmlmintldextract
gitclonehttps://github.com/mattilyra/LSH
cdLSH
pythonsetup.pyinstall

安裝 LSH 碰到了兩個 Python 版本不兼容引起的問題:

lsh/cMinhash.cpp21: error: ‘PyThreadState’ {aka ‘struct _ts’} has no member named ‘exc_type’; did you mean ‘curexc_type’?
19292 | *type = tstate->exc_type;

可以將exc_type替換為curexc_type來解決這個問題。

lsh/cMinhash.cpp26: error: ‘PyTypeObject’ {aka ‘struct _typeobject’} has no member named ‘tp_print’
17704 | __pyx_type___pyx_array.tp_print = 0;

可以將tp_print替換為tp_vectorcall_offset來解決這個問題。

接下來,執(zhí)行去重url的命令:

python3blacklist_urls.pyRS_2011-01.bz2.deduped.txtclean_urls.txt

我發(fā)現(xiàn)執(zhí)行這個命令之后clean_urls.txt是空的??戳讼麓a發(fā)現(xiàn)這個腳本要求去重的url文件必須在一個目錄下,并且把這個目錄的路徑傳遞給腳本。

3b7491f2-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

因此,在當(dāng)前文件夾下新建一個 urls 目錄,把剛才的url文件放進(jìn)去。如下所示:

3b7d975c-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

然后執(zhí)行:python3 blacklist_urls.py urls clean_urls.txt 就可以完成去重了。接下來使用https://github.com/eukaryote31/openwebtext/blob/master/download.py 下載去重后的 url 對應(yīng)的文本。

3b865ff4-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

這里要全部下載完需要的時間很長,我只下載50條url對應(yīng)的數(shù)據(jù)做一個演示作用。這里要把下載的每條url對應(yīng)的數(shù)據(jù)保存為json文件需要修改一下download.py里面的--sqlite_meta和--save_uncompressed默認(rèn)值,分別改成False和True,這樣執(zhí)行python3 openwebtext/download.py clean_urls.txt 之后就會生成一個scraped文件夾,每個url下載的文本就保存在data子文件夾下:

3b8dfe30-0c1b-11ee-962d-dac502259ad0.png然后我們使用下面的腳本(merge_jsons.py)來把文件夾中的所有txt合并成一個json文件,其中每一行都作為一個text字段對應(yīng)的內(nèi)容:

importglob
importsys
importjson
importargparse

if__name__=='__main__':

parser=argparse.ArgumentParser()
parser.add_argument("--data_path",type=str,default=".",
help="pathwhereallthejsonfilesarelocated")

parser.add_argument("--output_file",type=str,default="merged_output.json",
help="filenamewherethemergedjsonshouldgo")

args=parser.parse_args()

data_path=args.data_path
out_file=args.output_file

text_files=glob.glob(data_path+'/*.txt')

counter=0

withopen(out_file,'w')asoutfile:
forfnameintext_files:
counter+=1

ifcounter%1024==0:
print("Mergingat",counter,flush=True)

withopen(fname,'r')asinfile:
forrowininfile:
tmp={}
tmp['text']=row
outfile.write(json.dumps(tmp))
outfile.write('
')


print("Mergedfile",out_file,flush=True)

執(zhí)行這個腳本獲得merged_output.json:python3 merge_jsons.py --data_pathDeepSpeedExamples/Megatron-LM/openwebtext/scraped/data。

接著,我們在openwebtext文件夾下執(zhí)行一下cleanup_dataset.py來把tokens數(shù)量少于128的文本都刪掉。python3 cleanup_dataset.py merged_output.json merged_cleand.json。

訓(xùn)練詳細(xì)流程和踩坑

數(shù)據(jù)準(zhǔn)備好之后,我們修改一下DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2.sh下面的--train-data為webtext。此外將DeepSpeedExamples/Megatron-LM/data_utils/corpora.py中webtext類的path設(shè)置為我們剛才獲得的merged_cleand.json所在的路徑。

3b966ade-0c1b-11ee-962d-dac502259ad0.png此外,由于我這里只用了幾十條數(shù)據(jù)來做訓(xùn)練過程的演示,這里還需要改一下DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2.sh下面的--split參數(shù),將其改成400,300,300,也就是訓(xùn)練,測試,驗證集的數(shù)據(jù)比例為43,這樣才可以避免把測試集的數(shù)量設(shè)成0。

接下來就可以使用bash scripts/pretrain_gpt2.sh來啟動訓(xùn)練了。給一些訓(xùn)練日志出來:

Setting ds_accelerator to cuda (auto detect)
using world size: 1 and model-parallel size: 1 
 > using dynamic loss scaling
> initializing model parallel with size 1
Pretrain GPT2 model
arguments:
  pretrained_bert .............. False
  attention_dropout ............ 0.1
  num_attention_heads .......... 16
  hidden_size .................. 1024
  intermediate_size ............ None
  num_layers ................... 24
  layernorm_epsilon ............ 1e-05
  hidden_dropout ............... 0.1
  max_position_embeddings ...... 1024
  vocab_size ................... 30522
  deep_init .................... False
  make_vocab_size_divisible_by . 128
  cpu_optimizer ................ False
  cpu_torch_adam ............... False
  fp16 ......................... True
  fp32_embedding ............... False
  fp32_layernorm ............... False
  fp32_tokentypes .............. False
  fp32_allreduce ............... False
  hysteresis ................... 2
  loss_scale ................... None
  loss_scale_window ............ 1000
  min_scale .................... 1
  batch_size ................... 8
  weight_decay ................. 0.01
  checkpoint_activations ....... True
  checkpoint_num_layers ........ 1
  deepspeed_activation_checkpointing  False
  clip_grad .................... 1.0
  train_iters .................. 320000
  log_interval ................. 100
  exit_interval ................ None
  seed ......................... 1234
  reset_position_ids ........... False
  reset_attention_mask ......... False
  lr_decay_iters ............... None
  lr_decay_style ............... cosine
  lr ........................... 0.00015
  warmup ....................... 0.01
  save ......................... checkpoints/gpt2_345m
  save_interval ................ 5000
  no_save_optim ................ False
  no_save_rng .................. False
  load ......................... checkpoints/gpt2_345m
  no_load_optim ................ False
  no_load_rng .................. False
  finetune ..................... False
  resume_dataloader ............ True
  distributed_backend .......... nccl
  local_rank ................... None
  eval_batch_size .............. None
  eval_iters ................... 100
  eval_interval ................ 1000
  eval_seq_length .............. None
  eval_max_preds_per_seq ....... None
  overlapping_eval ............. 32
  cloze_eval ................... False
  eval_hf ...................... False
  load_openai .................. False
  temperature .................. 1.0
  top_p ........................ 0.0
  top_k ........................ 0
  out_seq_length ............... 256
  model_parallel_size .......... 1
  shuffle ...................... False
  train_data ................... ['webtext']
  use_npy_data_loader .......... False
  train_data_path .............. 
  val_data_path ................ 
  test_data_path ............... 
  input_data_sizes_file ........ sizes.txt
  delim ........................ ,
  text_key ..................... sentence
  eval_text_key ................ None
  valid_data ................... None
  split ........................ 400,300,300
  test_data .................... None
  lazy_loader .................. True
  loose_json ................... False
  presplit_sentences ........... False
  num_workers .................. 2
  tokenizer_model_type ......... bert-large-uncased
  tokenizer_path ............... tokenizer.model
  tokenizer_type ............... GPT2BPETokenizer
  cache_dir .................... cache
  use_tfrecords ................ False
  seq_length ................... 1024
  max_preds_per_seq ............ None
  deepspeed .................... False
  deepspeed_config ............. None
  deepscale .................... False
  deepscale_config ............. None
  deepspeed_mpi ................ False
  cuda ......................... True
  rank ......................... 0
  world_size ................... 1
  dynamic_loss_scale ........... True
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
configuring data
> padded vocab (size: 50257) with 47 dummy tokens (new size: 50304)
> found end-of-document token: 50256
building GPT2 model ...
 > number of parameters on model parallel rank 0: 354871296
Optimizer = FusedAdam
learning rate decaying cosine
WARNING: could not find the metadata file checkpoints/gpt2_345m/latest_checkpointed_iteration.txt 
    will not load any checkpoints and will start from random
Partition Activations False and Correctness Check False
 iteration      100/  320000 | elapsed time per iteration (ms): 963.3 | learning rate 3.937E-06 | lm loss 8.995377E+00 | loss scale 131072.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
  warnings.warn(
after 100 iterations memory (MB) | allocated: 6784.88427734375 | max allocated: 11927.470703125 | cached: 13826.0 | max cached: 13826.0
time (ms) | forward: 276.11 | backward: 672.99 | allreduce: 13.96 | optimizer: 14.00 | batch generator: 5.22 | data loader: 4.53
 iteration      200/  320000 | elapsed time per iteration (ms): 950.6 | learning rate 8.625E-06 | lm loss 3.041360E+00 | loss scale 131072.0 |
time (ms) | forward: 259.24 | backward: 674.56 | allreduce: 13.45 | optimizer: 16.63 | batch generator: 0.78 | data loader: 0.14

從 nvidia-smi 的截圖里也可以看到megatron的訓(xùn)練正在卡0運(yùn)行:

3b9f09aa-0c1b-11ee-962d-dac502259ad0.png在訓(xùn)練的時候可能會發(fā)生下面的 StopIteration 錯誤:

time (ms) | forward: 259.07 | backward: 671.87 | allreduce: 13.03 | optimizer: 16.64 | batch generator: 0.76 | data loader: 0.13
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:713 in                  │
│                                                                                                  │
│   710                                                                                            │
│   711                                                                                            │
│   712 if __name__ == "__main__":                                                                 │
│ ? 713 │   main()                                                                                 │
│   714                                                                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:686 in main                     │
│                                                                                                  │
│   683 │   iteration = 0                                                                          │
│   684 │   if args.train_iters > 0:                                                               │
│   685 │   │   if args.do_train:                                                                  │
│ ? 686 │   │   │   iteration, skipped = train(model, optimizer,                                   │
│   687 │   │   │   │   │   │   │   │   │      lr_scheduler,                                       │
│   688 │   │   │   │   │   │   │   │   │      train_data_iterator,                                │
│   689 │   │   │   │   │   │   │   │   │      val_data_iterator,                                  │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:415 in train                    │
│                                                                                                  │
│   412 │   report_memory_flag = True                                                              │
│   413 │   while iteration < args.train_iters:                                                    │
│   414 │   │                                                                                      │
│ ? 415 │   │   lm_loss, skipped_iter = train_step(train_data_iterator,                            │
│   416 │   │   │   │   │   │   │   │   │   │      model,                                          │
│   417 │   │   │   │   │   │   │   │   │   │      optimizer,                                      │
│   418 │   │   │   │   │   │   │   │   │   │      lr_scheduler,                                   │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:369 in train_step               │
│                                                                                                  │
│   366 │                                                                                          │
│   367 │   # Forward model for one step.                                                          │
│   368 │   timers('forward').start()                                                              │
│ ? 369 │   lm_loss = forward_step(data_iterator, model, args, timers)                             │
│   370 │   timers('forward').stop()                                                               │
│   371 │                                                                                          │
│   372 │   #print_rank_0("loss is {}".format(lm_loss))                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:286 in forward_step             │
│                                                                                                  │
│   283 │                                                                                          │
│   284 │   # Get the batch.                                                                       │
│   285 │   timers('batch generator').start()                                                      │
│ ? 286 │   tokens, labels, loss_mask, attention_mask, position_ids = get_batch(                   │
│   287 │   │   data_iterator, args, timers)                                                       │
│   288 │   timers('batch generator').stop()                                                       │
│   289                                                                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/pretrain_gpt2.py:257 in get_batch                │
│                                                                                                  │
│   254 │   # Broadcast data.                                                                      │
│   255 │   timers('data loader').start()                                                          │
│   256 │   if data_iterator is not None:                                                          │
│ ? 257 │   │   data = next(data_iterator)                                                         │
│   258 │   else:                                                                                  │
│   259 │   │   data = None                                                                        │
│   260 │   timers('data loader').stop()                                                           │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/utils/data/dataloader.p │
│ y:633 in __next__                                                                                │
│                                                                                                  │
│    630 │   │   │   if self._sampler_iter is None:                                                │
│    631 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    632 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ?  633 │   │   │   data = self._next_data()                                                      │
│    634 │   │   │   self._num_yielded += 1                                                        │
│    635 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and                           │
│    636 │   │   │   │   │   self._IterableDataset_len_called is not None and                     │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/utils/data/dataloader.p │
│ y:1318 in _next_data                                                                             │
│                                                                                                  │
│   1315 │   │   │   │   # no valid `self._rcvd_idx` is found (i.e., didn't break)                 │
│   1316 │   │   │   │   if not self._persistent_workers:                                          │
│   1317 │   │   │   │   │   self._shutdown_workers()                                              │
│ ? 1318 │   │   │   │   raise StopIteration                                                       │
│   1319 │   │   │                                                                                 │
│   1320 │   │   │   # Now `self._rcvd_idx` is the batch index we want to fetch                    │
│   1321                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
StopIteration

不用擔(dān)心,這個錯誤表示的是數(shù)據(jù)量不夠訓(xùn)練這么多個iter,這個發(fā)生的原因是因為在構(gòu)造dataloader的時候使用了torch.utils.data.SequentialSampler對dataset進(jìn)行采樣,這個采樣器是根據(jù)dataset的長度來采樣,所以無法和args.train_iters關(guān)聯(lián)起來,導(dǎo)致訓(xùn)練到很多iter之后數(shù)據(jù)讀完了就拋出StopIteration錯誤了。

我們調(diào)整一下腳本,把iter數(shù)改成600,并且把checkpoint的保存間隔設(shè)置為500,保證megatron可以存下一個checkpoint。再次運(yùn)行腳本:

3ba42b92-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

0x2. Megatron使用單卡預(yù)測訓(xùn)練好的GPT2模型

修改DeepSpeedExamples/Megatron-LM/scripts/generate_text.sh這里的CHECKPOINT_PATH為我們訓(xùn)練出來的模型路徑,我們這里改成DeepSpeedExamples/Megatron-LM/checkpoints/gpt2_345m,然后在Megatron的根目錄執(zhí)行一下:bash scripts/generate_text.sh。但報錯了:

Setting ds_accelerator to cuda (auto detect)
Generate Samples
WARNING: No training data specified
using world size: 1 and model-parallel size: 1 
 > using dynamic loss scaling
> initializing model parallel with size 1
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
prepare tokenizer done
building GPT2 model ...
 > number of parameters on model parallel rank 0: 354823168
global rank 0 is loading checkpoint /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/checkpoints/gpt2_345m/iter_0000600/mp_rank_00/model_optim_rng.pt
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_samples.py:277 in               │
│                                                                                                  │
│   274                                                                                            │
│   275                                                                                            │
│   276 if __name__ == "__main__":                                                                 │
│ ? 277 │   main()                                                                                 │
│   278                                                                                            │
│   279                                                                                            │
│   280                                                                                            │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_samples.py:267 in main                  │
│                                                                                                  │
│   264 │   tokenizer = prepare_tokenizer(args)                                                    │
│   265 │                                                                                          │
│   266 │   # Model, optimizer, and learning rate.                                                 │
│ ? 267 │   model = setup_model(args)                                                              │
│   268 │                                                                                          │
│   269 │   #setting default batch size to 1                                                       │
│   270 │   args.batch_size = 1                                                                    │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/generate_samples.py:80 in setup_model            │
│                                                                                                  │
│    77 │   model = get_model(args)                                                                │
│    78 │                                                                                          │
│    79 │   if args.load is not None:                                                              │
│ ?  80 │   │   _ = load_checkpoint(                                                               │
│    81 │   │   │   model, None, None, args)                                                       │
│    82 │                                                                                          │
│    83 │   return model                                                                           │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/utils.py:305 in load_checkpoint                  │
│                                                                                                  │
│   302 │   │                                                                                      │
│   303 │   │   # Model.                                                                           │
│   304 │   │   try:                                                                               │
│ ? 305 │   │   │   model.load_state_dict(sd['model'])                                             │
│   306 │   │   except KeyError:                                                                   │
│   307 │   │   │   print_rank_0('A metadata file exists but unable to load model '                │
│   308 │   │   │   │   │   │   'from checkpoint {}, exiting'.format(checkpoint_name))             │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/model/distributed.py:90 in load_state_dict       │
│                                                                                                  │
│    87 │   │   return sd                                                                          │
│    88 │                                                                                          │
│    89 │   def load_state_dict(self, state_dict, strict=True):                                    │
│ ?  90 │   │   self.module.load_state_dict(state_dict, strict=strict)                             │
│    91 │                                                                                          │
│    92 │   '''                                                                                    │
│    93 │   def _sync_buffers(self):                                                               │
│                                                                                                  │
│ /home/zhangxiaoyu/DeepSpeedExamples/Megatron-LM/fp16/fp16.py:71 in load_state_dict               │
│                                                                                                  │
│    68 │   │   return self.module.state_dict(destination, prefix, keep_vars)                      │
│    69 │                                                                                          │
│    70 │   def load_state_dict(self, state_dict, strict=True):                                    │
│ ?  71 │   │   self.module.load_state_dict(state_dict, strict=strict)                             │
│    72                                                                                            │
│    73 # TODO:  Update overflow check + downscale to use Carl's fused kernel.                     │
│    74 class FP16_Optimizer(object):                                                              │
│                                                                                                  │
│ /home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/module.py:20 │
│ 41 in load_state_dict                                                                            │
│                                                                                                  │
│   2038 │   │   │   │   │   │   ', '.join('"{}"'.format(k) for k in missing_keys)))               │
│   2039 │   │                                                                                     │
│   2040 │   │   if len(error_msgs) > 0:                                                           │
│ ? 2041 │   │   │   raise RuntimeError('Error(s) in loading state_dict for {}:
	{}'.format(     │
│   2042 │   │   │   │   │   │   │      self.__class__.__name__, "
	".join(error_msgs)))         │
│   2043 │   │   return _IncompatibleKeys(missing_keys, unexpected_keys)                           │
│   2044                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for GPT2Model:
        size mismatch for word_embeddings.weight: copying a param with shape torch.Size([50304, 1024]) from checkpoint, the shape in current model is 
torch.Size([50257, 1024]).

可以看到加載模型的時候提示word_embeddings.weight的shape不匹配,我們看一下word_embeddings在GPT2中的定義:3baf74ca-0c1b-11ee-962d-dac502259ad0.png

所以這個問題應(yīng)該是訓(xùn)練和測試的時候的vocab_size不同引起的。定位后發(fā)現(xiàn)這是因為訓(xùn)練的時候需要把tokens數(shù)num_tokens pad到可以被args.make_vocab_size_divisible_by=128整除,但是預(yù)測的時候就沒這個限制了,因此導(dǎo)致了embedding的維度不匹配,我們修改一下DeepSpeedExamples/Megatron-LM/generate_samples.py對num_token的處理邏輯,使得和訓(xùn)練一致。

3bc7c75a-0c1b-11ee-962d-dac502259ad0.png再次執(zhí)行bash scripts/generate_text.sh,我們就可以和GPT2對話了,輸出一條prompt模型會給你不同的補(bǔ)全輸出,然后輸入stop結(jié)束對話。

3bd1c264-0c1b-11ee-962d-dac502259ad0.png由于這里的模型只用了很少的數(shù)據(jù)做演示,所以基本沒有什么好的補(bǔ)全效果,后面可以加大數(shù)據(jù)量訓(xùn)練一個更好的GPT2對話模型。

0x3. 參數(shù)量和顯存估計

在 https://zhuanlan.zhihu.com/p/624740065 這篇文章里面有對 GPT2 這種架構(gòu)的 Transformer 的參數(shù)量和訓(xùn)練顯存占用的推導(dǎo),我們這里套用里面總結(jié)的公示計算一下我們當(dāng)前的GPT2模型的參數(shù)量和訓(xùn)練時的理論顯存占用。

參數(shù)量估計

套用下面的公示:3bdb19b8-0c1b-11ee-962d-dac502259ad0.png我們這里的:l=24,hidden_size=1024,12lh^2=12x24x1024x1024=301989888=0.3B。所以我們這里訓(xùn)練的GPT2模型只有大約0.3B參數(shù)。從模型的命名345M,我們也可以知道這個計算結(jié)果和真實大小基本一致。

訓(xùn)練顯存占用估計

3be538d0-0c1b-11ee-962d-dac502259ad0.png根據(jù)上述公式,模型參數(shù),梯度,優(yōu)化器狀態(tài)在訓(xùn)練時的顯存占用大約為301989888*20bytes=6039797760bytes=5898240kb=5760MB=5.6G。然后激活占用的顯存如下:

3bf3ba5e-0c1b-11ee-962d-dac502259ad0.png我們訓(xùn)練的時候 batch_size=8,s=1024,h=1024,a=num-attention-heads=16,l=24,那么。

所以0.3B的GPT2的訓(xùn)練顯存占用大約為5.6G+21G=26.6G。但在0x1節(jié)中,我們可以看到我們的顯卡單卡顯存是24G,并且訓(xùn)練過程中的顯存消耗只有15107MiB=14.75G,也就是說激活占用的顯存并不是我們計算的21G,而是14.75-5.6=9.15G,這是為什么呢?

這是因為在DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2.sh里面打開了--checkpoint-activations,做了Activation Checkpoint。我們可以定位到這部分代碼,在DeepSpeedExamples/Megatron-LM/mpu/transformer.py:406-413:

3bff018e-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

可以看到現(xiàn)在對于每個Transformer層來說,都可以省掉內(nèi)部Self-Attention和MLP做backward時需要保存的中間激活,達(dá)到了減少顯存的目的。

0x4. Megatron使用多卡訓(xùn)練GPT2模型

2卡數(shù)據(jù)并行

上面已經(jīng)完成了單卡的GPT2模型的訓(xùn)練,啟動多卡訓(xùn)練比較簡單,修改一下DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2_distributed.sh里面的--train-data為webtext,然后--train-iters改成600/num_gpus。實際上這個腳本啟動的是數(shù)據(jù)并行的訓(xùn)練,那么我們只需要把iter數(shù)設(shè)置為600/num_gpus就可以和單卡掃到一樣規(guī)模的數(shù)據(jù)了。訓(xùn)練數(shù)據(jù),驗證集,測試的配比也要改一下,因為這里只是模擬數(shù)據(jù)太少了,按照原始的比例會把測試集的數(shù)據(jù)條數(shù)算成0而報錯。最后把GPUS_PER_NODE設(shè)成2,代表使用2卡進(jìn)行數(shù)據(jù)并行訓(xùn)練。接著就可以啟動訓(xùn)練了:bash scripts/pretrain_gpt2_distributed.sh,日志如下:

/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/distributed/launch.py FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  warnings.warn(
WARNING
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Setting ds_accelerator to cuda (auto detect)
Setting ds_accelerator to cuda (auto detect)
using world size: 2 and model-parallel size: 1 
 > using dynamic loss scaling
> initializing model parallel with size 1
Pretrain GPT2 model
arguments:
  pretrained_bert .............. False
  attention_dropout ............ 0.1
  num_attention_heads .......... 16
  hidden_size .................. 1024
  intermediate_size ............ None
  num_layers ................... 24
  layernorm_epsilon ............ 1e-05
  hidden_dropout ............... 0.1
  max_position_embeddings ...... 1024
  vocab_size ................... 30522
  deep_init .................... False
  make_vocab_size_divisible_by . 128
  cpu_optimizer ................ False
  cpu_torch_adam ............... False
  fp16 ......................... True
  fp32_embedding ............... False
  fp32_layernorm ............... False
  fp32_tokentypes .............. False
  fp32_allreduce ............... False
  hysteresis ................... 2
  loss_scale ................... None
  loss_scale_window ............ 1000
  min_scale .................... 1
  batch_size ................... 8
  weight_decay ................. 0.01
  checkpoint_activations ....... True
  checkpoint_num_layers ........ 1
  deepspeed_activation_checkpointing  False
  clip_grad .................... 1.0
  train_iters .................. 300
  log_interval ................. 100
  exit_interval ................ None
  seed ......................... 1234
  reset_position_ids ........... False
  reset_attention_mask ......... False
  lr_decay_iters ............... None
  lr_decay_style ............... cosine
  lr ........................... 0.00015
  warmup ....................... 0.01
  save ......................... checkpoints/gpt2_345m
  save_interval ................ 5000
  no_save_optim ................ False
  no_save_rng .................. False
  load ......................... checkpoints/gpt2_345m
  no_load_optim ................ False
  no_load_rng .................. False
  finetune ..................... False
  resume_dataloader ............ True
  distributed_backend .......... nccl
  local_rank ................... 0
  eval_batch_size .............. None
  eval_iters ................... 100
  eval_interval ................ 1000
  eval_seq_length .............. None
  eval_max_preds_per_seq ....... None
  overlapping_eval ............. 32
  cloze_eval ................... False
  eval_hf ...................... False
  load_openai .................. False
  temperature .................. 1.0
  top_p ........................ 0.0
  top_k ........................ 0
  out_seq_length ............... 256
  model_parallel_size .......... 1
  shuffle ...................... False
  train_data ................... ['webtext']
  use_npy_data_loader .......... False
  train_data_path .............. 
  val_data_path ................ 
  test_data_path ............... 
  input_data_sizes_file ........ sizes.txt
  delim ........................ ,
  text_key ..................... sentence
  eval_text_key ................ None
  valid_data ................... None
  split ........................ 400,300,300
  test_data .................... None
  lazy_loader .................. True
  loose_json ................... False
  presplit_sentences ........... False
  num_workers .................. 2
  tokenizer_model_type ......... bert-large-uncased
  tokenizer_path ............... tokenizer.model
  tokenizer_type ............... GPT2BPETokenizer
  cache_dir .................... cache
  use_tfrecords ................ False
  seq_length ................... 1024
  max_preds_per_seq ............ None
  deepspeed .................... False
  deepspeed_config ............. None
  deepscale .................... False
  deepscale_config ............. None
  deepspeed_mpi ................ False
  cuda ......................... True
  rank ......................... 0
  world_size ................... 2
  dynamic_loss_scale ........... True
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
configuring data
> padded vocab (size: 50257) with 47 dummy tokens (new size: 50304)
> found end-of-document token: 50256
building GPT2 model ...
 > number of parameters on model parallel rank 0: 354871296
Optimizer = FusedAdam
Optimizer = FusedAdam
learning rate decaying cosine
WARNING: could not find the metadata file checkpoints/gpt2_345m/latest_checkpointed_iteration.txt 
    will not load any checkpoints and will start from random
Partition Activations False and Correctness Check False
 iteration      100/     300 | elapsed time per iteration (ms): 1048.5 | learning rate 1.258E-04 | lm loss 4.799004E+00 | loss scale 32768.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
  warnings.warn(
after 100 iterations memory (MB) | allocated: 6784.88427734375 | max allocated: 11927.470703125 | cached: 13826.0 | max cached: 13826.0
time (ms) | forward: 284.78 | backward: 749.95 | allreduce: 93.32 | optimizer: 13.60 | batch generator: 14.88 | data loader: 14.19
 iteration      200/     300 | elapsed time per iteration (ms): 1020.9 | learning rate 5.257E-05 | lm loss 7.708308E-02 | loss scale 32768.0 |
time (ms) | forward: 256.87 | backward: 747.37 | allreduce: 93.08 | optimizer: 16.52 | batch generator: 0.71 | data loader: 0.11
 iteration      300/     300 | elapsed time per iteration (ms): 1018.4 | learning rate 1.806E-06 | lm loss 4.669175E-03 | loss scale 32768.0 |
time (ms) | forward: 256.74 | backward: 744.96 | allreduce: 93.51 | optimizer: 16.53 | batch generator: 0.73 | data loader: 0.12
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
 validation loss at the end of training for val data | LM loss: 1.170473E+01 | LM PPL: 1.211437E+05
----------------------------------------------------------------------------------------------------
global rank 0 is saving checkpoint at iteration     300 to checkpoints/gpt2_345m/iter_0000300/mp_rank_00/model_optim_rng.pt
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/module.py UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
  successfully saved checkpoints/gpt2_345m/iter_0000300/mp_rank_00/model_optim_rng.pt
Evaluating iter 100/100
----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
 validation loss at the end of training for test data | LM loss: 1.169765E+01 | LM PPL: 1.202885E+05
-----------------------------------------------------------------------------------------------------

顯存占用截圖:

3c0deb18-0c1b-11ee-962d-dac502259ad0.png由于是數(shù)據(jù)并行,單張卡的顯存占用和使用單卡進(jìn)行訓(xùn)練時差不多。

基于數(shù)據(jù)并行訓(xùn)練出的模型進(jìn)行推理也可以正常運(yùn)行:

3c14f6f6-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

2卡模型并行

我們使用這個腳本DeepSpeedExamples/Megatron-LM/scripts/pretrain_gpt2_model_parallel.sh來進(jìn)行2卡的模型并行訓(xùn)練,除了2卡數(shù)據(jù)并行相關(guān)的修改之外我們還需要去掉這個腳本里面的--deepspeed參數(shù),因為要使用上DeepSpeed還需要執(zhí)行deepspeed的config配置文件。和deepspeed相關(guān)的訓(xùn)練特性,我們留到下一篇文章中探索。

使用bash scripts/pretrain_gpt2_model_parallel.sh 啟動2卡的模型并行訓(xùn)練。日志:

/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/distributed/launch.py FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use-env is set by default in torchrun.
If your script expects `--local-rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

  warnings.warn(
WARNING
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Setting ds_accelerator to cuda (auto detect)
Setting ds_accelerator to cuda (auto detect)
using world size: 2 and model-parallel size: 2 
 > using dynamic loss scaling
> initializing model parallel with size 2
Pretrain GPT2 model
arguments:
  pretrained_bert .............. False
  attention_dropout ............ 0.1
  num_attention_heads .......... 16
  hidden_size .................. 1024
  intermediate_size ............ None
  num_layers ................... 24
  layernorm_epsilon ............ 1e-05
  hidden_dropout ............... 0.1
  max_position_embeddings ...... 1024
  vocab_size ................... 30522
  deep_init .................... False
  make_vocab_size_divisible_by . 128
  cpu_optimizer ................ False
  cpu_torch_adam ............... False
  fp16 ......................... True
  fp32_embedding ............... False
  fp32_layernorm ............... False
  fp32_tokentypes .............. False
  fp32_allreduce ............... False
  hysteresis ................... 2
  loss_scale ................... None
  loss_scale_window ............ 1000
  min_scale .................... 1
  batch_size ................... 8
  weight_decay ................. 0.01
  checkpoint_activations ....... True
  checkpoint_num_layers ........ 1
  deepspeed_activation_checkpointing  False
  clip_grad .................... 1.0
  train_iters .................. 600
  log_interval ................. 100
  exit_interval ................ None
  seed ......................... 1234
  reset_position_ids ........... False
  reset_attention_mask ......... False
  lr_decay_iters ............... None
  lr_decay_style ............... cosine
  lr ........................... 0.00015
  warmup ....................... 0.01
  save ......................... checkpoints/gpt2_345m_mp2
  save_interval ................ 5000
  no_save_optim ................ False
  no_save_rng .................. False
  load ......................... checkpoints/gpt2_345m_mp2
  no_load_optim ................ True
  no_load_rng .................. False
  finetune ..................... False
  resume_dataloader ............ True
  distributed_backend .......... nccl
  local_rank ................... 0
  eval_batch_size .............. None
  eval_iters ................... 100
  eval_interval ................ 1000
  eval_seq_length .............. None
  eval_max_preds_per_seq ....... None
  overlapping_eval ............. 32
  cloze_eval ................... False
  eval_hf ...................... False
  load_openai .................. False
  temperature .................. 1.0
  top_p ........................ 0.0
  top_k ........................ 0
  out_seq_length ............... 256
  model_parallel_size .......... 2
  shuffle ...................... False
  train_data ................... ['webtext']
  use_npy_data_loader .......... False
  train_data_path .............. 
  val_data_path ................ 
  test_data_path ............... 
  input_data_sizes_file ........ sizes.txt
  delim ........................ ,
  text_key ..................... sentence
  eval_text_key ................ None
  valid_data ................... None
  split ........................ 400,300,300
  test_data .................... None
  lazy_loader .................. True
  loose_json ................... False
  presplit_sentences ........... False
  num_workers .................. 2
  tokenizer_model_type ......... bert-large-uncased
  tokenizer_path ............... tokenizer.model
  tokenizer_type ............... GPT2BPETokenizer
  cache_dir .................... None
  use_tfrecords ................ False
  seq_length ................... 1024
  max_preds_per_seq ............ None
  deepspeed .................... False
  deepspeed_config ............. None
  deepscale .................... False
  deepscale_config ............. None
  deepspeed_mpi ................ False
  cuda ......................... True
  rank ......................... 0
  world_size ................... 2
  dynamic_loss_scale ........... True
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234
configuring data
> padded vocab (size: 50257) with 175 dummy tokens (new size: 50432)
> found end-of-document token: 50256
building GPT2 model ...
 > number of parameters on model parallel rank 0: 178100224
 > number of parameters on model parallel rank 1: 178100224
Optimizer = FusedAdam
learning rate decaying cosine
WARNING: could not find the metadata file checkpoints/gpt2_345m_mp2/latest_checkpointed_iteration.txt 
    will not load any checkpoints and will start from random
Optimizer = FusedAdam
Partition Activations False and Correctness Check False
s iteration      100/     600 | elapsed time per iteration (ms): 810.9 | learning rate 1.444E-04 | lm loss 5.023855E+00 | loss scale 8192.0 |
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/cuda/memory.py FutureWarning: torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved
  warnings.warn(
after 100 iterations memory (MB) | allocated: 3447.24365234375 | max allocated: 6237.830078125 | cached: 7890.0 | max cached: 7890.0
time (ms) | forward: 252.44 | backward: 550.96 | allreduce: 12.11 | optimizer: 7.26 | batch generator: 7.15 | data loader: 6.35
 iteration      200/     600 | elapsed time per iteration (ms): 844.2 | learning rate 1.210E-04 | lm loss 1.112287E-01 | loss scale 8192.0 |
time (ms) | forward: 242.53 | backward: 589.63 | allreduce: 11.37 | optimizer: 10.92 | batch generator: 4.28 | data loader: 2.71
 iteration      300/     600 | elapsed time per iteration (ms): 824.7 | learning rate 8.518E-05 | lm loss 8.868908E-03 | loss scale 8192.0 |
time (ms) | forward: 240.10 | backward: 572.66 | allreduce: 11.63 | optimizer: 11.32 | batch generator: 3.64 | data loader: 2.12
 iteration      400/     600 | elapsed time per iteration (ms): 790.5 | learning rate 4.666E-05 | lm loss 2.208042E-03 | loss scale 8192.0 |
time (ms) | forward: 233.81 | backward: 547.29 | allreduce: 11.90 | optimizer: 9.11 | batch generator: 1.16 | data loader: 0.21
 iteration      500/     600 | elapsed time per iteration (ms): 792.8 | learning rate 1.574E-05 | lm loss 8.129998E-04 | loss scale 8192.0 |
time (ms) | forward: 234.04 | backward: 549.56 | allreduce: 13.62 | optimizer: 9.02 | batch generator: 0.91 | data loader: 0.16
 iteration      600/     600 | elapsed time per iteration (ms): 787.7 | learning rate 6.939E-07 | lm loss 6.003926E-04 | loss scale 8192.0 |
time (ms) | forward: 234.25 | backward: 544.30 | allreduce: 10.23 | optimizer: 9.00 | batch generator: 0.83 | data loader: 0.12
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
 validation loss at the end of training for val data | LM loss: 1.231077E+01 | LM PPL: 2.220759E+05
----------------------------------------------------------------------------------------------------
global rank 1 is saving checkpoint at iteration     600 to checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_01/model_optim_rng.pt
global rank 0 is saving checkpoint at iteration     600 to checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_00/model_optim_rng.pt
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/module.py UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
/home/zhangxiaoyu/miniconda3/envs/eval/lib/python3.9/site-packages/torch/nn/modules/module.py UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
  successfully saved checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_01/model_optim_rng.pt
  successfully saved checkpoints/gpt2_345m_mp2/iter_0000600/mp_rank_00/model_optim_rng.pt
Evaluating iter 100/100
----------------------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------------------
 validation loss at the end of training for test data | LM loss: 1.215604E+01 | LM PPL: 1.902403E+05
-----------------------------------------------------------------------------------------------------

顯存占用截圖:

3c1e2190-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

由于對模型參數(shù)進(jìn)行了切分,現(xiàn)在單卡的顯存占用峰值從數(shù)據(jù)并行的15個G左右降低到了9個G。

這里如果直接使用這個模型進(jìn)行推理,會在load checkpoint的時候出現(xiàn)參數(shù)和模型定義不匹配的問題。這是因為這個版本的Meagtron代碼沒有考慮到加載模型并行訓(xùn)練存儲下來的模型,所以這里只能通過把兩個模型并行的子模型合并為一個完整的單卡模型來讓Megatron加載并進(jìn)行推理。

3c233504-0c1b-11ee-962d-dac502259ad0.png但這但本文所在的這份Megatron-LM源碼中也沒有提供模型合并的工具,所以這里就不對這個模型并行訓(xùn)練的模型進(jìn)行推理了。如果你想對模型并行訓(xùn)練的checkpoint進(jìn)行推理,最簡單的方法就是直接用nvidia的Megatron-LM的最新代碼進(jìn)行模型訓(xùn)練和推理,它不僅支持模型并行還支持流水并行并且可以加載任意組合并行的模型進(jìn)行推理。此外,官方Megatron還提供了工具將原始任意模型并行大小和流水并行大小的checkpoint轉(zhuǎn)換為用戶指定的模型并行大小和流水并行大小的checkpoint。(https://github.com/NVIDIA/Megatron-LM/tree/main#evaluation-and-tasks) 如下圖所示:

3c2d67ae-0c1b-11ee-962d-dac502259ad0.png在這里插入圖片描述

審核編輯:湯梓紅

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    27

    文章

    4591

    瀏覽量

    128144
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3032

    瀏覽量

    48357
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4671

    瀏覽量

    67765
  • GitHub
    +關(guān)注

    關(guān)注

    3

    文章

    461

    瀏覽量

    16235
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    794

    瀏覽量

    13010

原文標(biāo)題:0x5. 總結(jié)

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    GPT2模塊的相關(guān)資料推薦

    定時器 GPT2 模塊這次簡單介紹下GPT2GPT2 模塊框圖**捕捉/重載寄存器 CAPREL 可用來捕捉定時器 T5 的值, 或者重載定時器 T6。 同時, 特殊模式促進(jìn)兩種功能對寄存器
    發(fā)表于 02-22 07:57

    LINUX中的IMX6ULL如何設(shè)置GPT2為輸入捕捉模式來捕捉外部PWM信號頻率?

    LINUX中的IMX6ULL如何設(shè)置GPT2為輸入捕捉模式來捕捉外部PWM信號頻率?我嘗試自己寫了一個Capture的驅(qū)動程序,但是我無法觸發(fā)捕獲中斷。請告訴我如何正確配置它。
    發(fā)表于 04-04 07:21

    IMX6ULL中如何使用GPT2的capture1捕捉外部PWM信號?

    ;gpt2 { compatible = \"ainuode,gpt2\" ; pinctrl-name = \"默認(rèn)\" ; pinctrl-0 = <
    發(fā)表于 04-14 06:36

    OpenAI發(fā)布了一個“逆天”的AI模型——GPT2整個模型包含15億個參數(shù)

    能有這樣出色的表現(xiàn),不是沒有原因的,GPT-2各種特定領(lǐng)域的語言建模任務(wù)中都取得了很好的分?jǐn)?shù)。作為一個沒有經(jīng)過任何領(lǐng)域數(shù)據(jù)專門訓(xùn)練模型,它的表現(xiàn),比那些專為特定領(lǐng)域數(shù)據(jù)集(例如維基百科,新聞,書籍)上
    的頭像 發(fā)表于 03-07 14:45 ?8210次閱讀

    超大Transformer語言模型的分布式訓(xùn)練框架

    模型的預(yù)訓(xùn)練計算。 大模型是大勢所趨 近年來,NLP 模型的發(fā)展十分迅速,模型的大小每年以1-2
    的頭像 發(fā)表于 10-11 16:46 ?2544次閱讀
    超大Transformer語言<b class='flag-5'>模型</b>的分布式<b class='flag-5'>訓(xùn)練</b>框架

    浪潮、英偉達(dá)微軟相繼發(fā)布2500億、5300億參數(shù)的巨量模型,超過GPT-3

    由于模型越來越大,訓(xùn)練過程中硬件的優(yōu)化變得尤為重要。從2019年下半年開始,各家分別開發(fā)出大規(guī)模并行訓(xùn)練、模型擴(kuò)展技術(shù),以期開發(fā)出更大的NLP模型
    的頭像 發(fā)表于 10-18 14:41 ?3417次閱讀
    浪潮、英偉達(dá)微軟相繼發(fā)布2500億、5300億參數(shù)的巨量<b class='flag-5'>模型</b>,超過<b class='flag-5'>GPT</b>-3

    AURIX系列之TC275學(xué)習(xí)筆記(四):GPT2 模塊

    定時器 GPT2 模塊這次簡單介紹下GPT2 GPT2 模塊框圖**捕捉/重載寄存器
    發(fā)表于 12-27 19:18 ?17次下載
    AURIX系列之TC275學(xué)習(xí)<b class='flag-5'>筆記</b>(四):<b class='flag-5'>GPT2</b> 模塊

    GPT/GPT-2/GPT-3/InstructGPT進(jìn)化之路

    在預(yù)訓(xùn)練階段,GPT 選擇 transformer 的 decoder 部分作為模型的主要模塊,transformer 是 2017年 google 提出的一種特征抽取模型,
    的頭像 發(fā)表于 03-03 11:14 ?3550次閱讀

    模型及ChatGPT核心技術(shù)論文

    從Transformer提出到“大規(guī)模預(yù)訓(xùn)練模型GPT(Generative Pre-Training)的誕生,再到GPT2的迭代標(biāo)志Open AI成為營利性公司,以及
    的頭像 發(fā)表于 05-16 09:56 ?763次閱讀
    大<b class='flag-5'>模型</b>及ChatGPT核心技術(shù)論文

    圖解大模型系列之:Megatron源碼解讀1,分布式環(huán)境初始化

    使用Megatron訓(xùn)練gpt類大模型的項目有很多。在這個系列里,我選擇了由THUDM開發(fā)的CodeGeeX項目,它是gpt在代碼生成方向
    的頭像 發(fā)表于 06-06 15:22 ?5256次閱讀
    圖解大<b class='flag-5'>模型</b>系列之:<b class='flag-5'>Megatron</b>源碼解讀1,分布式環(huán)境初始化

    圖解大模型訓(xùn)練之:Megatron源碼解讀2,模型并行

    前文說過,用Megatron做分布式訓(xùn)練的開源大模型有很多,我們選用的是THUDM開源的CodeGeeX(代碼生成式大模型,類比于openAI Codex)。選用它的原因是“完全開源”
    的頭像 發(fā)表于 06-07 15:08 ?3669次閱讀
    圖解大<b class='flag-5'>模型</b><b class='flag-5'>訓(xùn)練</b>之:<b class='flag-5'>Megatron</b>源碼解讀<b class='flag-5'>2</b>,<b class='flag-5'>模型</b>并行

    DeepSpeed里面和Zero相關(guān)技術(shù)教程

    使用原始的 Megatron-LM 訓(xùn)練 GPT2 設(shè)置訓(xùn)練數(shù)據(jù) 運(yùn)行未修改的Megatron-LM G
    的頭像 發(fā)表于 06-12 10:25 ?3436次閱讀
    <b class='flag-5'>DeepSpeed</b>里面和Zero相關(guān)技術(shù)教程

    DeepSpeed安裝和使用教程

    本文翻譯了 Getting Started 和 Installation Details 和 CIFAR-10 Tutorial 三個教程,可以讓新手安裝和簡單使用上 DeepSpeed 來做模型訓(xùn)練。
    的頭像 發(fā)表于 06-20 11:47 ?9067次閱讀

    Profile工作判斷模型的計算以及內(nèi)存瓶頸

    DeepSpeed運(yùn)行時一起使用 在Megatron-LM中使用 在 DeepSpeed 運(yùn)行環(huán)境之外的使用方法 訓(xùn)練工作流例子 0x0. 前言 這篇翻譯是對 https://ww
    的頭像 發(fā)表于 06-26 10:45 ?1251次閱讀

    基于PyTorch的模型并行分布式訓(xùn)練Megatron解析

    NVIDIA Megatron 是一個基于 PyTorch 的分布式訓(xùn)練框架,用來訓(xùn)練超大Transformer語言模型,其通過綜合應(yīng)用了數(shù)據(jù)并行,Tensor并行和Pipeline并
    的頭像 發(fā)表于 10-23 11:01 ?2331次閱讀
    基于PyTorch的<b class='flag-5'>模型</b>并行分布式<b class='flag-5'>訓(xùn)練</b><b class='flag-5'>Megatron</b>解析