深度学习4.4笔记

《动手学深度学习》-4.4-笔记

验证数据集:通常是从训练集中划分出来的一部分数据,不要和训练数据混在一起,评估模型好坏的数据集

测试数据集:只用一次的数据集

k-折交叉验证(k-Fold Cross-Validation)是一种统计方法,用于评估和比较机器学习模型的性能。它通过将数据集分成k个子集(或“折”)来实现,每个子集都作为一次测试集,而剩余的k-1个子集则作为训练集。这个过程会重复k次,每次选择不同的子集作为测试集,最终将k次测试结果的平均值作为模型的性能评估。常用k=5/10,在没有足够多数据使用时。

总结:

训练数据集:训练模型参数

验证数据集:选择模型超参数

非大型数据集上通常使用k-折交叉验证

欠拟合(Underfitting)

欠拟合是指模型对训练数据的拟合程度不够,无法捕捉到数据中的规律和模式。换句话说,模型过于简单,无法很好地描述数据的特征。

过拟合(Overfitting)

过拟合是指模型对训练数据拟合得过于完美,以至于模型在训练数据上表现很好,但在新的、未见过的数据上表现很差。换句话说,模型过度学习了训练数据中的噪声和细节,而无法泛化到新的数据。

模型容量的定义

表示容量:模型的最大拟合能力,即通过调节参数,模型能够表示的函数族

  1. 模型参数数量:参数越多,模型容量通常越高。

  2. 模型结构复杂度:例如,神经网络的层数和每层的神经元数量。

  3. 数据复杂度:数据的复杂度(如样本数量、特征数量)也会影响模型容量的选择

模型容量与过拟合、欠拟合的关系

  • 容量不足:模型无法很好地拟合训练数据,导致欠拟合。

  • 容量过高:模型可能会过度拟合训练数据中的噪声,导致过拟合

总结;

模型容量需要匹配数据复杂度,否则可能过拟合或欠拟合

统计机器学习提供数学工具来衡量模型复杂度 

代码部分:
 

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

引入需要的库

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))#随机生成200个样本点(服从标准正态分布的x值)。
np.random.shuffle(features)#打乱样本顺序。
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)#加上噪音

分析:

生成一个多项式回归的训练/测试数据集。也就是说,我们在模拟一个“隐藏函数”,然后加一点噪声,生成一些数据,来用于模型训练。

多项式阶数:我们打算生成最多20阶的多项式数据(比如 1, x, x², ..., x¹⁹)。

true_w = np.zeros(max_degree)  # 创建一个长度为20的权重数组,初始值全是0
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

这一步设置了我们想要“模拟”的真实多项式模型的参数。它实际上模拟了一个三阶多项式:

y = 5 + 1.2x - 3.4x² + 5.6x³

其余的高阶项(x⁴ ~ x¹⁹)的系数为0。

poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))


这一步是关键!构造一个 多项式特征矩阵。

假设 features = [[x1], [x2], ..., [x200]],
我们把它转化为:
[[1, x1, x1², x1³, ..., x1^19],
 [1, x2, x2², x2³, ..., x2^19],
 ...
]

for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!

这一步是做多项式特征的缩放处理,用的是数学中的Gamma函数

举例:

  • gamma(1) = 0! = 1

  • gamma(2) = 1! = 1

  • gamma(3) = 2! = 2

  • gamma(4) = 3! = 6 ...

所以这是在做归一化的处理,让高阶项不会变得太大

labels = np.dot(poly_features, true_w)

这一步是最核心的:根据我们设定的权重 true_w 计算标签 y 值

可以理解为:
对每一行的多项式特征向量和权重向量做内积(点乘),
也就是:

  • 所以最终的标签是:
    真实标签 + 小范围扰动

# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]
#这句用列表推导式,把之前的 NumPy 数组全部 转换成 PyTorch 的 tensor(张量)格式,这样就可以用 PyTorch 来训练模型啦!
features[:2], poly_features[:2, :], labels[:2]#这个不是赋值语句,而是查看前两个样本的输入特征、多项式特征和标签的值,

已经把 NumPy 的数组转成了 PyTorch 的张量

def evaluate_loss(net, data_iter, loss):  #@save"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)#前向传播 + 计算损失 让模型对输入 X 做预测,得到输出 outy = y.reshape(out.shape)l = loss(out, y)#计算预测结果和真实值之间的损失metric.add(l.sum(), l.numel())#计算预测结果和真实值之间的损失return metric[0] / metric[1]#计算预测结果和真实值之间的损失

评估模型在某个数据集(data_iter)上的平均损失

