PyTorch 神经网络工具箱 - 实践

news/2025/9/24 16:21:43/文章来源:https://www.cnblogs.com/tlnshuju/p/19109454

PyTorch 神经网络工具箱 - 实践

2025-09-24 16:21  tlnshuju  阅读(0)  评论(0)    收藏  举报

在深度学习领域,PyTorch 凭借其动态计算图、简洁的 API 设计和强大的灵活性,成为科研与工程领域的热门框架。无论是搭建基础神经网络,还是实现复杂的自定义模块(如残差网络),PyTorch 都能提供高效的支持。本文将基于 PyTorch 核心知识点,从神经网络核心组件模型构建工具三种模型构建方法自定义网络模块完整训练流程,手把手带您掌握 PyTorch 构建神经网络的全流程。

一、神经网络核心组件:理解深度学习的 “积木”

神经网络的本质是通过 “层” 对数据进行变换,结合 “损失函数” 和 “优化器” 实现参数学习。这四大核心组件共同构成了深度学习模型的基础,其关系与作用如下表所示:

核心组件作用与说明
层(Layer)神经网络的基本结构单元,负责将输入张量(如图片、文本特征)通过权重变换为输出张量。常见层包括全连接层(Linear)、卷积层(Conv2d)、池化层(MaxPool2d)等。
模型(Model)由多个 “层” 按特定逻辑组合而成的网络结构,定义了数据从输入到输出的完整变换路径(如 CNN、RNN、Transformer)。
损失函数(Loss Function)参数学习的 “目标标尺”,量化模型预测值(Y')与真实值(Y)的差异。通过最小化损失函数,引导模型调整权重。常见损失函数:CrossEntropyLoss(分类)、MSELoss(回归)。
优化器(Optimizer)实现 “损失最小化” 的工具,通过反向传播计算的梯度,更新模型的可学习参数(如权重、偏置)。常见优化器:SGD(随机梯度下降)、Adam、RMSprop。

组件工作流程示意图

数据在组件间的流转逻辑可概括为:

  1. 输入数据(X)经过的变换,得到模型预测值(Y');
  2. 损失函数计算 Y' 与真实标签(Y)的差异(Loss);
  3. 优化器根据 Loss 的梯度,反向更新 “层” 中的权重参数;
  4. 重复上述步骤,直到 Loss 收敛到最小值。

二、PyTorch 模型构建核心工具:nn.Module vs nn.functional

PyTorch 提供了两种核心工具用于构建网络:nn.Module 和 nn.functional。两者功能有重叠,但设计理念和适用场景差异显著,掌握其区别是高效构建模型的关键。

1. 工具特性对比

对比维度nn.Modulenn.functional
本质面向对象的类(需实例化)纯函数(直接调用)
参数管理自动定义、存储和管理可学习参数(如 weight、bias)需手动定义和传入参数,无自动管理
与容器结合支持与nn.Sequential等容器结合,简化代码不支持与容器结合,需手动串联层
状态切换(如 Dropout)调用model.eval()后自动关闭 Dropout(测试模式)需手动传入training=True/False参数控制状态
适用场景卷积层、全连接层、Dropout 层等需参数管理的组件激活函数(ReLU、Sigmoid)、池化层等无参数组件

2. 代码示例:直观理解差异

以 “全连接层 + ReLU 激活” 为例,对比两种工具的使用方式:

(1)使用nn.Module
import torch
import torch.nn as nn
# 1. 实例化层(自动管理weight和bias)
linear = nn.Linear(in_features=10, out_features=5)  # 全连接层
relu = nn.ReLU()  # 激活函数(虽无参数,也可实例化)
# 2. 前向传播
x = torch.randn(3, 10)  # 输入:3个样本,每个样本10维特征
output = relu(linear(x))  # 先过全连接层,再过激活函数
print("nn.Module输出形状:", output.shape)  # torch.Size([3, 5])
(2)使用nn.functional
import torch
import torch.nn.functional as F
# 1. 手动定义参数(需指定形状,初始化权重)
weight = torch.randn(5, 10)  # 全连接层权重:out_features × in_features
bias = torch.randn(5)        # 偏置:out_features维
# 2. 前向传播(手动传入weight和bias)
x = torch.randn(3, 10)
output = F.relu(F.linear(x, weight, bias))  # 先调用linear函数,再调用relu函数
print("nn.functional输出形状:", output.shape)  # torch.Size([3, 5])

三、三种模型构建方法:从简单到灵活

PyTorch 支持多种模型构建方式,可根据模型复杂度选择。以下以 “MNIST 手写数字识别” 的全连接模型(输入 28×28=784 维,隐藏层 300/100 维,输出 10 类)为例,详细讲解三种构建方法。

方法 1:继承nn.Module基类(最灵活,推荐复杂模型)

当模型需要自定义前向传播逻辑(如分支结构、跳连)时,需继承nn.Module基类,并实现__init__(定义层)和forward(定义数据流转)两个方法。

代码实现
import torch
from torch import nn
import torch.nn.functional as F
# 定义模型类,继承nn.Module
class MNISTModel(nn.Module):
def __init__(self, in_dim=784, n_hidden1=300, n_hidden2=100, out_dim=10):
super(MNISTModel, self).__init__()  # 调用父类构造函数
# 1. 定义网络层(自动管理参数)
self.flatten = nn.Flatten()  # 将28×28图片展平为784维向量
self.linear1 = nn.Linear(in_dim, n_hidden1)  # 第1个全连接层
self.bn1 = nn.BatchNorm1d(n_hidden1)  # 批量归一化(加速训练)
self.linear2 = nn.Linear(n_hidden1, n_hidden2)  # 第2个全连接层
self.bn2 = nn.BatchNorm1d(n_hidden2)  # 批量归一化
self.out_layer = nn.Linear(n_hidden2, out_dim)  # 输出层(10类)
def forward(self, x):
# 2. 定义前向传播逻辑(数据流转路径)
x = self.flatten(x)          # 展平:(batch, 1, 28, 28) → (batch, 784)
x = self.linear1(x)          # 全连接1:(batch, 784) → (batch, 300)
x = self.bn1(x)              # 批量归一化
x = F.relu(x)                # 激活函数(引入非线性)
x = self.linear2(x)          # 全连接2:(batch, 300) → (batch, 100)
x = self.bn2(x)              # 批量归一化
x = F.relu(x)                # 激活函数
x = self.out_layer(x)        # 输出层:(batch, 100) → (batch, 10)
x = F.softmax(x, dim=1)      # 输出概率分布(dim=1表示按样本维度归一化)
return x
# 实例化模型并查看结构
in_dim, n_h1, n_h2, out_dim = 784, 300, 100, 10
model = MNISTModel(in_dim, n_h1, n_h2, out_dim)
print("MNISTModel结构:")
print(model)
运行结果
MNISTModel(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear1): Linear(in_features=784, out_features=300, bias=True)
(bn1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(linear2): Linear(in_features=300, out_features=100, bias=True)
(bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(out_layer): Linear(in_features=100, out_features=10, bias=True)
)

方法 2:使用nn.Sequential(简单线性模型,层顺序固定)

nn.Sequential是一个 “层容器”,可按顺序堆叠层,自动实现前向传播(无需手动写forward)。适合层顺序固定、无复杂分支的模型。它支持三种初始化方式:

方式 2.1:可变参数(最快,但无法指定层名)

直接将层作为可变参数传入,层名由容器自动分配(如 0、1、2...)。

import torch
from torch import nn
# 定义超参数
in_dim, n_h1, n_h2, out_dim = 784, 300, 100, 10
# 用可变参数初始化Sequential
model_seq = nn.Sequential(
nn.Flatten(),
nn.Linear(in_dim, n_h1),
nn.BatchNorm1d(n_h1),
nn.ReLU(),
nn.Linear(n_h1, n_h2),
nn.BatchNorm1d(n_h2),
nn.ReLU(),
nn.Linear(n_h2, out_dim),
nn.Softmax(dim=1)
)
print("nn.Sequential(可变参数)结构:")
print(model_seq)
方式 2.2:add_module方法(手动指定层名)

通过add_module("层名", 层实例)逐个添加层,适合需要明确层名的场景。

import torch
from torch import nn
in_dim, n_h1, n_h2, out_dim = 784, 300, 100, 10
# 初始化空的Sequential
model_seq_add = nn.Sequential()
# 逐个添加层并指定名称
model_seq_add.add_module("flatten", nn.Flatten())
model_seq_add.add_module("linear1", nn.Linear(in_dim, n_h1))
model_seq_add.add_module("bn1", nn.BatchNorm1d(n_h1))
model_seq_add.add_module("relu1", nn.ReLU())
model_seq_add.add_module("linear2", nn.Linear(n_h1, n_h2))
model_seq_add.add_module("bn2", nn.BatchNorm1d(n_h2))
model_seq_add.add_module("relu2", nn.ReLU())
model_seq_add.add_module("out_layer", nn.Linear(n_h2, out_dim))
model_seq_add.add_module("softmax", nn.Softmax(dim=1))
print("nn.Sequential(add_module)结构:")
print(model_seq_add)
方式 2.3:OrderedDict(有序键值对,层名 + 层实例)

通过collections.OrderedDict将 “层名 - 层实例” 按顺序封装,适合层名较多的场景。

import torch
from torch import nn
from collections import OrderedDict
in_dim, n_h1, n_h2, out_dim = 784, 300, 100, 10
# 用OrderedDict定义层名和层实例
layers_dict = OrderedDict([
("flatten", nn.Flatten()),
("linear1", nn.Linear(in_dim, n_h1)),
("bn1", nn.BatchNorm1d(n_h1)),
("relu1", nn.ReLU()),
("linear2", nn.Linear(n_h1, n_h2)),
("bn2", nn.BatchNorm1d(n_h2)),
("relu2", nn.ReLU()),
("out_layer", nn.Linear(n_h2, out_dim)),
("softmax", nn.Softmax(dim=1))
])
# 初始化Sequential
model_seq_ordered = nn.Sequential(layers_dict)
print("nn.Sequential(OrderedDict)结构:")
print(model_seq_ordered)

方法 3:继承nn.Module + 模型容器(兼顾灵活与简洁)

当模型需分 “模块” 管理(如残差网络的 “残差块”)时,可在nn.Module中嵌套模型容器nn.Sequentialnn.ModuleListnn.ModuleDict),既简化代码,又保留自定义逻辑。

容器 1:nn.Sequential(模块内顺序固定)

将模型的某个子模块(如 “全连接 + 批量归一化”)用nn.Sequential封装,减少代码冗余。

import torch
from torch import nn
import torch.nn.functional as F
class ModelWithSequential(nn.Module):
def __init__(self, in_dim=784, n_h1=300, n_h2=100, out_dim=10):
super(ModelWithSequential, self).__init__()
self.flatten = nn.Flatten()
# 用Sequential封装子模块1(linear1 + bn1)
self.block1 = nn.Sequential(
nn.Linear(in_dim, n_h1),
nn.BatchNorm1d(n_h1)
)
# 用Sequential封装子模块2(linear2 + bn2)
self.block2 = nn.Sequential(
nn.Linear(n_h1, n_h2),
nn.BatchNorm1d(n_h2)
)
# 输出层
self.out_block = nn.Sequential(
nn.Linear(n_h2, out_dim),
nn.Softmax(dim=1)
)
def forward(self, x):
x = self.flatten(x)
x = F.relu(self.block1(x))  # 子模块1 + 激活
x = F.relu(self.block2(x))  # 子模块2 + 激活
x = self.out_block(x)       # 输出子模块
return x
# 实例化模型
model_with_seq = ModelWithSequential()
print("ModelWithSequential结构:")
print(model_with_seq)
容器 2:nn.ModuleList(模块列表,支持迭代)

将层或子模块存入列表,支持索引访问和迭代,适合动态调整模块数量的场景(如可变层数的全连接网络)。

import torch
from torch import nn
import torch.nn.functional as F
class ModelWithModuleList(nn.Module):
def __init__(self, in_dim=784, hidden_dims=[300, 100], out_dim=10):
super(ModelWithModuleList, self).__init__()
self.flatten = nn.Flatten()
# 1. 构建模块列表:输入层→隐藏层1→隐藏层2
self.layers = nn.ModuleList()
prev_dim = in_dim  # 前一层的输出维度
for hidden_dim in hidden_dims:
self.layers.append(nn.Linear(prev_dim, hidden_dim))  # 全连接层
self.layers.append(nn.BatchNorm1d(hidden_dim))       # 批量归一化
prev_dim = hidden_dim
# 2. 输出层
self.layers.append(nn.Linear(prev_dim, out_dim))
self.layers.append(nn.Softmax(dim=1))
def forward(self, x):
x = self.flatten(x)
# 迭代模块列表,实现前向传播
for layer in self.layers:
# 对激活函数的特殊处理(此处简化,实际可单独判断)
if isinstance(layer, (nn.Linear, nn.BatchNorm1d, nn.Softmax)):
x = layer(x)
# 若有ReLU,可在此处添加:x = F.relu(x)
return x
# 实例化模型(隐藏层维度为[300, 100])
model_with_list = ModelWithModuleList(hidden_dims=[300, 100])
print("ModelWithModuleList结构:")
print(model_with_list)
容器 3:nn.ModuleDict(模块字典,支持键值访问)

将模块用 “键值对” 存储,可通过键名灵活调用模块,适合需要动态选择模块的场景(如多分支模型)。

import torch
from torch import nn
import torch.nn.functional as F
class ModelWithModuleDict(nn.Module):
def __init__(self, in_dim=784, n_h1=300, n_h2=100, out_dim=10):
super(ModelWithModuleDict, self).__init__()
# 1. 用字典存储模块
self.layers_dict = nn.ModuleDict({
"flatten": nn.Flatten(),
"linear1": nn.Linear(in_dim, n_h1),
"bn1": nn.BatchNorm1d(n_h1),
"relu": nn.ReLU(),
"linear2": nn.Linear(n_h1, n_h2),
"bn2": nn.BatchNorm1d(n_h2),
"out": nn.Linear(n_h2, out_dim),
"softmax": nn.Softmax(dim=1)
})
def forward(self, x):
# 2. 按键名顺序调用模块(自定义流转路径)
layer_order = ["flatten", "linear1", "bn1", "relu", "linear2", "bn2", "relu", "out", "softmax"]
for layer_name in layer_order:
x = self.layers_dict[layer_name](x)
return x
# 实例化模型
model_with_dict = ModelWithModuleDict()
print("ModelWithModuleDict结构:")
print(model_with_dict)

四、自定义网络模块:实现残差块与 ResNet18

在实际场景中,基础层往往无法满足需求(如深层网络的梯度消失问题)。此时需自定义复杂模块,以残差块(Residual Block)和 ResNet18 为例,讲解自定义模块的实现思路。

1. 残差块的核心思想

残差块通过 “跳连(Skip Connection)” 将输入直接加到输出,解决深层网络的梯度消失问题。根据输入与输出维度是否一致,残差块分为两种:

块 1:基础残差块(RestNetBasicBlock

输入与输出维度一致,直接跳连(无额外参数)。

import torch
import torch.nn as nn
import torch.nn.functional as F
class RestNetBasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(RestNetBasicBlock, self).__init__()
# 残差路径:2个3×3卷积 + 批量归一化
self.conv1 = nn.Conv2d(
in_channels, out_channels,
kernel_size=3, stride=stride, padding=1  # padding=1保证尺寸不变
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels,
kernel_size=3, stride=stride, padding=1
)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x  # 跳连:保存输入(残差)
# 残差路径计算
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 残差连接:输入 + 残差路径输出
out += residual
out = F.relu(out)
return out
块 2:下采样残差块(RestNetDownBlock

输入与输出维度不一致(如通道数增加、尺寸缩小),需通过 1×1 卷积调整输入维度,再进行跳连。

class RestNetDownBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=[2, 1]):
super(RestNetDownBlock, self).__init__()
# 残差路径:2个3×3卷积(第1个卷积下采样) + 批量归一化
self.conv1 = nn.Conv2d(
in_channels, out_channels,
kernel_size=3, stride=stride[0], padding=1  # stride=2下采样
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels,
kernel_size=3, stride=stride[1], padding=1  # stride=1保持尺寸
)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1×1卷积:调整输入维度(通道数+尺寸),适配残差路径输出
self.extra_conv = nn.Sequential(
nn.Conv2d(
in_channels, out_channels,
kernel_size=1, stride=stride[0], padding=0  # 下采样+通道调整
),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
residual = self.extra_conv(x)  # 1×1卷积调整输入维度
# 残差路径计算
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 残差连接:调整后的输入 + 残差路径输出
out += residual
out = F.relu(out)
return out

2. 组合残差块构建 ResNet18

ResNet18 由 “1 个初始卷积层 + 4 个残差组(含 2 个残差块)+1 个全局平均池化 + 1 个全连接层” 构成。

class ResNet18(nn.Module):
def __init__(self, num_classes=10):
super(ResNet18, self).__init__()
# 1. 初始卷积层(将输入图片(3通道)转换为64通道)
self.init_conv = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.init_bn = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 下采样
# 2. 残差组(4组,每组2个残差块)
self.layer1 = nn.Sequential(
RestNetBasicBlock(64, 64, stride=1),   # 输入64→输出64,无下采样
RestNetBasicBlock(64, 64, stride=1)
)
self.layer2 = nn.Sequential(
RestNetDownBlock(64, 128, stride=[2, 1]),  # 输入64→输出128,下采样
RestNetBasicBlock(128, 128, stride=1)
)
self.layer3 = nn.Sequential(
RestNetDownBlock(128, 256, stride=[2, 1]),  # 输入128→输出256,下采样
RestNetBasicBlock(256, 256, stride=1)
)
self.layer4 = nn.Sequential(
RestNetDownBlock(256, 512, stride=[2, 1]),  # 输入256→输出512,下采样
RestNetBasicBlock(512, 512, stride=1)
)
# 3. 全局平均池化(将512×7×7特征图转换为512×1×1)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
# 4. 全连接层(输出类别概率)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
# 初始卷积+池化
out = self.init_conv(x)
out = self.init_bn(out)
out = F.relu(out)
out = self.maxpool(out)
# 残差组
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
# 池化+全连接
out = self.avgpool(out)
out = out.reshape(out.shape[0], -1)  # 展平:(batch, 512, 1, 1) → (batch, 512)
out = self.fc(out)
return out
# 实例化ResNet18(假设分类10类)
resnet18 = ResNet18(num_classes=10)
print("ResNet18结构:")
print(resnet18)

五、模型训练全流程:从数据到可视化

掌握模型构建后,需通过训练让模型 “学习” 数据规律。以下以 ResNet18 训练 CIFAR-10 数据集为例,讲解完整训练流程。

步骤 1:加载并预处理数据集

使用torchvision加载公开数据集(如 CIFAR-10),并进行数据增强(如随机翻转、归一化),提升模型泛化能力。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 数据预处理与增强
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),  # 随机裁剪(32×32)
transforms.RandomHorizontalFlip(),     # 随机水平翻转
transforms.ToTensor(),                 # 转换为张量(0-1)
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # 归一化(CIFAR-10均值/标准差)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# 2. 加载数据集
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_train
)
test_dataset = datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_test
)
# 3. 构建数据加载器(批量读取数据)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

步骤 2:定义损失函数与优化器

  • 损失函数:CIFAR-10 是分类任务,使用CrossEntropyLoss(含 Softmax,无需手动添加)。
  • 优化器:使用Adam(自适应学习率,收敛更快),并添加权重衰减(L2 正则化,防止过拟合)。
import torch.optim as optim
# 1. 定义模型(ResNet18)
model = ResNet18(num_classes=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 优先使用GPU
model.to(device)  # 将模型移动到GPU/CPU
# 2. 定义损失函数
criterion = nn.CrossEntropyLoss()
# 3. 定义优化器
optimizer = optim.Adam(
model.parameters(),
lr=0.001,  # 初始学习率
weight_decay=5e-4  # 权重衰减(L2正则化)
)
# 4. 学习率调度器(可选,按epoch衰减学习率)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)  # 每20个epoch学习率减半

步骤 3:循环训练模型

训练流程核心是 “前向传播计算损失→反向传播计算梯度→优化器更新参数”,每个 epoch 包含 “训练阶段” 和 “验证阶段”。

import time
# 超参数
epochs = 50  # 训练轮次
best_acc = 0.0  # 最佳测试准确率
train_losses = []  # 记录训练损失
test_accs = []     # 记录测试准确率
for epoch in range(epochs):
# -------------------------- 训练阶段 --------------------------
model.train()  # 切换到训练模式(启用Dropout、BatchNorm更新)
train_loss = 0.0
start_time = time.time()
for batch_idx, (data, targets) in enumerate(train_loader):
# 1. 数据移动到设备(GPU/CPU)
data, targets = data.to(device), targets.to(device)
# 2. 清零梯度(防止梯度累积)
optimizer.zero_grad()
# 3. 前向传播:计算预测值
outputs = model(data)
# 4. 计算损失
loss = criterion(outputs, targets)
# 5. 反向传播:计算梯度
loss.backward()
# 6. 优化器更新参数
optimizer.step()
# 7. 累加损失
train_loss += loss.item() * data.size(0)  # 乘以batch_size,得到总损失
# 计算每个epoch的平均训练损失
avg_train_loss = train_loss / len(train_dataset)
train_losses.append(avg_train_loss)
# -------------------------- 测试阶段 --------------------------
model.eval()  # 切换到测试模式(关闭Dropout、固定BatchNorm)
test_acc = 0.0
with torch.no_grad():  # 禁用梯度计算,加速测试
for data, targets in test_loader:
data, targets = data.to(device), targets.to(device)
outputs = model(data)
_, predicted = torch.max(outputs, dim=1)  # 取概率最大的类别
test_acc += (predicted == targets).sum().item()  # 累加正确预测数
# 计算测试准确率
avg_test_acc = test_acc / len(test_dataset)
test_accs.append(avg_test_acc)
# 更新最佳准确率并保存模型
if avg_test_acc > best_acc:
best_acc = avg_test_acc
torch.save(model.state_dict(), "resnet18_best.pth")  # 保存模型权重
# 学习率调度
scheduler.step()
# 打印日志
end_time = time.time()
print(f"Epoch [{epoch+1}/{epochs}], "
f"Train Loss: {avg_train_loss:.4f}, "
f"Test Acc: {avg_test_acc:.4f}, "
f"Best Acc: {best_acc:.4f}, "
f"Time: {end_time - start_time:.2f}s")

步骤 4:可视化训练结果

使用matplotlib绘制 “训练损失曲线” 和 “测试准确率曲线”,直观分析模型训练趋势(如是否过拟合、是否收敛)。

import matplotlib.pyplot as plt
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei']
plt.rcParams['axes.unicode_minus'] = False
# 创建画布
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 1. 绘制训练损失曲线
ax1.plot(range(1, epochs+1), train_losses, label='训练损失', color='blue')
ax1.set_xlabel('Epoch(训练轮次)')
ax1.set_ylabel('损失值')
ax1.set_title('ResNet18训练损失曲线')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. 绘制测试准确率曲线
ax2.plot(range(1, epochs+1), test_accs, label='测试准确率', color='red')
ax2.set_xlabel('Epoch(训练轮次)')
ax2.set_ylabel('准确率')
ax2.set_title('ResNet18测试准确率曲线')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 保存图片
plt.tight_layout()
plt.savefig('resnet18_training_curve.png', dpi=300, bbox_inches='tight')
plt.show()

六、总结

本文从 PyTorch 神经网络的核心组件出发,详细讲解了三种模型构建方法(继承nn.Modulenn.SequentialModule+ 容器)、自定义残差模块的实现,以及完整的模型训练流程。通过本文的学习,您可以掌握:

  1. nn.Modulenn.functional的核心差异与适用场景;
  2. 根据模型复杂度选择合适的构建方法;
  3. 自定义复杂模块(如残差块)解决实际问题;
  4. 端到端的模型训练与结果可视化。

PyTorch 的灵活性在于其支持从简单到复杂的各类模型,建议您结合具体任务(如分类、检测、生成)实践本文代码,进一步加深理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/915935.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java函数式编程的学习01

java函数式编程:在stream流中经常用到 对stream流的理解:操作集合的一种方法 stream流的用法:创建流、中间操作、终结操作 创建流的方式以及一些注意事项: 如果是集合通过.stream()方法来创建流,如果是数组,可以…

Manim实现镜面反射特效

本文将介绍如何使用ManimCE框架实现镜面反射特效,让你的动画更加生动有趣。 1. 实现原理 1.1. 对称点计算 实现镜面反射的核心是计算点关于直线的对称点。 代码中的symmetry_point函数通过向量投影的方法计算对称点:…

25Java基础之IO(二)

IO流-字符流 FileReader(文件字符输入流)作用:以内存为基准,可以把文件中的数据以字符的形式读入到内存中去。案例:读取一个字符//目标:文件字符输入流的使用,每次读取一个字符。 public class FileReaderDemo01 …

【git】统计项目下每个人提交行数

git log --format=%aN | sort -u | while read name; do echo -en "$name\t"; git log --author="$name" --pretty=tformat: --numstat | awk { add += $1; subs += $2; loc += $1 - $2 } END { p…

【P2860】[USACO06JAN] Redundant Paths G - Harvey

题意 给定一个连通图,求最少要加多少条边使得图无割边。 思路 首先,我们可以先缩点再进行考虑。 缩点后整个连通图变成一棵树,为了使连边后不出现割边,可以将所有度为 \(1\) 的点两两连边,如果度为 \(1\) 的点的个…

GUI软件构造

GUI(桌面图形用户界面) 设计遵循规范,要标准,不繁杂 JAVA GUI设计模式 观察者模式是一种软件设计模式 ,他定义了一种一对多的依赖关系,一个对象改变其他对象自动更新 包含的角色 被观察对象(subject) 具体被观…

网站页面建设方案书模板wordpress模班之家

1. 字面含义不同 Comparable字面意思是“具有比较能力”,Comparator字面意思是“比较器”。 2. 用法不同 Comparable用法:对需要排序的类,实现Comparable接口,重写compareTo()方法。 Comparator用法:创建自定义比较…

ssh蒙语网站开发室内设计公司办公室图片

在孩子学习过程中,假设有一种“方法”,能让孩子成绩突飞猛进,你想不想掌握?在孩子学习过程中,假设有一套“系统”,能让孩子主动喜欢上学习,你想不想拥有?在孩子学习过程中&#xff0…

点餐网站怎么做哈尔滨网站建设制作

导读:本文主要围绕材料非线性问题的有限元Matlab编程求解进行介绍,重点围绕牛顿-拉普森法(切线刚度法)、初应力法、初应变法等三种非线性迭代方法的算法原理展开讲解,最后利用Matlab对材料非线性问题有限元迭代求解算法进行实现,展示了实现求解的核心代码。这些内容都将收…

【CV】GAN代码解析 image_folder.py

【CV】GAN代码解析 image_folder.pyPosted on 2025-09-24 16:07 SaTsuki26681534 阅读(0) 评论(0) 收藏 举报"""A modified image folder classWe modify the official PyTorch image folder (htt…

一些常用的网站

📚 我的常用网址收藏夹前言: 记录那些在我的数字生活中不可或缺的网站和工具,方便快速访问和分享。🚀 常用工具 开发与编程插件库: open-vsx - vscode/trae的历史插件下载 技术文档: MDN Web Docs - 前端开发者的…

systemd-nspawn容器体积精简和桥接网络实战

systemd-nspawn容器体积精简和桥接网络实战目录前言需求精简容器体积创建目录结构测试容器是否正常启动创建并测试容器的独立网络形成systemd服务文件。通过wifi连接网关的容器配置其他说明前言 以前我的树莓派服务是放…

运维自动化工具Ansible大总结20250914 - 教程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

贵州省住房和城乡建设部网站成都住建局官网站首页

目录 概念图遍历深度优先搜索 (DFS)DFS 适用场景DFS 优缺点 广度优先搜索 (BFS)BFS 适用场景BFS 优缺点 DFS & BFS 异同点 图搜索Dijkstra算法A*算法Floyd算法Bellman-Ford算法SPFA算法 概念 图遍历和图搜索是解决图论问题时常用的两种基本操作。 图遍历是指从图中的某一个…

上海建筑 公司网站wordpress 伪静态

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

杭州网站前端建设seo全称是什么意思

目录 从上到下,你所看到的目录如下 /bin /bin 目录是包含一些二进制文件的目录,即可以运行的一些应用程序。 你会在这个目录中找到上面提到的 ls 程序,以及用于新建和删除文件和目录、移动它们基本工具。还有其它一些程序,等等。…

企业微信客服API模式接入第三方客服系统,对接大模型AI智能体

我们系统可以接入企业微信客服的API gofly.v1kf.com 联系vx:llike620企业微信客服是企业微信里面的一项功能,它整合了微信生态的优势,解决的是与临时访客进行实时沟通的需求 核心功能 多渠道接待:支持在微信内(公…

react使用ctx和reducer代替redux

入门版本 创建一个store,包含ctx、reduce、dispatch+action import { createContext, useContext } from react;// 定义ctx export const defaultValue = {count: 0, }; export const AppCtx = createContext(null);e…

KM 乱记

狠狠学习了先来看一个问题:给定 \(w_{1\sim n, 1\sim n}\),现在要求满足 \(\forall i, j\in [1, n], a_i + b_j\ge w_{i, j}\) 且 \(\sum a_i + \sum b_j\) 最小的 \(a_{1\sim n}, b_{1\sim n}\)。如果会线性规划对偶…

深入解析:B树与B+树的原理区别应用

深入解析:B树与B+树的原理&区别&应用pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", &q…