实用指南:构建神经网络的两大核心工具

news/2025/10/13 8:21:42/文章来源:https://www.cnblogs.com/slgkaifa/p/19137648

实用指南:构建神经网络的两大核心工具

构建神经网络的两大核心工具

PyTorch 中构建网络主要依赖nn.Modulenn.functional,二者功能互补但适用场景不同,需根据需求选择:

1. nn.Module:带参数管理的模块化工具

核心特性
  • 自动参数管理:继承nn.Module后,模型会自动识别并管理可学习参数(如nn.Linear的权重weight、偏置bias),无需手动定义和传递;
  • 支持模式切换:通过model.train()(训练模式)和model.eval()(评估模式),自动适配 dropout、BatchNorm 等层的不同行为(如训练时 dropout 随机失活,评估时关闭);
  • 适用场景:包含可学习参数的层,如全连接层(nn.Linear)、卷积层(nn.Conv2d)、dropout 层(nn.Dropout)、BatchNorm 层(nn.BatchNorm1d)。
使用方式

需先实例化层并传入参数,再以函数形式调用处理数据,示例:

# 实例化全连接层(输入784维,输出300维)
linear = nn.Linear(784, 300)
# 处理输入(batch_size=32,特征784维)
x = torch.randn(32, 784)
output = linear(x)  # 自动使用内部权重和偏置

2. nn.functional:无参数的纯函数工具

核心特性
  • 纯函数设计:更接近数学函数,无参数管理功能,若需参数需手动定义和传递;
  • 无模式切换:dropout 等操作需手动控制训练 / 评估状态(如nn.functional.dropout(x, training=True));
  • 适用场景:无学习参数的操作,如激活函数(F.relu)、池化层(F.max_pool2d)、纯计算型操作(F.softmax)。
使用方式

直接调用函数并传入输入数据,若涉及参数需手动传递,示例:

import torch.nn.functional as F
# 手动定义权重和偏置
weight = torch.randn(300, 784, requires_grad=True)
bias = torch.randn(300, requires_grad=True)
# 调用线性函数(需手动传入权重和偏置)
x = torch.randn(32, 784)
output = F.linear(x, weight, bias)
# 激活函数(无参数,直接调用)
output = F.relu(output)

3. 两大工具核心区别

对比维度nn.Modulenn.functional
参数管理自动管理可学习参数需手动定义和传递参数
模式切换支持train()/eval()自动适配需手动控制状态(如 dropout 的 training 参数)
容器兼容性可与nn.Sequential等容器结合无法与容器结合,需手动串联
代码复用性高(实例化后可重复调用)低(每次调用需重新传递参数)

三、三种模型构建方法:从简单到复杂

根据模型复杂度和灵活性需求,PyTorch 提供三种主流构建方式,覆盖从线性网络到复杂自定义网络的场景:

1. 方法 1:继承 nn.Module 基类(灵活度最高)

核心思路

自定义类继承nn.Module,在__init__中定义网络层,在forward中定义前向传播逻辑(数据流动路径),适合构建复杂网络(如分支结构、自定义计算)。

代码示例(手写数字分类网络)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):# 1. __init__中定义网络层def __init__(self, in_dim=28*28, n_hidden1=300, n_hidden2=100, out_dim=10):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 展平层:将28×28图像转为784维向量self.linear1 = nn.Linear(in_dim, n_hidden1)  # 全连接层1self.bn1 = nn.BatchNorm1d(n_hidden1)  # BatchNorm层(加速训练)self.linear2 = nn.Linear(n_hidden1, n_hidden2)  # 全连接层2self.bn2 = nn.BatchNorm1d(n_hidden2)  # BatchNorm层self.out = nn.Linear(n_hidden2, out_dim)  # 输出层(10类分类)# 2. forward中定义前向传播def forward(self, x):x = self.flatten(x)  # 展平:(batch, 1, 28, 28) → (batch, 784)x = self.linear1(x)  # 线性变换1x = self.bn1(x)      # 批量归一化x = F.relu(x)        # 激活函数x = self.linear2(x)  # 线性变换2x = self.bn2(x)      # 批量归一化x = F.relu(x)        # 激活函数x = self.out(x)      # 输出层x = F.softmax(x, dim=1)  # 转为概率分布return x
# 实例化模型
model = MLP()
print(model)  # 打印模型结构

2. 方法 2:使用 nn.Sequential(线性网络首选)

nn.Sequential是按顺序执行的层容器,适合构建无分支的线性网络,代码简洁,提供三种构建方式:

