pytorch小记(十三):pytorch中`nn.ModuleList` 详解

pytorch小记(十三):pytorch中`nn.ModuleList` 详解

  • PyTorch 中的 `nn.ModuleList` 详解
    • 1. 什么是 `nn.ModuleList`?
    • 2. 为什么不直接使用普通的 Python 列表?
    • 3. `nn.ModuleList` 的基本用法
      • 示例:构建一个包含两层全连接网络的模型
    • 4. 使用 `nn.ModuleList` 计算参数总数(与普通列表对比)
      • 示例代码
    • 5. `nn.ModuleList` 的其他应用
      • 示例:构建动态 MLP 模型
      • Transformers中的多头注意力机制
    • 6. 总结


PyTorch 中的 nn.ModuleList 详解

在构建深度学习模型时,经常需要管理多个网络层(例如多个 nn.Linearnn.Conv2d 等)。在 PyTorch 中,nn.ModuleList 是一个非常有用的容器,可以帮助我们存储多个子模块,并自动注册它们的参数。这对于确保所有参数能够参与训练非常重要。本文将详细介绍 nn.ModuleList 的作用、使用方法及与普通 Python 列表的区别,并给出清晰的代码示例。


1. 什么是 nn.ModuleList

nn.ModuleList 是一个类似于 Python 列表的容器,但专门用来存储 PyTorch 的子模块(也就是继承自 nn.Module 的对象)。其主要特点是:

  • 自动注册子模块:将 nn.Module 存储在 ModuleList 中后,这些模块的参数会自动被添加到父模块的参数列表中。这意味着当你调用 model.parameters() 时,这些子模块的参数也会被包含进去,从而参与梯度计算和优化。

  • 灵活管理:它可以像普通列表一样进行索引、迭代和切片操作,方便构建动态网络结构。

注意nn.ModuleList 不会像 nn.Sequential 那样自动定义前向传播(forward)流程。你需要在模型的 forward() 方法中手动遍历 ModuleList 并调用各个子模块。


2. 为什么不直接使用普通的 Python 列表?

虽然可以将 nn.Module 对象存储在普通列表中,但这样做有一个主要问题:
普通列表中的模块不会自动注册为父模块的子模块
这会导致:

  • 调用 model.parameters() 时无法获取这些模块的参数;
  • 优化器无法更新这些参数,从而影响模型训练。

而使用 nn.ModuleList 可以避免这个问题,因为它会自动将内部所有的模块注册到父模块中。


3. nn.ModuleList 的基本用法

下面通过一个简单的示例来说明如何使用 nn.ModuleList 构建一个简单的神经网络模型。

