PyTorch中的线性变换:nn.Parameter VS nn.Linear

self.weight = nn.Parameter(torch.randn(in_channels, out_channels))self.linear = nn.Linear(in_channels, out_channels) 并不完全一致,尽管它们都可以用于实现线性变换(即全连接层),但它们的使用方式和内部实现有所不同。

nn.Parameter

当手动创建一个 nn.Parameter 时,是在显式地定义权重矩阵,并且需要自己管理这个参数以及它如何参与到计算中。例如:

self.weight = nn.Parameter(torch.randn(in_channels, out_channels))

这里,self.weight 是一个可学习的参数,可以将其视为模型的一部分,并在前向传播过程中手动与输入进行矩阵乘法运算。假设输入是 x,则输出可以这样计算:

output = torch.matmul(x, self.weight)

注意这里的数学公式是 Y = X W Y = XW Y=XW,其中 X X X 是输入矩阵, W W W 是权重矩阵。如果还需要加上偏置项 b b b,则变为 Y = X W + b Y = XW + b Y=XW+b。在这个例子中,需要另外定义并初始化偏置项 self.bias

示例 1:自定义实现线性层

import torch
import torch.nn as nnclass CustomLinear(nn.Module):def __init__(self, in_channels, out_channels):super(CustomLinear, self).__init__()# 初始化权重self.weight = nn.Parameter(torch.randn(in_channels, out_channels))# 初始化偏置self.bias = nn.Parameter(torch.randn(out_channels))def forward(self, x):# 线性变换:Y = XW + breturn torch.matmul(x, self.weight) + self.bias# 创建自定义线性层
custom_linear = CustomLinear(in_channels=3, out_channels=2)# 打印权重和偏置
print("Weights:", custom_linear.weight)
print("Bias:", custom_linear.bias)# 输入数据
input_data = torch.randn(4, 3)  # 4个样本,每个样本有3个特征# 前向传播
output = custom_linear(input_data)
print("Output:", output)

在这个示例中,我们手动创建了一个自定义的线性层 CustomLinear,它使用 nn.Parameter 来定义权重和偏置。在 forward 方法中,我们手动计算线性变换:Y = XW + b。这个实现与 nn.Linear 提供的功能类似,但更多地体现了手动管理权重和偏置的方式。

nn.Linear

另一方面,nn.Linear 是 PyTorch 提供的一个封装好的模块,用于执行线性变换。它不仅包含了权重矩阵,还自动处理了偏置项(除非明确设置 bias=False)。例如:

self.linear = nn.Linear(in_channels, out_channels)

当调用 self.linear(x) 时,它实际上是在执行以下操作:

output = torch.matmul(x, self.linear.weight.t()) + self.linear.bias

这里,self.linear.weight 的形状是 (out_channels, in_channels),而不是直接 (in_channels, out_channels),因此在进行矩阵乘法之前需要对其转置 (t() 方法)。这意味着数学公式实际上是 Y = X W T + b Y = XW^T + b Y=XWT+b,其中 W T W^T WT 表示权重矩阵的转置。

示例 2:使用 nn.Linear

import torch
import torch.nn as nn# 定义一个线性层
linear_layer = nn.Linear(in_features=3, out_features=2)# 打印权重和偏置
print("Weights:", linear_layer.weight)
print("Bias:", linear_layer.bias)# 输入数据
input_data = torch.randn(4, 3)  # 4个样本,每个样本有3个特征# 前向传播
output = linear_layer(input_data)
print("Output:", output)

在这个示例中,我们创建了一个线性层,它接受一个形状为 [4, 3] 的输入数据,并将其映射到一个形状为 [4, 2] 的输出数据。linear_layer.weightlinear_layer.bias 是自动初始化的。