分析:

  • net: 模型(PyTorch 中定义的神经网络)

  • data_iter: 数据迭代器(通常是训练集或测试集的 DataLoader

  • loss: 损失函数(比如 nn.MSELoss()

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):#定义了一个训练函数loss = nn.MSELoss(reduction='none')#均方误差损失函数(MSE),但不求平均,保留每个样本的损失值。input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)#把训练和测试数据打包成 DataLoader,方便模型一批一批训练test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)#使用随机梯度下降(SGD)优化模型参数,学习率为 0.01animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])#用 D2L 里的 Animator 动态绘图类,记录训练过程的 loss 曲线for epoch in range(num_epochs):#训练一个 epoch,用的是 D2L 中封装好的 train_epoch_ch3(每轮完整训练一遍所有 batch)d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:#每隔20轮(或第1轮),就评估一下训练集和测试集上的平均损失,然后加到图上animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())  #打印最终训练得到的权重

PyTorch + 多项式特征训练一个线性模型,并可视化训练过程的

  • 用线性模型拟合你设计的多项式数据(多阶特征)

  • 使用 MSELoss + SGD 训练

  • 可视化训练和测试集上的损失变化

  • 打印最终训练好的模型参数,看看学得准不准

按书中的报错然后,你调用了 l.backward() 来反向传播,但这个 l 是一个 不需要梯度 的张量(requires_grad=False),所以无法反向传播!

还是之前的做法:

loss = nn.MSELoss(reduction='none')返回的是一个 每个样本的损失 的张量,而不是所有样本损失的平均或总和。

看看 train_epoch_ch3 的定义),它里面可能是直接用了 l = loss(y_hat, y),然后 l.backward()

修改后

正常:

欠拟合

如果用不同复杂度的模型来拟合这个函数,表现会怎样?

# 只用 1 和 x 两项(线性模型)
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

这意味着你在训练一个线性模型

这个模型完全忽略了二阶项 和三阶项 ,所以它根本学不出原来的复杂模式

结果就是:

  • 训练损失很高

  • 测试损失也高

  • 模型欠拟合:学得太简单,跟不上真实的非线性函数

# 使用与真实模型相同的特征阶数
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

 这次你用了前4项:

注意:你训练的时候也会拟合这几个特征,也就是:

而我们真实函数 y = 5 + 1.2x - 3.4x^2 + 5.6x^3,刚好就是3阶多项式

  • 训练损失下降得更快

  • 最终损失更低

  • 模型可以很好地拟合数据,不欠拟合也不过拟合

 

poly_features[:, :] 表示使用 所有20阶的多项式特征,也就是:

  • 训练了一个 20维输入的线性模型

  • 训练次数设为 1500 轮(比前面更多)

  • 但现在用一个 包含20阶的模型 去拟合这些数据,虽然原函数只有3阶,后面17个高阶项都是“多余的”。

  • 训练集表现很好(损失很低),但在测试集上 泛化能力变差

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/73808.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 两种路由模式

一、两种模式比较 在vue.js中,路由模式分为两种:hash 模式和 history 模式。这两种模式决定了URL的结构和浏览器历史记录的管理方式。 1. hash 模式带 #,#后面的地址变化不会引起页面的刷新。换句话说,hash模式不会将#后面的地址…

Android生态大变革,谷歌调整开源政策,核心开发不再公开

“开源”这个词曾经是Android的护城河,如今却成了谷歌的烫手山芋。最近谷歌宣布调整Android的开源政策,核心开发将全面转向私有分支。翻译成人话就是:以后Android的核心更新,不再公开共享了。 这操作不就是开源变节吗,…

JavaScript中集合常用操作方法详解

JavaScript中集合常用操作方法详解 JavaScript中的集合主要包括数组(Array)、集合(Set)和映射(Map)。下面我将详细介绍这些集合类型的常用操作方法。 数组(Array) 数组是JavaScript中最常用的集合类型,提供了丰富的操作方法。 创建数组 // 字面量创建 const ar…

【HC-05】蓝牙串口通信模块调试与应用(1)

一、HC-05 基础学习视频 HC-05蓝牙串口通信模块调试与应用1 二、HC-05学习视频课件

【学Rust写CAD】18 定点数2D仿射变换矩阵结构体(MatrixFixedPoint结构别名)

源码 // matrix/fixed.rs use crate::fixed::Fixed; use super::generic::Matrix;/// 定点数矩阵类型别名 pub type MatrixFixedPoint Matrix<Fixed, Fixed, Fixed, Fixed, Fixed, Fixed>;代码解析 这段代码定义了一个定点数矩阵的类型别名 MatrixFixedPoint&#xff…

axios文件下载使用后端传递的名称

java后端通过HttpServletResponse 返回文件流 在Content-Disposition中插入文件名 一定要设置Access-Control-Expose-Headers&#xff0c;代表跨域该Content-Disposition返回Header可读&#xff0c;如果没有&#xff0c;前端是取不到Content-Disposition的&#xff0c;可以在统…

HarmonyOS之深入解析如何根据url下载pdf文件并且在本地显示和预览

一、文件下载 ① 网络请求配置 下载在线文件&#xff0c;需要访问网络&#xff0c;因此需要在 config.json 中添加网络权限&#xff1a; {"module": {"requestPermissions": [{"name": "ohos.permission.INTERNET","reason&qu…

鸿蒙前后端项目源码-点餐v3.0-原创!原创!原创!