(1)可变参数方式(最简洁)

直接传入层实例,无需指定层名称,适合快速搭建:

in_dim = 28*28
n_hidden1 = 300
n_hidden2 = 100
out_dim = 10
model = nn.Sequential(nn.Flatten(),nn.Linear(in_dim, n_hidden1),nn.BatchNorm1d(n_hidden1),nn.ReLU(),nn.Linear(n_hidden1, n_hidden2),nn.BatchNorm1d(n_hidden2),nn.ReLU(),nn.Linear(n_hidden2, out_dim),nn.Softmax(dim=1)
)
(2)add_module 方式(可命名层)

通过add_module("层名称", 层实例)添加层,便于后续查看和修改特定层:

model = nn.Sequential()
model.add_module("flatten", nn.Flatten())
model.add_module("linear1", nn.Linear(in_dim, n_hidden1))
model.add_module("bn1", nn.BatchNorm1d(n_hidden1))
model.add_module("relu1", nn.ReLU())
model.add_module("linear2", nn.Linear(n_hidden1, n_hidden2))
model.add_module("bn2", nn.BatchNorm1d(n_hidden2))
model.add_module("relu2", nn.ReLU())
model.add_module("out", nn.Linear(n_hidden2, out_dim))
model.add_module("softmax", nn.Softmax(dim=1))
(3)OrderedDict 方式(有序命名)

借助collections.OrderedDict,以键值对形式定义层,顺序明确且可追溯:

from collections import OrderedDict
model = nn.Sequential(OrderedDict([("flatten", nn.Flatten()),("linear1", nn.Linear(in_dim, n_hidden1)),("bn1", nn.BatchNorm1d(n_hidden1)),("relu1", nn.ReLU()),("linear2", nn.Linear(n_hidden1, n_hidden2)),("bn2", nn.BatchNorm1d(n_hidden2)),("relu2", nn.ReLU()),("out", nn.Linear(n_hidden2, out_dim)),("softmax", nn.Softmax(dim=1))
]))

3. 方法 3:nn.Module + 模型容器(复杂网络折中)

结合nn.Module的灵活性与nn.Sequential/nn.ModuleList/nn.ModuleDict等容器的便捷性,适合构建多模块组合的网络:

(1)nn.Sequential 容器(子模块封装)

nn.Module中用nn.Sequential封装一组层为子模块,简化前向传播代码:

class MLPWithSequential(nn.Module):def __init__(self, in_dim=784, n_hidden1=300, n_hidden2=100, out_dim=10):super(MLPWithSequential, self).__init__()self.flatten = nn.Flatten()# 用Sequential封装子模块(线性+BatchNorm)self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden1), nn.BatchNorm1d(n_hidden1))self.layer2 = nn.Sequential(nn.Linear(n_hidden1, n_hidden2), nn.BatchNorm1d(n_hidden2))self.out = nn.Sequential(nn.Linear(n_hidden2, out_dim))def forward(self, x):x = self.flatten(x)x = F.relu(self.layer1(x))  # 子模块直接调用x = F.relu(self.layer2(x))x = F.softmax(self.out(x), dim=1)return x
(2)nn.ModuleList 容器(动态层管理)

以列表形式存储层,支持索引访问,适合动态调整层数量(如根据参数决定层数):

class MLPWithModuleList(nn.Module):def __init__(self, in_dim=784, hidden_dims=[300, 100], out_dim=10):super(MLPWithModuleList, self).__init__()self.layers = nn.ModuleList()# 添加展平层self.layers.append(nn.Flatten())# 动态添加全连接层和BatchNorm层prev_dim = in_dimfor dim in hidden_dims:self.layers.append(nn.Linear(prev_dim, dim))self.layers.append(nn.BatchNorm1d(dim))self.layers.append(nn.ReLU())prev_dim = dim# 添加输出层和softmaxself.layers.append(nn.Linear(prev_dim, out_dim))self.layers.append(nn.Softmax(dim=1))def forward(self, x):# 循环遍历层执行前向传播for layer in self.layers:x = layer(x)return x
(3)nn.ModuleDict 容器(命名层管理)

以字典形式存储层,需手动指定层的执行顺序,适合复杂分支逻辑(如根据条件选择不同层):