数学公式的对比

  • 对于手动定义的 nn.Parameter,如果输入是 X X X (形状为 [ N , i n _ c h a n n e l s ] [N, in\_channels] [N,in_channels]),权重是 W W W (形状为 [ i n _ c h a n n e l s , o u t _ c h a n n e l s ] [in\_channels, out\_channels] [in_channels,out_channels]),那么输出 Y Y Y 将通过 Y = X W Y = XW Y=XW 计算。

  • 对于 nn.Linear,同样的输入 X X X (形状为 [ N , i n _ c h a n n e l s ] [N, in\_channels] [N,in_channels]),但是权重 W ′ W' W (形状为 [ o u t _ c h a n n e l s , i n _ c h a n n e l s ] [out\_channels, in\_channels] [out_channels,in_channels]),输出 Y Y Y 将通过 Y = X ( W ′ ) T + b Y = X(W')^T + b Y=X(W)T+b 计算。

从上面可以看出,虽然两者都实现了线性变换,但在 nn.Linear 中,权重矩阵的形状是倒置的,以适应其内部的实现细节。此外,nn.Linear 还自动处理了偏置项的添加,这使得它比手动定义参数更加方便和简洁。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/72694.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙生态日日新,夸克、顺丰速运、驾校一点通等多款应用功能更新

3月5日鸿蒙生态日日新PLOG:吉事办、健康甘肃等政务服务App上架原生鸿蒙应用市场;夸克、顺丰速运、驾校一点通等多款应用功能更新。

基于SpringBoot的智慧停车场小程序(源码+论文+部署教程)

运行环境 • 前端:小程序 Vue • 后端:Java • IDE工具:IDEA(可自行选择) HBuilderX 微信开发者工具 • 技术栈:小程序 SpringBoot Vue MySQL 主要功能 智慧停车场微信小程序主要包含小程序端和…

致同报告:香港财政赤字加剧,扩大税基与增收迫在眉睫

2月26日香港政府2025-26年度财政预算案,(以下简称“预算案”)发布,香港财政司司长陈茂波提出一系列旨在减少开支并振兴香港经济的措施,以应对日益增长的财政赤字。主要提案包括对所有公务员实施冻薪、针对性税务宽减措…

在Spring Boot项目中分层架构

常见的分层架构包括以下几层: 1. Domain 层(领域层) 作用:领域层是业务逻辑的核心,包含与业务相关的实体类、枚举、值对象等。它是对业务领域的抽象,通常与数据库表结构直接映射。 主要组件: 实体类(Entity):与数据库表对应的Java类,通常使用JPA或MyBatis等ORM框架…

实训任务2.2 使用Wireshark捕获数据包并分析

