Hands on Deep Learning Chapter 3 线性神经网络

news/2025/10/20 20:58:37/文章来源:https://www.cnblogs.com/milesssma/p/19146502

3 线性神经网络

3.1 线性回归

回归(regression)、预测(prediction)、分类(classification)

3.1.1 线性回归的基本元素

线性模型:对输入特征进行一个仿射变换(affine transformation,加权和对特征进行线性变换,偏置项进行平移)
单个数据样本:\(\hat{y}=\mathbf{w^T x}+b\),w:d×1,x:d×1,d个特征。
整个数据集,每一行是一个样本:\(\mathbf{\hat{y}}=\mathbf{Xw}+b\),y:n×1,X:n×d,w:d×1,b是标量,使用了广播机制。
找到一组参数(w,b),使得误差小,在获得最优的w、b之前,还需要考虑:

  1. 如何衡量误差
  2. 如何更新w、b

损失函数
单样本平方误差损失函数 \(l^{(i)}(\mathbf{w},b)=\frac{1}{2}(\hat{y}^{(i)}-y^{(i)})^2\)
经验误差是关于模型参数的函数(即损失函数是关于参数的,训练集固定)
\(L(\mathbf{w},b)=\frac{1}{n}\sum^{n}_{i=1}l^{(i)}(\mathbf{w},b)=\frac{1}{n}\sum^{n}_{i=1}\frac{1}{2}(\mathbf{ w^T x^{(i)}}+b-y^{(i)})^2\)
找一组w、b令上市最小化

解析解:\(\mathbf{w^*=(X^T X)^{-1}X^T y}\),Aw对w求导为A的转置,w^T A对w求导是A,w^T Aw对w求导是2A

随机梯度下降
梯度下降:gradient descent,沿着梯度相反的方向去更新参数,从而使得损失函数减小。通常每次随机抽取一小批样本,called 小批量随机梯度下降(minibatch stochastic gradient descent)
1)初始化模型的值;
2)抽取小批量样本且在负梯度方向上更新参数。
超参数:batch size:批量大小、learning rate:学习率,不在训练过程中更新,根据训练迭代成果来调整,训练迭代的结果通过验证集(validation set)评估得到。
线性回归在一整个域中只有1个最小值,更难的是泛化(generalization),找到一组参数在没见过的数据集上表现良好。

3.1.3 正态分布与平方损失

均方损失误差可以用于线性回归的一个原因:假设观测中包含噪声,且服从正态分布(均值为0,方差恒定)。
给定x观测到特定y的似然是通过噪声的分布来建模的,噪声取到某个值的概率密度,直接导致观测到y的概率密度:
\(P(y|\mathbf{x})=\frac{1}{\sqrt{2\pi\sigma^2}}exp(-\frac{1}{2\sigma^2}(y-\mathbf{w^T x}-b)^2)\)
极大似然估计,选择w、b令取到这种特定数据集的概率最大。
\(P(\mathbf{y|X})=\prod^n_{i=1}p(y^{(i)}|x^{(i)})\),对数化
\(-logP(\mathbf{y|X})=\sum^{n}_{i=1}\frac{1}{2}log(2\pi\sigma^2)+\frac{1}{2\sigma^2}(y^{(i)}-\mathbf{w^T x^{(i)}}-b)^2\)
与均方损失误差的优化是一样的(前面常数不考虑),只需要假设方差是常数即可。
高斯假设下:最小化均方误差=对线性模型的极大化似然估计

3.2 线性回归的从零开始实现

1)数据集的生成与读取:每次读取一个小批量,input:batch_size, X, y,生成数据迭代器,每次返回batch size大小的一组特征和标签。
2)初始化模型参数:w用均值0,标准差0.01的正态,b为0;
3)定义模型:torch.maxmul(X, w)+b;
4)定义损失函数;
5)定义优化算法:用批量大小除以损失值,避免因为batch_size的选择导致损失过大or过小,导致计算出的梯度大小影响更新步长;
6)训练:执行循环:初始化参数、计算梯度、更新参数,可以迭代多个周期(epoch)。

3.3 线性回归的简洁实现

1)生成数据集、读取数据集: data.TensorDataset(features, labels), data.DataLoader(dataset, batch_size, shuffle=True),生成一个数据迭代器,可以用for循环遍历,并且执行sgd;
2)定义模型: net=nn.Sequential(nn.Linear(2, 1)), 第一个2指的是输入特征形状,第二个1是输出特征形状;
3)初始化模型参数: net[0].weight/bias.data.normal(0, 0.01)/fill_(0);
4)定义损失函数: nn.MSELoss();
5)定义优化算法: torch.optim.SGD(net.parameters(), lr=0.03);
6)训练:遍历数据迭代器,前向传播计算net(X)生成预测,计算损失,反向传播计算梯度,调用优化器来更新模型参数,注意每次迭代要梯度清0。

3.4 softmax回归