鸿蒙前后端点餐项目源码含文档ArkTS语言. 原创作品.我半个月写的原创作品&#xff0c;请尊重原创。 原创作品&#xff0c;盗版必究&#xff01;&#xff01;&#xff01;&#xff01; 原创作品&#xff0c;盗版必究&#xff01;&#xff01;&#xff01;&#xff01; 原创作…

VUE3+TypeScript项目,使用html2Canvas+jspdf生成PDF并实现--分页--页眉--页尾

使用html2CanvasJsPDF生成pdf&#xff0c;并实现分页添加页眉页尾 1.封装方法htmlToPdfPage.ts /**path: src/utils/htmlToPdfPage.tsname: 导出页面为PDF格式 并添加页眉页尾 **/ /*** 封装思路* 1.将页面根据A4大小分隔边距&#xff0c;避免内容被中间截断* 所有元素层级不要…

5.Excel:从网上获取数据

一 用 Excel 数据选项卡获取数据的方法 连接。 二 要求获取实时数据 每1分钟自动更新数据。 A股市场_同花顺行情中心_同花顺财经网 用上面方法将数据加载进工作表中。 在表格内任意区域右键&#xff0c;刷新。 自动刷新&#xff1a; 三 缺点 Excel 只能爬取网页上表格类型的…

《深度剖析SQL之WHERE子句:数据过滤的艺术》

在当今数据驱动的时代&#xff0c;数据处理和分析能力已成为职场中至关重要的技能。SQL作为一种强大的结构化查询语言&#xff0c;在数据管理和分析领域占据着核心地位。而WHERE子句&#xff0c;作为SQL中用于数据过滤的关键组件&#xff0c;就像是一把精准的手术刀&#xff0c…

华为eNSP-配置静态路由与静态路由备份

一、静态路由介绍 静态路由是指用户或网络管理员手工配置的路由信息。当网络拓扑结构或者链路状态发生改变时&#xff0c;需要网络管理人员手工修改静态路由信息。相比于动态路由协议&#xff0c;静态路由无需频繁地交换各自的路由表&#xff0c;配置简单&#xff0c;比较适合…

Docker 快速入门指南

Docker 快速入门指南 1. Docker 常用指令 Docker 是一个轻量级的容器化平台&#xff0c;可以帮助开发者快速构建、测试和部署应用程序。以下是一些常用的 Docker 命令。 1.1 镜像管理 # 搜索镜像 docker search <image_name># 拉取镜像 docker pull <image_name>…

基础认证-单选题(一)

单选题 1、下列关于request方法和requestlnStream方法说法错误的是(C) A 都支持取消订阅响应事件 B 都支持订阅HTTP响应头事件 C 都支持HttpResponse返回值类型 D 都支持传入URL地址和相关配置项 2、如需修改Text组件文本的透明度可通过以下哪个属性方法进行修改 (C) A dec…

Logback使用和常用配置

Logback 是 Spring Boot 默认集成的日志框架&#xff0c;相比 Log4j&#xff0c;它性能更高、配置更灵活&#xff0c;并且天然支持 Spring Profile 多环境配置。以下是详细配置步骤及常用配置示例。 一、添加依赖&#xff08;非 Spring Boot 项目&#xff09; 若项目未使用 Sp…

MySQL基础语法DDLDML

目录 #1.创建和删除数据库 ​#2.如果有lyt就删除,没有则创建一个新的lyt #3.切换到lyt数据库下 #4.创建数据表并设置列及其属性,name是关键词要用name包围 ​编辑 #5.删除数据表 #5.查看创建的student表 #6.向student表中添加数据,数据要与列名一一对应 #7.查询studen…

在windows下安装windows+Ubuntu16.04双系统(下)

这篇文章的内容主要来源于这篇文章&#xff0c;为正式安装windowsUbuntu16.04双系统部分。在正式安装前&#xff0c;若还没有进行前期准备工作&#xff08;1.分区2.制作启动u盘&#xff09;&#xff0c;见《在windows下安装windowsUbuntu16.04双系统(上)》 二、正式安装Ubuntu …

Ubuntu24.04 离线安装 MySQL8.0.41

一、环境准备 1.1 官方下载MySQL8.0.41 完整包 1.2 上传包 & 解压 上传包名称是&#xff1a;mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundle.tar # 切换到上传目录 cd /home/MySQL8 # 解压&#xff1a; tar -xvf mysql-server_8.0.41-1ubuntu24.04_amd64.deb-bundl…

记录一次Dell服务器更换内存条报错解决过程No memory found

文章目录 问题问题分析解决流程总结 问题 今天给服务器添加了几个内存条&#xff0c;开启后报错 No memory found No useable DlMMs found. Verify the DlMMsare properly seated and that they are installed in the correct sockets. 问题分析 这个错误说明服务器在启动时没…

Apache HttpClient使用

一、Apache HttpClient 基础版 HttpClients 是 Apache HttpClient 库中的一个工具类&#xff0c;用于创建和管理 HTTP 客户端实例。Apache HttpClient 是一个强大的 Java HTTP 客户端库&#xff0c;用于发送 HTTP 请求并处理 HTTP 响应。HttpClients 提供了多种方法来创建和配…