0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

那些年在pytorch上踩過的坑

jf_78858299 ? 來源:天宏NLP ? 作者:tianhongzxy ? 2023-02-22 14:18 ? 次閱讀

今天又發(fā)現(xiàn)了一個pytorch的小坑,給大家分享一下。手上兩份同一模型的代碼,一份用tensorflow寫的,另一份是我拿pytorch寫的,模型架構(gòu)一模一樣,預(yù)處理數(shù)據(jù)的邏輯也一模一樣,測試發(fā)現(xiàn)模型推理的速度也差不多。一份預(yù)處理代碼是為pytorch模型寫的,用到的庫是torch,另一份是為tensorflow寫的,用到的是numpy。在訓(xùn)練時,每個epoch耗時居然差距非常大,pytorch的代碼在140w條數(shù)據(jù)上訓(xùn)練每輪耗時約45min,而tensorflow版的代碼耗時僅約12min。

我把代碼看了又看,百思不得其解,預(yù)處理的代碼比較復(fù)雜,都包含兩個for循環(huán),pytorch版代碼我把更多的預(yù)處理步驟放到了Dataset里,這樣訓(xùn)練時加載每個batch后,再要處理的步驟就更少了,速度也應(yīng)該更快,而tensorflow版代碼的for循環(huán)里預(yù)處理的步驟明明更多,怎么會速度比我的代碼還快呢?然而,經(jīng)過我的測試發(fā)現(xiàn),從加載每個batch的數(shù)據(jù)進來開始,經(jīng)過預(yù)處理,直到輸入到模型做計算前,兩者的耗時差了約7~8倍。最后發(fā)現(xiàn)問題出在對pytorch的tensor進行了頻繁的索引操作。

下面做個實驗給大家直觀體驗一下,對tensor做索引和對array做索引的速度差距有多大,tensorarray都是大小(1000x1000)的二維數(shù)組。

Pytorch(version==1.4.1)索引1000000次耗時:3.51秒

圖片

Numpy索引1000000次耗時:0.43秒

圖片

我還特意對比了一下對TensorFlow的tensor做索引的耗時

TensorFlow(version==2.1.0)索引1000000次耗時:118.89秒

圖片

由此可見tensorarray的索引速度至少差距在10倍,不過這也在情理之中,畢竟tensor要比array“重”得多。因此在使用pytorch和tensorflow時,頻繁需要索引的操作一定要先把tensor轉(zhuǎn)換為numpy.array來做!

除此之外,與其對二維數(shù)組進行索引,不如將其展平為一維數(shù)組,算上展平的時間,速度還會有不少提升。

Pytorch從3.51秒降到了1.94秒

圖片

Numpy從0.43秒降到了0.29秒

圖片

如果在訓(xùn)練和數(shù)據(jù)預(yù)處理過程中發(fā)現(xiàn)自己的代碼跑起來速度非常慢,記得看一看有沒有對tensor做太多次索引,如果有的話,要把它轉(zhuǎn)為numpy.array,還有,盡量把二維、三維的索引變成一維的索引,這些都能加快你訓(xùn)練模型的速度。

PS:最后我的代碼終于訓(xùn)練一輪也只需要不到12min了,后來又找了點加速的辦法,把訓(xùn)練一輪的時間控制到了9min以內(nèi),這些就放在以后再寫吧~

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4722

    瀏覽量

    68234
  • tensorflow
    +關(guān)注

    關(guān)注

    13

    文章

    328

    瀏覽量

    60473
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    802

    瀏覽量

    13115