硬类别:是什么,软类别:每类的概率
独特编码(one-hot encoding),(1,0,0)、(0,1,0)、(0,0,1)

3.4.2 网络架构

多个输入,每个类别对应一个输出,对一个样本就要做多输出了,每一个输出对应着一个仿射函数,比如4个特征、3个类别,需要12个权重+3个偏置。
o1=x1w11+x2w12+x3w13+x4w14+b1
o2=x1w21+x2w22+x3w23+x4w24+b2
o3=x1w31+x2w32+x3w33+x4w34+b3
\(\mathbf{o=Wx+b}\), softmax也是单层、全连接层。

3.4.3 参数开销

d个input转换为q个output成本O(dq),but可以减少到O(dq/n)

3.4.4 softmax运算

分类问题要求得到预测结果,可以选择最大概率的标签,不过我们有时也需要软标签,也就是具体每个类别的概率,我们需要这些数相加=1,大于等于0。
softmax函数:将未规范化的预测变换为非负数且总和为1:
\(\mathbf{\hat{y}}=softmax(\mathbf{o})\)
\(\hat{y_j}=\frac{exp(o_j)}{\sum_k exp(o_k)}\)
softmax不会更改大小次序,虽然是非线性函数,但是是线性模型,可以直接通过oj的值选择分类

3.4.7 信息论基础

information theory: 涉及编码、解码、发送以及尽可能简洁地处理信息or数据。

量化数据中的信息内容,该数值被成为分布P的熵(entropy)。
\(H[P]=\sum_{j}-P(j)logP(j)\)
信息论基本定理之一:为了对分布p中随机抽取的数据进行编码,我们至少需要H[P]“纳特nat”对其进行编码,“纳特”相当于以对数e为底rather than 2的比特。
\(H_{nat}=-\sum P_i log_e P_i = \frac{H_{bit}}{log_2 e}\)
对于2进制,信息为:-ln1/2 nats = -log_2 1/2 bit, 即:ln2nats=1bits,hence 1nats=1/ln2bits=1.44bits
信息流
如果很容易预测下一个数据,则可以把数据压缩过大,不用传递那么多信息,可以丢一些。
如果是常数数据流,我们不用传递任何信息,“下一个数据是xx”这个事件毫无信息量。
如果不能完全预测每一个事件,有时候会感到“诧异”,香农使用\(log\frac{1}{P(j)}=-logP(j)\)来量化这种惊异程度,P(j)是主观概率,如果概率较低,则出现时惊异程度会更大,该事件的信息量也更大。
熵的定义是分配的概率真正匹配数据生成过程时的信息量的期望,出现的概率是P(j),信息量是-logP(j)。
交叉熵,从P到Q记为\(H(P,Q)\),主观概率为Q的观察者看到根据概率P生成的数据时的预期惊异,当P=Q时,交叉熵最低。
主观概率Q,可以理解为我们模型训练出来的,客观概率P是数据集的分类,\(H(P,Q)=E_{x ~ p}(-log(Q(x))\)
两个角度理解交叉熵分类目标:
1)最大化观测数据的似然;
2)最小化传达标签所需要的惊异。

3.6 softmax的从零开始实现

1)制作数据迭代器;
2)模型参数初始化(W,b);
3)定义softmax操作;
4)定义模型;
5)定义损失函数;
6)分类精度;
7)训练;
8)预测;
一些问题:exp当input比较大时,可能导致python数据结构的溢出、且交叉熵损失要求input>0。

3.7 softmax的简洁实现

1)初始化模型参数;
2)重新审视softmax实现:
exp(o_k)可能上溢,let分子or分母无穷大,最后得到0、inf、nan,然后损失函数也算不出来。
可以把o_j换成o_j-max(o_k)。
如果o_j-max比较小,可能接近0发生下溢,结果为0,损失函数算不出来,为inf,反向传播一堆nan。
可以把softmax函数带入损失函数后,直接把o_j带入进去,最后结果是\(log(\hat{y_i})=o_j-max(o_k)-log(\sum_k exp(o_k-max(o_k)))\)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/941594.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超越技术范畴:低代码如何重塑企业数字文化

当我们谈论低代码时,目光往往聚焦于其提升开发效率的技术特性。然而,它的深层影响力远不止于此。低代码更像是一颗投入企业静湖的石子,其激起的涟漪,正层层扩散,深刻地重塑着组织的协作模式、创新节奏乃至内在的数…

好用的网址

