0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內(nèi)不再提示

掌握PyTorch圖片分類的簡明教程

WpOh_rgznai100 ? 來源:lq ? 2019-07-18 15:24 ? 次閱讀

1.引文

深度學習的比賽中,圖片分類是很常見的比賽,同時也是很難取得特別高名次的比賽,因為圖片分類已經(jīng)被大家研究的很透徹,一些開源的網(wǎng)絡很容易取得高分。如果大家還掌握不了使用開源的網(wǎng)絡進行訓練,再慢慢去模型調(diào)優(yōu),很難取得較好的成績。

我們在[PyTorch小試牛刀]實戰(zhàn)六·準備自己的數(shù)據(jù)集用于訓練講解了如何制作自己的數(shù)據(jù)集用于訓練,這個教程在此基礎上,進行訓練與應用。

(實戰(zhàn)六鏈接:

https://blog.csdn.net/xiaosongshine/article/details/85225873)

2.數(shù)據(jù)介紹

數(shù)據(jù)下載地址:

https://download.csdn.net/download/xiaosongshine/11128410

這次的實戰(zhàn)使用的數(shù)據(jù)是交通標志數(shù)據(jù)集,共有62類交通標志。其中訓練集數(shù)據(jù)有4572張照片(每個類別大概七十個),測試數(shù)據(jù)集有2520張照片(每個類別大概40個)。數(shù)據(jù)包含兩個子目錄分別train與test:

為什么還需要測試數(shù)據(jù)集呢?這個測試數(shù)據(jù)集不會拿來訓練,是用來進行模型的評估與調(diào)優(yōu)。

train與test每個文件夾里又有62個子文件夾,每個類別在同一個文件夾內(nèi):

我從中打開一個文件間,把里面圖片展示出來:

其中每張照片都類似下面的例子,100*100*3的大小。100是照片的照片的長和寬,3是什么呢?這其實是照片的色彩通道數(shù)目,RGB。彩色照片存儲在計算機里就是以三維數(shù)組的形式。我們送入網(wǎng)絡的也是這些數(shù)組。

3.網(wǎng)絡構建

1.導入Python包,定義一些參數(shù)

1importtorchast 2importtorchvisionastv 3importos 4importtime 5importnumpyasnp 6fromtqdmimporttqdm 7 8 9classDefaultConfigs(object):1011data_dir="./traffic-sign/"12data_list=["train","test"]1314lr=0.00115epochs=1016num_classes=6217image_size=22418batch_size=4019channels=320gpu="0"21train_len=457222test_len=252023use_gpu=t.cuda.is_available()2425config=DefaultConfigs()

2.數(shù)據(jù)準備,采用PyTorch提供的讀取方式

注意一點Train數(shù)據(jù)需要進行隨機裁剪,Test數(shù)據(jù)不要進行裁剪了

1normalize=tv.transforms.Normalize(mean=[0.485,0.456,0.406], 2std=[0.229,0.224,0.225] 3) 4 5transform={ 6config.data_list[0]:tv.transforms.Compose( 7[tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]), 8tv.transforms.ToTensor(),normalize]#tv.transforms.Resize用于重設圖片大小 9),10config.data_list[1]:tv.transforms.Compose(11[tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize]12)13}1415datasets={16x:tv.datasets.ImageFolder(root=os.path.join(config.data_dir,x),transform=transform[x])17forxinconfig.data_list18}1920dataloader={21x:t.utils.data.DataLoader(dataset=datasets[x],22batch_size=config.batch_size,23shuffle=True24)25forxinconfig.data_list26}

3.構建網(wǎng)絡模型(使用resnet18進行遷移學習,訓練參數(shù)為最后一個全連接層 t.nn.Linear(512,num_classes))

1defget_model(num_classes): 2 3model=tv.models.resnet18(pretrained=True) 4forparmainmodel.parameters(): 5parma.requires_grad=False 6model.fc=t.nn.Sequential( 7t.nn.Dropout(p=0.3), 8t.nn.Linear(512,num_classes) 9)10return(model)

如果電腦硬件支持,可以把下述代碼屏蔽,則訓練整個網(wǎng)絡,最終準確率會上升,訓練數(shù)據(jù)會變慢。

1forparmainmodel.parameters():2parma.requires_grad=False

模型輸出

1ResNet( 2(conv1):Conv2d(3,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False) 3(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True) 4(relu):ReLU(inplace) 5(maxpool):MaxPool2d(kernel_size=3,stride=2,padding=1,dilation=1,ceil_mode=False) 6(layer1):Sequential( 7(0):BasicBlock( 8(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False) 9(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)10(relu):ReLU(inplace)11(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)12(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)13)14(1):BasicBlock(15(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)16(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)17(relu):ReLU(inplace)18(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)19(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)20)21)22(layer2):Sequential(23(0):BasicBlock(24(conv1):Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)25(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)26(relu):ReLU(inplace)27(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)28(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)29(downsample):Sequential(30(0):Conv2d(64,128,kernel_size=(1,1),stride=(2,2),bias=False)31(1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)32)33)34(1):BasicBlock(35(conv1):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)36(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)37(relu):ReLU(inplace)38(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)39(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)40)41)42(layer3):Sequential(43(0):BasicBlock(44(conv1):Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)45(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)46(relu):ReLU(inplace)47(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)48(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)49(downsample):Sequential(50(0):Conv2d(128,256,kernel_size=(1,1),stride=(2,2),bias=False)51(1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)52)53)54(1):BasicBlock(55(conv1):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)56(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)57(relu):ReLU(inplace)58(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)59(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)60)61)62(layer4):Sequential(63(0):BasicBlock(64(conv1):Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)65(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)66(relu):ReLU(inplace)67(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)68(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)69(downsample):Sequential(70(0):Conv2d(256,512,kernel_size=(1,1),stride=(2,2),bias=False)71(1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)72)73)74(1):BasicBlock(75(conv1):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)76(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)77(relu):ReLU(inplace)78(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)79(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)80)81)82(avgpool):AvgPool2d(kernel_size=7,stride=1,padding=0)83(fc):Sequential(84(0):Dropout(p=0.3)85(1):Linear(in_features=512,out_features=62,bias=True)86)87)

