import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import numpy as np
1. 数据加载与预处理
transform = Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
查看数据集样本(可选)
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(train_loader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images[:4]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
2. 构建网络
class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.3)
def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv2(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv3(x)))x = self.dropout(x)x = x.view(-1, 128 * 4 * 4)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x
net = Net()
3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
4. 训练网络
epochs = 30
train_losses = []
train_accs = []
test_losses = []
test_accs = []
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
net.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accs.append(train_acc)
# 测试
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()
test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
test_accs.append(test_acc)print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.2f}%')
print('Finished Training')
绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')
plt.show()
5. 测试模型精度
net.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
print(f'测试集精度: {100. * correct / total:.2f}%')
查看各类别预测精度(可选)
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print(f'类别 {classes[i]} 的精度: {100. * class_correct[i] / class_total[i]:.2f}%')