填验证码(? 题解格式化。 画图。 代码格式化。 纯文字图片生成器。

【C++实战(71)】解锁C++音视频编写:FFmpeg从入门到实战

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

20251020

正睿NOIP 二十连测 A 有 \(m(m \le 95)\) 种药剂,每种药剂有 \(n_i(\sum n_i \le 10^{15})\) 瓶,等级为 \(p_i\)(\(2 \le p_i \le 499\))。要将这些药剂分成两个不相交的集合 \(X, Y\),\(X\) 的价值为其组内所有药…

低代码赋能业务创新:打破数字鸿沟,释放业务潜能

在数字化转型的浪潮中,一个突出的矛盾日益显现:业务部门汹涌的创新需求,与IT部门有限的开发资源之间,形成了一道难以逾越的“数字鸿沟”。当市场部门需要一个临时的活动报名系统,当HR部门渴望一个高效的内部推荐工…

【大模型】大模型训练的几个不同阶段

总结:各方法的典型关联(以大语言模型为例)Pre-Training:先让模型学“通识知识”(如语言、世界知识)。 Supervised Fine-Tuning (SFT):用标注数据让模型学“任务基本模式”(如指令遵循)。 Reward Modeling:训…

详细介绍:1、手把手教你入门设计半桥LLC开关电源设计,LLC谐振腔器件计算

详细介绍:1、手把手教你入门设计半桥LLC开关电源设计,LLC谐振腔器件计算pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family…

十六天

今日重点学习关系型数据库基础,核心掌握三个模块:一是数据表的结构化设计,明确字段类型(如INT、VARCHAR)需与数据属性匹配,避免后续数据存储异常;二是主键的作用,通过实操验证其“唯一标识记录”的必要性——未…

10/20/2025杂题 关于在线性时间内求解低次多项式的幂

例 设 \(g = ax^2 + bx + c\),求: \[ f = g^n\]其中 \(0 \leq n \leq 3 \times 10^5\)。结果对 \(10^9 + 7\) 取模。 首先可以直接用 MTT 在 \(O(n \log n)\) 的时间复杂度内求解。然而此做法常数太大,在需要多次求…

歌手与模特儿

https://www.luogu.com.cn/problem/AT_nikkei2019_2_final_h 第一次见到能 manacher 但不能二分+哈希的题! 直接上 manacher,当尝试将区间拓展为 \([l,r]\) 时,考察 \(nxt_l\) 和 \(lst_r\) 的位置关系,可以 check…

20251019

正睿 NOIP 十连测 B 有 \(n\) 个数 \(a_1 \sim a_n\)。初始有一个 \(x = 1\),每次需要将 \(x\) 变为某个 \(i\),花费代价为 \(\min(|i - x|, n - |i - x|)\),且 \(a_x \le a_i\)。问访问所有 \(i\) 需花费的最小代价…

计算机毕业设计 基于EChants的海洋气象数据可视化平台设计与建立 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】

计算机毕业设计 基于EChants的海洋气象数据可视化平台设计与建立 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】pre { white-space: pre !important; word-wrap: normal !important; overflo…

SpringBoot整合Redis教程

一、Redis 简介 Redis(Remote Dictionary Server)是一个开源的高性能键值对存储数据库,基于内存运行并支持持久化,常用于缓存、会话存储、消息队列等场景。其核心特点包括:速度快:基于内存操作,单线程模型避免上…

https://www.luogu.com.cn/problem/CF1635E

考虑一个事情,两辆车方向一定相反,弱化限制后,建二元关系图,发现一定是一张二分图。 钦定左部点为向左,其他点为向右,然后发现位置满足一个二元大小关系限制,建 DAG 跑拓扑序即可。

ZR 2025 NOIP 二十连测 Day 5

85 + 32 + 5 + 5 = 127, Rank 67/128.呜呜我错了……我再也不开太大的 vector 了呜呜……/dk /dk /dk25noip二十连测day5 链接:link 题解:题目内 时间:4h (2025.10.20 14:00~18:00) 题目数:4 难度:A B C D\(\colo…

关于单片机内部ADC采样率,采样精度的理解与计算整理 - 实践

关于单片机内部ADC采样率,采样精度的理解与计算整理 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Conso…

整体架构与数据流

下面给出对题目及当前代码求解方法的系统、深入解析,涵盖三问建模逻辑、数据流、关键算法、假设与局限、以及改进建议。内容按“题目需求 -> 代码实现 -> 差异/假设 -> 评估/改进”结构展开,方便你写论文或…

【上青了】

【上青了】赶紧把面板记录一下先,怕下次又又又掉了 没什么好讲的,本来上场打完就差 11 分,这场只要正常发挥就没问题变色,所以也没什么激动,该激动的上次都激动完了,哎哎哎 要说就是这次状态还行,不算差,前面 …

[VIM] reverse multiple lines in VIM

推荐方法: If you’re on a Unix-like system (FreeBSD, Linux, macOS), use :14,19!tac.来自chatgptTo reverse the display order of lines 51 to 54 in Vim, you can use the :g and :tac-style command combinati…

DeviceNet 转 Ethernet/IP:三菱 Q 系列 PLC 与欧姆龙 CJ2M PLC 在食品饮料袋装生产线包装材料余量预警的通讯配置案例

案例背景 DeviceNet 转 Ethernet/IP在食品饮料行业,包装生产线涉及多种设备,如灌装机、贴标机、封口机等。不同设备可能采用不同的工业总线协议,为了实现整个包装生产线的自动化控制和数据共享,需要将采用 Etherne…