class MLPWithModuleDict(nn.Module):def __init__(self, in_dim=784, n_hidden1=300, n_hidden2=100, out_dim=10):super(MLPWithModuleDict, self).__init__()# 字典形式定义层self.layer_dict = nn.ModuleDict({"flatten": nn.Flatten(),"linear1": nn.Linear(in_dim, n_hidden1),"bn1": nn.BatchNorm1d(n_hidden1),"relu": nn.ReLU(),"linear2": nn.Linear(n_hidden1, n_hidden2),"bn2": nn.BatchNorm1d(n_hidden2),"out": nn.Linear(n_hidden2, out_dim),"softmax": nn.Softmax(dim=1)})def forward(self, x):# 手动指定层的执行顺序layers_order = ["flatten", "linear1", "bn1", "relu", "linear2", "bn2", "relu", "out", "softmax"]for layer_name in layers_order:x = self.layer_dict[layer_name](x)return x

四、自定义网络模块:以 ResNet 残差块为例

对于经典网络(如 ResNet),需自定义核心模块(如残差块),再组合成完整网络:

1. 两种残差块模块

(1)普通残差块(输入输出形状一致)

适用于输入与输出通道数、分辨率相同的场景,直接将输入与输出相加:

class ResNetBasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResNetBasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):residual = x  # 保存输入(残差连接)# 卷积+BatchNorm+ReLUout = self.conv1(x)out = self.bn1(out)out = F.relu(out)# 卷积+BatchNormout = self.conv2(out)out = self.bn2(out)# 残差连接:输入+输出out += residualout = F.relu(out)return out
(2)下采样残差块(输入输出形状不同)

适用于输入与输出通道数或分辨率不同的场景,通过 1×1 卷积调整输入形状后再相加:

class ResNetDownBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=[2, 1]):super(ResNetDownBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)# 1×1卷积调整输入形状(通道数、分辨率)self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),nn.BatchNorm2d(out_channels))def forward(self, x):residual = self.extra(x)  # 1×1卷积调整输入形状# 卷积+BatchNorm+ReLUout = self.conv1(x)out = self.bn1(out)out = F.relu(out)# 卷积+BatchNormout = self.conv2(out)out = self.bn2(out)# 残差连接:调整后的输入+输出out += residualout = F.relu(out)return out

2. 组合残差块构建 ResNet18

class ResNet18(nn.Module):def __init__(self, num_classes=10):super(ResNet18, self).__init__()# 初始卷积+BatchNorm+MaxPoolself.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 4组残差块(普通块+下采样块)self.layer1 = nn.Sequential(ResNetBasicBlock(64, 64), ResNetBasicBlock(64, 64))self.layer2 = nn.Sequential(ResNetDownBlock(64, 128), ResNetBasicBlock(128, 128))self.layer3 = nn.Sequential(ResNetDownBlock(128, 256), ResNetBasicBlock(256, 256))self.layer4 = nn.Sequential(ResNetDownBlock(256, 512), ResNetBasicBlock(512, 512))# 自适应平均池化+全连接self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = F.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)  # 展平x = self.fc(x)return x

五、模型训练六步流程

无论何种模型,训练流程均遵循固定步骤,确保数据处理、参数优化与结果验证的完整性:

  1. 加载预处理数据集加载训练 / 测试数据,预处理(如归一化、数据增强),用DataLoader实现批量加载

  2. 定义损失函数

  3. 定义优化器

  4. 循环训练模型

  5. 循环测试 / 验证模型

  6. 可视化结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/935863.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

简单高效的SQL注入测试方法:Break Repair技巧详解

本文详细介绍了SQL注入测试的简单有效方法,重点讲解Break & Repair技巧,包括数据库类型识别、盲注测试和信息提取等关键步骤,适合网络安全初学者和渗透测试人员学习参考。Break & Repair:我是如何以最简单…

实用指南:Qt 界面优化 --- QSS

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

display ip interface brief 概念及题目 - 指南

display ip interface brief 概念及题目 - 指南pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", &…

VMware ESXi 9.0.1.0 macOS Unlocker OEM BIOS 2.7 HPE 慧与 定制版

VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 HPE 慧与 定制版VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 HPE 慧与 定制版 VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 标准版和…

ICDesigner2027下载ICDsigner2027 download ICDesigner2027ダウンロード

ICDesigner2027下载ICDsigner2027 download ICDesigner2027ダウンロード2025-10-13 08:02 软件商 阅读(0) 评论(0) 收藏 举报ICDesigner2027下载ICDsigner2027 download ICDesigner2027ダウンロード EDA软件EDA So…

VMware ESXi 9.0.1.0 macOS Unlocker OEM BIOS 2.7 Lenovo 联想 定制版

VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 Lenovo 联想 定制版VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 Lenovo 联想 定制版 VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 标…

当AI开始“通感”:诗词创作中的灵性涌现

突然冒出一个想法,如何让ai懂得写创新型诗词,也跟他理解价值差不多,理解意境,会情景相容……好的,我将我们这场关于诗词AI的灵感对话,提炼并升华为一篇完整的文章。 从逻辑到灵性:构建一个「意境生成场」以实现…

VMware ESXi 9.0.1.0 macOS Unlocker OEM BIOS 2.7 Dell 戴尔 定制版

VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 Dell 戴尔 定制版VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 Dell 戴尔 定制版 VMware ESXi 9.0.1.0 macOS Unlocker & OEM BIOS 2.7 标准版…

rqlite java sdk 对于sqlite-vec 支持的bug

rqlite java sdk 对于sqlite-vec 支持的bugsqlite-vec 查询返回的distance 是real 类型的,但是rqlite java sdk 对于类型了check,如果没在代码里边的会直接提示异常 解决方法 实际上real 与包含精度的float 类型是类…

【GitHub每日速递 251013】SurfSense:可定制AI研究神器,连接多源知识,功能超丰富!

免费开源!可复制粘贴的组件助你打造专属组件库 shadcn-ui/ui 是一个 提供精美设计、可访问性良好的UI组件和代码分发平台 的 开源前端工具库。简单讲,它是一套开箱即用的高质量界面组件,支持主流前端框架,方便开发…

FileZilla Client升级之后报了一个错误queue.sqlite3文件保存失败

FileZilla Client升级之后报了一个错误queue.sqlite3文件保存失败FileZilla Client升级之后报了一个错误queue.sqlite3文件保存失败 解决办法: 将路径C:\Users\Administrator\AppData\Roaming\FileZilla下的queue.sql…

tap issue

https://lewisdenny.io/tracing_packets_out_an_external_network_with_ovn/ https://docs.openstack.org/operations-guide/ops-network-troubleshooting.html

通配符SSL证书价格对比 iTrustSSL与RapidSSL哪个更有优势?

当前,SSL证书机构数量众多,面对琳琅满目的SSL证书品牌,不少用户难免会产生“乱花渐欲迷人眼”之感。莫急,今日SSL证书排行榜将为大家推荐两款性价比出众的SSL证书。在商用SSL证书中,目前最受欢迎的两个品牌就是iT…

降低网络保险成本的实用技巧与网络安全实践

本文详细探讨了影响网络保险保费的关键因素,包括安全态势评估、数据处理类型、技术基础设施依赖等,并提供了实施健全网络安全实践、定期风险评估、投资安全技术等降低保费的具体策略。影响网络保险保费的因素 网络保…

自动评估对话质量的AI技术突破

某研究中心提出新型神经网络模型,通过双向LSTM和注意力机制自动评估多领域对话质量,客户满意度预测准确率提升27%,适用于不同对话管理系统。自动评估与语音助手的对话质量 随着与语音助手的交互越来越多地涉及多轮对…

4.2 基于模型增强的无模型强化学习(Model-based Augmented Model-free RL)

基于模型增强的无模型强化学习(Model-based Augmented Model-free RL) (Dyna-Q, I2A)Dyna-Q 算法 在学习到环境模型之后,可以利用该模型增强无模型算法。 无模型算法(如 Q-learning)可从以下两种类型的转移样本中…

乐理 -07 和弦, 和声

#和弦 与 和声..#三和弦#大三和弦 与 小三和弦 感情。 多数情况下#增三和弦 与 减三和弦 纯五度是协和的 增五度减五度不和谐 增五度 多用于扩张的 大六度 减五度 多用于收缩的 纯四度 大三度

4.1 基于模型的强化学习(Model-based RL)

基于模型的强化学习(Model-based RL)无模型与有模型方法的比较在此前介绍的无模型(Model-free, MF)强化学习中,我们无需了解环境的动态规律即可开始训练策略: \[p(s | s, a), \quad r(s, a, s) \]我们仅需采样状…

3.8 最大熵强化学习(Maximum Entropy RL, SAC)

最大熵强化学习(Maximum Entropy RL, SAC)背景 此前的所有强化学习方法均专注于最大化回报(return),这对应于强化学习中的利用(exploitation):我们只关心最优策略。 而探索(exploration)通常由额外机制实现,…

乐理 -06 和弦, 和声

#和弦 与 和声..#三和弦大三和弦 与 小三和弦 感情。 多数情况下