PyTorch 是一個(gè)流行的開源機(jī)器學(xué)習(xí)庫,廣泛用于計(jì)算機(jī)視覺和自然語言處理等領(lǐng)域。它提供了強(qiáng)大的計(jì)算圖功能和動態(tài)圖特性,使得模型的構(gòu)建和調(diào)試變得更加靈活和直觀。
數(shù)據(jù)準(zhǔn)備
在訓(xùn)練模型之前,首先需要準(zhǔn)備好數(shù)據(jù)集。PyTorch 提供了 torch.utils.data.Dataset
和 torch.utils.data.DataLoader
兩個(gè)類來幫助我們加載和批量處理數(shù)據(jù)。
1. 定義 Dataset
Dataset
類需要我們實(shí)現(xiàn) __init__
、__len__
和 __getitem__
三個(gè)方法。__init__
方法用于初始化數(shù)據(jù)集,__len__
返回?cái)?shù)據(jù)集中的樣本數(shù)量,__getitem__
根據(jù)索引返回單個(gè)樣本。
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
label = self.labels[index]
return data, label
2. 使用 DataLoader
DataLoader
類用于封裝數(shù)據(jù)集,并提供批量加載、打亂數(shù)據(jù)和多線程加載等功能。
from torch.utils.data import DataLoader
dataset = CustomDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
模型定義
在 PyTorch 中,模型是通過繼承 torch.nn.Module
類來定義的。我們需要實(shí)現(xiàn) __init__
方法來定義網(wǎng)絡(luò)層,并實(shí)現(xiàn) forward
方法來定義前向傳播。
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(784, 128) # 以 MNIST 數(shù)據(jù)集為例
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
損失函數(shù)和優(yōu)化器
1. 選擇損失函數(shù)
PyTorch 提供了多種損失函數(shù),如 nn.CrossEntropyLoss
、nn.MSELoss
等。根據(jù)任務(wù)的不同,選擇合適的損失函數(shù)。
criterion = nn.CrossEntropyLoss()
2. 選擇優(yōu)化器
PyTorch 也提供了多種優(yōu)化器,如 torch.optim.SGD
、torch.optim.Adam
等。優(yōu)化器用于在訓(xùn)練過程中更新模型的權(quán)重。
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
訓(xùn)練循環(huán)
訓(xùn)練循環(huán)是模型訓(xùn)練的核心,它包括前向傳播、計(jì)算損失、反向傳播和權(quán)重更新。
model = MyModel()
num_epochs = 10
for epoch in range(num_epochs):
for data, labels in data_loader:
optimizer.zero_grad() # 清空梯度
outputs = model(data) # 前向傳播
loss = criterion(outputs, labels) # 計(jì)算損失
loss.backward() # 反向傳播
optimizer.step() # 更新權(quán)重
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
模型評估
在訓(xùn)練過程中,我們還需要定期評估模型的性能,以監(jiān)控訓(xùn)練進(jìn)度和過擬合情況。
def evaluate(model, data_loader):
model.eval() # 設(shè)置為評估模式
total = 0
correct = 0
with torch.no_grad(): # 禁用梯度計(jì)算
for data, labels in data_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
model.train() # 恢復(fù)訓(xùn)練模式
-
模型
+關(guān)注
關(guān)注
1文章
3108瀏覽量
48645 -
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8344瀏覽量
132287 -
自然語言處理
+關(guān)注
關(guān)注
1文章
594瀏覽量
13479 -
pytorch
+關(guān)注
關(guān)注
2文章
802瀏覽量
13109
發(fā)布評論請先 登錄
相關(guān)推薦
評論