收藏 人收藏

    評論

    相關(guān)推薦

    使用STM32采集電池電壓那些

    本文來解析一個盆友在使用STM32采集電池電壓。以STM32F4 的ADC屬于逐次逼近SAR 型ADC為例進行分析,參考STM32F405xxDatasheet,對于如何編寫ADC程序就不做描述了。
    發(fā)表于 03-01 07:39

    開發(fā)STM32 USB HID

    記錄一下 開發(fā)STM32 USB HID一、前言二、代碼配置一、前言MCU: STM32F103C8T6CubeMX: STM32CubeMX 5.3.0二、代碼配置引腳配置時鐘樹配置我
    發(fā)表于 08-24 07:15

    使用樹莓派搭建stm32開發(fā)環(huán)境以及碰到的問題

    使用樹莓派搭建stm32開發(fā)環(huán)境了很多,下面主要是記錄一下,以及碰到的問題。##開發(fā)方式的選擇1.使用Eclipse+GDB+O
    發(fā)表于 08-24 07:47

    Linux學(xué)習(xí)過程與如何解決

    Linux記錄記錄Linux學(xué)習(xí)過程與如何解決
    發(fā)表于 11-04 08:44

    移植debian系統(tǒng)

    基本的linux系統(tǒng),板子的交叉編譯器是arm-linux-gnueabihf-gcc,這給我?guī)砹瞬簧俚穆闊?,以至于想重新移植一下debian系統(tǒng)。ok,轉(zhuǎn)入正題,說說這兩天我吧。首先...
    發(fā)表于 12-14 08:42

    使用MDK5時出現(xiàn)的一些error分享

    使用MDK5時出現(xiàn)的一些error分享
    發(fā)表于 12-17 07:49

    關(guān)于RK1808板子調(diào)試過程記錄

    關(guān)于RK1808板子調(diào)試過程記錄
    發(fā)表于 02-16 06:38

    STM32G070CB cubemx串口調(diào)試哪些

    使用G070CB時寫的中斷程序是怎樣的?STM32G070CB cubemx串口調(diào)試哪些呢?
    發(fā)表于 02-18 06:08

    專訪技術(shù)創(chuàng)業(yè)工程師吳才澤:感恩這些年

    本期采訪對象技術(shù)創(chuàng)業(yè)工程師吳才澤,這些年從工程師到創(chuàng)業(yè)那些呢?
    發(fā)表于 11-25 16:53 ?3350次閱讀

    使用STM32采集電池電壓資料下載

    電子發(fā)燒友網(wǎng)為你提供使用STM32采集電池電壓資料下載的電子資料下載,更有其他相關(guān)的電路圖、源代碼、課件教程、中文資料、英文資料、參考設(shè)計、用戶指南、解決方案等資料,希望可以幫助到廣大的電子工程師們。
    發(fā)表于 04-05 08:49 ?73次下載
    使用STM32采集電池電壓<b class='flag-5'>踩</b><b class='flag-5'>過</b>的<b class='flag-5'>坑</b>資料下載

    嵌入式Linux記錄

    Linux記錄記錄Linux學(xué)習(xí)過程與如何解決
    發(fā)表于 11-01 17:21 ?10次下載
    嵌入式Linux<b class='flag-5'>踩</b><b class='flag-5'>坑</b>記錄

    Arduino-IDE配置ESP32-CAM開發(fā)環(huán)境那些

    Arduino-IDE配置ESP32-CAM開發(fā)環(huán)境那些
    發(fā)表于 11-30 18:36 ?24次下載
    Arduino-IDE配置ESP32-CAM開發(fā)環(huán)境<b class='flag-5'>踩</b><b class='flag-5'>過</b>的<b class='flag-5'>那些</b><b class='flag-5'>坑</b>

    推挽電路的,你沒?

    推挽電路的,你沒?
    的頭像 發(fā)表于 11-24 16:25 ?1043次閱讀
    推挽電路的<b class='flag-5'>坑</b>,你<b class='flag-5'>踩</b><b class='flag-5'>過</b>沒?

    關(guān)于圖像傳感器圖像質(zhì)量的四大誤區(qū)!你幾個?

    關(guān)于圖像傳感器圖像質(zhì)量的四大誤區(qū)!你幾個?
    的頭像 發(fā)表于 11-27 16:56 ?409次閱讀
    關(guān)于圖像傳感器圖像質(zhì)量的四大誤區(qū)!你<b class='flag-5'>踩</b><b class='flag-5'>過</b>幾個<b class='flag-5'>坑</b>?

    反相輸入放大器的,你沒有?

    反相輸入放大器的,你沒有?
    的頭像 發(fā)表于 12-06 15:35 ?566次閱讀
    反相輸入放大器的<b class='flag-5'>坑</b>,你<b class='flag-5'>踩</b><b class='flag-5'>過</b>沒有?