示例:构建一个包含两层全连接网络的模型

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 创建一个 ModuleList 来存储各层self.layers = nn.ModuleList([nn.Linear(10, 20),  # 第 1 层:输入 10 个特征,输出 20 个特征nn.ReLU(),          # 激活层nn.Linear(20, 5)    # 第 2 层:输入 20 个特征,输出 5 个特征])def forward(self, x):# 手动遍历 ModuleList 中的每个模块,并依次调用 forwardfor layer in self.layers:x = layer(x)return x# 创建模型实例
model = MyModel()# 打印模型结构
print("模型结构:")
print(model)# 生成一组示例输入
input_tensor = torch.randn(3, 10)  # 3 个样本,每个样本 10 个特征# 得到模型输出
output = model(input_tensor)
print("\n模型输出:")
print(output)
模型结构:
MyModel((layers): ModuleList((0): Linear(in_features=10, out_features=20, bias=True)(1): ReLU()(2): Linear(in_features=20, out_features=5, bias=True))
)模型输出:
tensor([[ 0.3741,  0.0883,  0.3550, -0.3930,  0.5173],[ 0.2171, -0.0978, -0.0585, -0.4568,  0.3331],[ 0.1232, -0.1491,  0.2026, -0.0978,  0.5478]],grad_fn=<AddmmBackward0>)

说明

  • __init__() 方法中,我们将各个层放在了 nn.ModuleList 中。
  • forward() 方法中,我们使用了一个简单的 for 循环,依次调用 self.layers 中的每个子模块。

4. 使用 nn.ModuleList 计算参数总数(与普通列表对比)

为了进一步说明 nn.ModuleList 与普通列表的区别,我们分别计算一下两种方式下模型的参数总数。

示例代码

import torch.nn as nn# 使用 ModuleList 存储模型层
layers_ml = nn.ModuleList([nn.Linear(10, 20),nn.Linear(20, 5)
])# 计算 ModuleList 中的参数总数
ml_params = 0
for p in layers_ml.parameters():ml_params += p.numel()# 使用普通 Python 列表存储模型层
layers_list = [nn.Linear(10, 20),nn.Linear(20, 5)
]# 计算普通列表中的参数总数
list_params = 0
# 先遍历列表中的每个层
for layer in layers_list:# 再遍历每个层的参数for p in layer.parameters():list_params += p.numel()print("ModuleList 参数总数:", ml_params)
print("普通列表参数总数:", list_params)
ModuleList 参数总数: 325
普通列表参数总数: 325

说明

  • 第一个 for 循环遍历 layers_ml.parameters(),直接累加所有参数的元素数。
  • 第二部分中,我们先遍历普通列表中的每个 layer,再单独遍历每个层的参数。这样做使每一步都清晰易懂。

5. nn.ModuleList 的其他应用

示例:构建动态 MLP 模型

当网络结构比较复杂或层数不固定时,可以利用列表生成器动态构建 ModuleList

class DynamicMLP(nn.Module):def __init__(self, layer_sizes):super(DynamicMLP, self).__init__()# 使用 for 循环构造每一层,存储在 ModuleList 中layers = []  # 先用普通列表保存层for i in range(len(layer_sizes) - 1):linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])layers.append(linear_layer)# 将普通列表转换为 ModuleListself.layers = nn.ModuleList(layers)def forward(self, x):# 遍历每一层(没有嵌套循环,逐个执行)for layer in self.layers:x = torch.relu(layer(x))return x# 创建一个动态 MLP:输入 10,隐藏层 20, 30,输出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("动态 MLP 模型:")
print(dynamic_model)# 测试模型
input_tensor = torch.randn(4, 10)  # 4 个样本,每个样本 10 个特征
output = dynamic_model(input_tensor)
print("\n动态 MLP 模型输出:")
print(output)

说明

  • __init__() 方法中,我们使用一个普通列表 layers 存储每个 nn.Linear 层,然后再将它转换为 nn.ModuleList
  • forward() 方法中,使用单独的 for 循环逐个调用每一层,并对输出应用 ReLU 激活函数。
  • 这种写法适用于层数动态变化的网络(例如 MLP、RNN、Transformer 中部分模块)。

Transformers中的多头注意力机制

class SingleHeadAttention(nn.Module):def __init__(self, embed_dim, head_dim):super().__init__()self.query = nn.Linear(embed_dim, head_dim)self.key = nn.Linear(embed_dim, head_dim)self.value = nn.Linear(embed_dim, head_dim)def forward(self, x):# 实现注意力计算逻辑...return attended_valuesclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.head_dim = embed_dim // num_heads# 显式创建每个注意力头self.head1 = SingleHeadAttention(embed_dim, self.head_dim)self.head2 = SingleHeadAttention(embed_dim, self.head_dim)self.head3 = SingleHeadAttention(embed_dim, self.head_dim)# 使用ModuleList管理多个头self.heads = nn.ModuleList([self.head1,self.head2,self.head3])self.output_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 分别处理每个头head1_out = self.head1(x)head2_out = self.head2(x) head3_out = self.head3(x)# 拼接结果combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)return self.output_proj(combined)

关键点解析:

  • 显式声明每个注意力头(避免循环)

  • 使用ModuleList统一管理注意力头

  • 在forward中分别调用每个头

  • 保持各头独立性,便于后续调试


6. 总结

  • nn.ModuleList 是专门用于存储多个子模块的容器,它会自动注册子模块,确保所有参数能参与训练。
  • 与普通 Python 列表相比,ModuleList 可以直接通过 model.parameters() 获取其中所有参数,从而方便地进行优化。
  • 使用 ModuleList 时,前向传播需要手动遍历其中的模块,这提供了更大的灵活性,但也要求开发者理解循环过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/72746.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Excel导出工具类--复杂的excel功能导出(使用自定义注解导出)

Excel导出工具类 前言: 简单的excel导出,可以用easy-excel, fast-excel, auto-poi,在导出实体类上加上对应的注解,用封装好的工具类直接导出,但对于复杂的场景, 封装的工具类解决不了,要用原生的excel导出(easy-excel, fast-excel, auto-poi都支持原生的) 业务场景: 根据…

批量测试IP和域名联通性2

在前面批量测试IP和域名联通性-CSDN博客的基础上&#xff0c;由于IP和域名多样性&#xff0c;比如带端口号的192.168.1.17:17&#xff0c;实际上应该ping 192.168.1.17。如果封禁http://www.abc.com/a.exe&#xff0c;实际可ping www.abc.com。所以又完善了代码。 echo off se…

国产编辑器EverEdit - 语法着色文件的语法

1 语法着色定义(官方文档) 1.1 概述 EverEdit有着优异的语法着色引擎&#xff0c;可以高亮现存的绝大多数的编程语言。在EverEdit的语法着色中有Region和Item两个概念&#xff0c;Region表示着不同的区块。而Item则代表着这些区块中不同的部分。一般情况下&#xff0c;Region…

Excel处理控件Aspose.Cells教程:如何自动将 HTML 转换为 Excel

在处理 HTML 表中呈现的结构化数据时&#xff0c;将 HTML 转换为 Excel 是一种常见需求。无论您是从网站、报告还是任何其他来源提取数据&#xff0c;将其转换为 Excel 都可以更好地进行分析、操作和共享。 开发人员通常更喜欢使用编程方法将 HTML 转换为 Excel&#xff0c;因…

基于springbo校园安全管理系统(源码+lw+部署文档+讲解),源码可白嫖!

摘要 随着信息时代的来临&#xff0c;过去信息校园安全管理方式的缺点逐渐暴露&#xff0c;本次对过去的校园安全管理方式的缺点进行分析&#xff0c;采取计算机方式构建校园安全管理系统。本文通过阅读相关文献&#xff0c;研究国内外相关技术&#xff0c;提出了一种集进出校…

vim在连续多行行首插入相同的字符

工作中经常需要用vim注释掉一段代码或者json文件中的一部分&#xff0c;需要在多行前面插入//或者#符号。在 Vim 中&#xff0c;在连续多行行首插入相同字符主要有以下两种方法&#xff1a; Visual Block 模式插入 将光标移到要插入相同内容的第一行的行首24。按下Ctrl v进入…

Git 实战指南:本地客户端连接 Gitee 全流程

本文将以 Gitee(码云)、系统Windows 11 为例,详细介绍从本地仓库初始化到远程协作的全流程操作 目录 1. 前期准备1.1 注册与配置 Gitee1.2 下载、安装、配置客户端1.3 配置公钥到 Gitee2. 本地仓库操作(PowerShell/Git Bash)2.1 初始化本地仓库2.2 关联 Gitee 远程仓库3. …

Pytest项目_day01(HTTP接口)

HTTP HTTP是一个协议&#xff08;服务器传输超文本到浏览器的传送协议&#xff09;&#xff0c;是基于TCP/IP通信协议来传输数据&#xff08;HTML文件&#xff0c;图片文件&#xff0c;查询结果等&#xff09;。 访问域名 例如www.baidu.com就是百度的域名&#xff0c;我们想…

MySQL超详细介绍(近2万字)

1. 简单概述 MySQL安装后默认有4个库不可以删除&#xff0c;存储的是服务运行时加载的不同功能的程序和数据 information_schema&#xff1a;是MySQL数据库提供的一个虚拟的数据库&#xff0c;存储了MySQL数据库中的相关信息&#xff0c;比如数据库、表、列、索引、权限、角色等…

SQLMesh宏操作符深度解析:掌握@star与@GENERATE_SURROGATE_KEY实战技巧

引言&#xff1a;解锁SQLMesh的动态查询能力 在复杂的数据处理场景中&#xff0c;手动编写重复性SQL代码不仅效率低下&#xff0c;还难以维护。SQLMesh作为新一代数据库中间件&#xff0c;通过其强大的宏系统赋予开发者编程式构建查询的能力。本文将重点解析两个核心操作符——…

超详细kubernetes部署k8s----一台master和两台node

一、部署说明 1、主机操作系统说明 2、主机硬件配置说明 二、主机准备&#xff08;没有特别说明都是三台都要配置&#xff09; 1、配置主机名和IP 2、配置hosts解析 3、防火墙和SELinux 4、时间同步配置 5、配置内核转发及网桥过滤 6、关闭swap 7、启用ipvs 8、句柄…

高光谱相机在水果分类与品质检测中的应用

一、核心应用领域 ‌外部品质检测‌ ‌表面缺陷识别&#xff1a;通过400-1000nm波段的高光谱成像&#xff0c;可检测苹果表皮损伤、碰伤等细微缺陷&#xff0c;结合图像分割技术实现快速分类‌。 ‌损伤程度评估&#xff1a;例如青香蕉的碰撞损伤会导致光谱反射率变化&#…

【蓝桥杯每日一题】3.17

&#x1f3dd;️专栏&#xff1a; 【蓝桥杯备篇】 &#x1f305;主页&#xff1a; f狐o狸x 他们说内存泄漏是bug&#xff0c;我说这是系统在逼我进化成SSR级程序员 OK来吧&#xff0c;不多废话&#xff0c;今天来点有难度的&#xff1a;二进制枚举 二进制枚举&#xff0c;就是…

Windows11 新机开荒(二)电脑优化设置

目录 前言&#xff1a; 一、注册微软账号绑定权益 二、此电脑 桌面图标 三、系统分盘及默认存储位置更改 3.1 系统分盘 3.2 默认存储位置更改 四、精简任务栏 总结&#xff1a; 前言&#xff1a; 本文承接上一篇 新机开荒&#xff08;一&#xff09; 上一篇文章地址&…

aws(学习笔记第三十三课) 深入使用cdk 练习aws athena

文章目录 aws(学习笔记第三十三课) 深入使用cdk学习内容&#xff1a;1. 使用aws athena1.1 什么是aws athena1.2 什么是aws glue1.2 为什么aws athena和aws glue一起使用 2. 开始练习aws athena2.1 代码链接2.2 整体架构2.3 代码解析2.3.1 创建测试数据的S3 bucket2.3.2 创建保…

每日学习Java之一万个为什么(待补充)

Git分支操作 git branch 分支名 git branch -v git checkout -b 分支名 git checkout 分支名 git merge 分支名 git branch -d | -D 分支名Git冲突 git同名文件合并的最基本单位是行。同名文件同一行不同就会发生冲突。 解决办法&#xff1a;及时沟通&#xff0c;手动更改&…

C++ 多生产者单消费者(MPSC)模式

根据你的需求,多生产者单消费者(MPSC)模式的日志任务队列需要调整设计。以下是改进后的代码实现,重点在于多线程安全入队、单线程消费任务,并确保停止时队列任务全部处理完毕: 多生产者单消费者(MPSC)任务队列实现 #include <iostream> #include <queue> …

OpenCV基础【图像和视频的加载与显示】

目录 一.创建一个窗口&#xff0c;显示图片 二.显示摄像头/多媒体文件 三.把摄像头录取到的视频存储在本地 四.鼠标回调事件 五.TrackBar滑动条 一.创建一个窗口&#xff0c;显示图片 import cv2img_path "src/fengjing.jpg" # 自己的图片路径 img cv2.imre…

c++--vector

1.定义vector vector的定义分为四种 (1)vector() ——————无参构造 (2)vector(size_t n,const value_type& val value_type()) ——————构造并初始化n个val (3)vector(const vector& v1) ———————拷贝构造 (4)vector(inputiterator first,inpu…

宇树科技纯技能要求总结

一、嵌入式开发与硬件设计 核心技能 嵌入式开发&#xff1a; 精通C/C&#xff0c;熟悉STM32、ARM开发熟悉Linux BSP开发及驱动框架&#xff08;SPI/UART/USB/FLASH/Camera/GPS/LCD&#xff09;掌握主流平台&#xff08;英伟达、全志、瑞芯微等&#xff09; 硬件设计&#xff1a…