目录 【实训目标】 【实训环境】 【实训内容】 【实训步骤】 1.启动WireShark 2. 使用Wireshark捕获数据包 (1)选择网络接口 (2)捕获数据包 (1)设置Wireshark过滤器并捕获数据包 (2&…

工业自动化核心:BM100 信号隔离器的强大力量

安科瑞 吕梦怡 18706162527 BM100系列信号隔离器可以对电流、电压等电量参数或温度、电阻等非电量参数进行快速精确测量,经隔 离转换成标准的模拟信号输出。既可以直接与指针表、数显表相接,也可以与自控仪表(如PLC)、各种 A/D …

并发编程——累加器

目录 1 AtomicLong 1.1 核心功能 1.2 实现原理: (1)基于 Unsafe 的底层操作 (2) volatile字段的内存可见性 (3)CAS 操作与 ABA 问题 1.3 性能分析 1.4 使用场景 2 LongAdder 核心设计原理 1 分段存储 2 分散更新策略 3.处理高竞…

大模型管理工具:LLaMA-Factory

目录 一、安装与环境配置 二、​启动 Web 界面 三、数据准备 四、模型训练 五、模型评估 七、模型导出 八、API服务部署 LLaMA-Factory 是一个开源的大语言模型(LLM)微调框架,旨在简化大规模模型的训练、微调和部署流程。它支持多种主…

推流项目的ffmpeg配置和流程重点总结一下

ffmpeg的初始化配置,在合成工作都是根据这个ffmpeg的配置来做的,是和成ts流还是flv,是推动远端还是保存到本地, FFmpeg 的核心数据结构,负责协调编码、封装和写入操作。它相当于推流的“总指挥”。 先来看一下ffmpeg的…

大语言模型从理论到实践(第二版)-学习笔记(绪论)

大语言模型的基本概念 1.理解语言是人工智能算法获取知识的前提 2.语言模型的目标就是对自然语言的概率分布建模 3.词汇表 V 上的语言模型,由函数 P(w1w2 wm) 表示,可以形式化地构建为词序列 w1w2 wm 的概率分布,表示词序列 w1w2 wm…

strace工具的交叉编译

1、下载源码 git clone https://github.com/strace/strace.git cd strace 2、运行 bootstrap 脚本(如果需要) 如果源码中没有 configure 脚本,运行以下命令生成: ./bootstrap 3. 配置编译参数 运行 configure 脚本&#xff…

Vue 3 组件库持续集成 (CI) 实战:GitHub Actions 自动化测试与 Storybook 文档构建 - 构建高效可靠的组件库 CI 流程

引言 欢迎再次回到 Vue 3 + 现代前端工程化 系列技术博客! 在昨天的第十篇博客中,我们深入学习了代码覆盖率分析,掌握了利用 Jest 代码覆盖率报告提升单元测试有效性的方法,进一步巩固了组件库的质量防线。 今天,我们将迈向 自动化流程 的构建,聚焦于 持续集成 (Continu…

无穿戴动捕数字人互动方案 | 畅享零束缚、高沉浸的虚实交互体验

在数字化浪潮席卷而来的当下,虚拟人互动体验正逐渐成为各领域的新宠。长久以来,虚拟人驱动主要依靠穿戴式动作捕捉设备,用户需要通过佩戴传感器或标记点来实现动作捕捉。然而,随着技术的不断突破,一种全新的无穿戴动作…

03 HarmonyOS Next仪表盘案例详解(二):进阶篇

温馨提示:本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦! 文章目录 前言1. 响应式设计1.1 屏幕适配1.2 弹性布局 2. 数据展示与交互2.1 数据卡片渲染2.2 图表区域 3. 事件处理机制3.1 点击事件处理3.2 手势…

python-leetcode-统计构造好字符串的方案数

2466. 统计构造好字符串的方案数 - 力扣(LeetCode) 这个问题可以用**动态规划(DP)**来解决,思路如下: 思路 1. 定义 DP 数组 设 dp[i] 表示长度为 i 的好字符串的个数。 2. 状态转移方程 我们可以在 dp…

MySQL------存储引擎和用户和授权

9.存储引擎 1.两种引擎 MyISAM和InnoDB 2.两种区别 1.事务: MyISAM不支持事务 2.存储文件: innodb : frm、ibd MyISAM: frm、MYD、MYI 3.数据行锁定: MyISAM不支持 4.全文索引: INNODB不支持,所以MYISAM做select操作速度很快 5.外键约束: MyISAM…

题海拾贝:P9241 [蓝桥杯 2023 省 B] 飞机降落

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《数据结构与算法之美》、《题海拾贝》 欢迎点赞&#xff0c;关注&#xff01; 1、题…

删除有序数组中的重复项(js实现,LeetCode:26)

给你一个 非严格递增排列 的数组 nums &#xff0c;请你原地删除重复出现的元素&#xff0c;使每个元素只出现一次&#xff0c;返回删除后数组的新长度。元素的相对顺序应该保持一致。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数量为 k &#xff0c;你需要做以…

3-9 WPS JS宏单元格复制、重定位应用(拆分单表到多表)

************************************************************************************************************** 点击进入 -我要自学网-国内领先的专业视频教程学习网站 *******************************************************************************************…

大白话react第十六章React 与 WebGL 结合的实战项目

大白话react第十六章React 与 WebGL 结合的实战项目 1. 项目简介 React 是一个构建用户界面的强大库&#xff0c;而 WebGL 则允许我们在网页上实现高性能的 3D 图形渲染。将它们结合起来&#xff0c;我们可以创建出炫酷的 3D 网页应用&#xff0c;比如 3D 产品展示、虚拟场景…