4.訓練模型(支持自動GPU加速)

1deftrain(epochs): 2 3model=get_model(config.num_classes) 4print(model) 5loss_f=t.nn.CrossEntropyLoss() 6if(config.use_gpu): 7model=model.cuda() 8loss_f=loss_f.cuda() 910opt=t.optim.Adam(model.fc.parameters(),lr=config.lr)11time_start=time.time()1213forepochinrange(epochs):14train_loss=[]15train_acc=[]16test_loss=[]17test_acc=[]18model.train(True)19print("Epoch{}/{}".format(epoch+1,epochs))20forbatch,datasintqdm(enumerate(iter(dataloader["train"]))):21x,y=datas22if(config.use_gpu):23x,y=x.cuda(),y.cuda()24y_=model(x)25#print(x.shape,y.shape,y_.shape)26_,pre_y_=t.max(y_,1)27pre_y=y28#print(y_.shape)29loss=loss_f(y_,pre_y)30#print(y_.shape)31acc=t.sum(pre_y_==pre_y)3233loss.backward()34opt.step()35opt.zero_grad()36if(config.use_gpu):37loss=loss.cpu()38acc=acc.cpu()39train_loss.append(loss.data)40train_acc.append(acc)41#if((batch+1)%5==0):42time_end=time.time()43print("Batch{},Trainloss:{:.4f},Trainacc:{:.4f},Time:{}"\44.format(batch+1,np.mean(train_loss)/config.batch_size,np.mean(train_acc)/config.batch_size,(time_end-time_start)))45time_start=time.time()4647model.train(False)48forbatch,datasintqdm(enumerate(iter(dataloader["test"]))):49x,y=datas50if(config.use_gpu):51x,y=x.cuda(),y.cuda()52y_=model(x)53#print(x.shape,y.shape,y_.shape)54_,pre_y_=t.max(y_,1)55pre_y=y56#print(y_.shape)57loss=loss_f(y_,pre_y)58acc=t.sum(pre_y_==pre_y)5960if(config.use_gpu):61loss=loss.cpu()62acc=acc.cpu()6364test_loss.append(loss.data)65test_acc.append(acc)66print("Batch{},Testloss:{:.4f},Testacc:{:.4f}".format(batch+1,np.mean(test_loss)/config.batch_size,np.mean(test_acc)/config.batch_size))6768t.save(model,str(epoch+1)+"ttmodel.pkl")69707172if__name__=="__main__":73train(config.epochs)

訓練結果如下:

1Epoch1/10 2115it[00:48,2.63it/s] 3Batch115,Trainloss:0.0590,Trainacc:0.4635,Time:48.985504150390625 463it[00:24,2.62it/s] 5Batch63,Testloss:0.0374,Testacc:0.6790,Time:24.648272275924683 6Epoch2/10 7115it[00:45,3.22it/s] 8Batch115,Trainloss:0.0271,Trainacc:0.7576,Time:45.68823838233948 963it[00:23,2.62it/s]10Batch63,Testloss:0.0255,Testacc:0.7524,Time:23.27178287506103511Epoch3/1012115it[00:45,3.19it/s]13Batch115,Trainloss:0.0181,Trainacc:0.8300,Time:45.926485061645511463it[00:23,2.60it/s]15Batch63,Testloss:0.0212,Testacc:0.7861,Time:23.8078927993774416Epoch4/1017115it[00:45,3.28it/s]18Batch115,Trainloss:0.0138,Trainacc:0.8767,Time:45.275250196456911963it[00:23,2.57it/s]20Batch63,Testloss:0.0173,Testacc:0.8385,Time:23.73632144927978521Epoch5/1022115it[00:44,3.22it/s]23Batch115,Trainloss:0.0112,Trainacc:0.8950,Time:44.9836382865905762463it[00:22,2.69it/s]25Batch63,Testloss:0.0156,Testacc:0.8520,Time:22.79007434844970726Epoch6/1027115it[00:44,3.19it/s]28Batch115,Trainloss:0.0095,Trainacc:0.9159,Time:45.104269504547122963it[00:22,2.77it/s]30Batch63,Testloss:0.0158,Testacc:0.8214,Time:22.8041245937347431Epoch7/1032115it[00:45,2.95it/s]33Batch115,Trainloss:0.0081,Trainacc:0.9280,Time:45.304390430450443463it[00:23,2.66it/s]35Batch63,Testloss:0.0139,Testacc:0.8528,Time:23.12237954139709536Epoch8/1037115it[00:44,3.23it/s]38Batch115,Trainloss:0.0073,Trainacc:0.9300,Time:44.3047628402709963963it[00:22,2.74it/s]40Batch63,Testloss:0.0142,Testacc:0.8496,Time:22.80183553695678741Epoch9/1042115it[00:43,3.19it/s]43Batch115,Trainloss:0.0068,Trainacc:0.9361,Time:44.084140300750734463it[00:23,2.44it/s]45Batch63,Testloss:0.0142,Testacc:0.8437,Time:23.60441923141479546Epoch10/1047115it[00:46,3.12it/s]48Batch115,Trainloss:0.0063,Trainacc:0.9337,Time:46.765970468521124963it[00:24,2.65it/s]50Batch63,Testloss:0.0130,Testacc:0.8591,Time:24.64351773262024

訓練10個Epoch,測試集準確率可以到達0.86,已經(jīng)達到不錯效果。通過修改參數(shù),增加訓練,可以達到更高的準確率。

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權轉載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學習之用,如有內(nèi)容侵權或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 數(shù)據(jù)集

    關注

    4

    文章

    1197

    瀏覽量

    24532
  • pytorch
    +關注

    關注

    2

    文章

    794

    瀏覽量

    13008

原文標題:實戰(zhàn):掌握PyTorch圖片分類的簡明教程 | 附完整代碼

文章出處:【微信號:rgznai100,微信公眾號:rgznai100】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    Protel DXP簡明教

    Protel DXP簡明教程電路設計自動化( Electronic Design Automation ) EDA 指的就是將電路設計中各種工作交由計算機來協(xié)助完成。如電路圖( Schematic
    發(fā)表于 12-09 15:13

    DSP中斷設置簡明教

    DSP中斷設置簡明教程 from hellodsp
    發(fā)表于 10-24 20:51

    protel簡明教

    protel簡明教
    發(fā)表于 01-13 13:47

    vivado簡明教

    本帖最后由 burkfly 于 2014-1-14 22:26 編輯 vivado簡明教程,初學者有用!
    發(fā)表于 01-14 22:22

    PROTEL簡明教

    簡明教程希望對你有用
    發(fā)表于 03-27 11:11

    Vivado 簡明教

    Vivado 簡明教
    發(fā)表于 05-07 11:25

    ADS版圖導入、編輯、仿真簡明教

    ADS版圖導入、編輯、仿真簡明教
    發(fā)表于 09-12 16:10 ?0次下載

    ZEMAX光學輔助設計簡明教

    ZEMAX光學輔助設計簡明教程 ZEMAX光學輔助設計簡明教
    發(fā)表于 10-30 17:57 ?0次下載

    電工學簡明教程習題+答案

    電工學簡明教程習題+答案高清pdf版本電工學簡明教程習題+答案高清pdf版本
    發(fā)表于 02-25 14:13 ?16次下載

    SKILL程序使用及開發(fā)簡明教

    Cadence SKILL語言的使用及開發(fā)簡明教程,從網(wǎng)上下的。。。
    發(fā)表于 08-26 15:09 ?0次下載

    Protel99簡明教

    本文檔詳細的介紹了Protel99使用的簡明教
    發(fā)表于 08-30 17:02 ?0次下載

    Altium-Designer-10簡明教

    Altium-Designer-10簡明教
    發(fā)表于 12-16 22:13 ?0次下載

    基于DSP中斷設置簡明教

    基于DSP中斷設置簡明教
    發(fā)表于 10-23 14:28 ?5次下載
    基于DSP中斷設置<b class='flag-5'>簡明教</b>程

    JCBus串口調(diào)試助手簡明教

    JCBus串口調(diào)試助手簡明教程說明。
    發(fā)表于 03-25 16:05 ?11次下載
    JCBus串口調(diào)試助手<b class='flag-5'>簡明教</b>程

    電磁兼容簡明教程(1)

    電磁兼容簡明教程(1)
    的頭像 發(fā)表于 12-05 16:23 ?457次閱讀
    電磁兼容<b class='flag-5'>簡明教</b>程(1)