P22_损失函数与反向传播

news/2025/11/22 18:10:44/文章来源:https://www.cnblogs.com/Samar-blog/p/19258280

P22_损失函数与反向传播

22.1损失函数的作用

  1. 计算实际输出和目标之间的差距
  2. 为我们更新输出提供一定的依据(反向传播)

22.2几种官方文档中的损失函数

打开torch.nn—Loss Functions:

【注意:损失函数只能处理float类型的张量,可在张量设置时加入dtype=torch.float或者张量中出现小数点】

1.L1Loss:

(1)其参数如下:

class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

其中reduction默认是mean,也可以设置为sum

(2)举个例子

输入X:1,2,3

对应的目标Y:1,2,5

mean:L1loss = (0+0+2)/3 = 0.6667

sum:LL1loss = 0+0+2 = 2

(3)将上例写入代码
点击查看代码
import torch
from torch.nn import L1Loss#损失函数只能处理float类型的张量
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)#将inputs和targets变成1batch_size、1channel,1行3列
inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))loss = L1Loss()
result = loss(inputs,targets)print(result)
(4)输出:

tensor(0.6667)

(5)设置reduction为sum

loss = L1Loss(reduction='sum')

输出为:
tensor(2.)

2.MSELoss(平方差):

(1)举个例子

输入X:1,2,3

对应的目标Y:1,2,5

则MSE=(0+0+2^2)/3 = 4/3 = 1.333

(2)写入代码:
点击查看代码
import torch
from torch.nn import L1Loss, MSELoss
from torch import nn#损失函数只能处理float类型的张量
inputs = torch.tensor([1,2,3],dtype=torch.float32)
targets = torch.tensor([1,2,5],dtype=torch.float32)#将inputs和targets变成1batch_size、1channel,1行3列
inputs = torch.reshape(inputs,(1,1,1,3))
targets = torch.reshape(targets,(1,1,1,3))loss_mse = MSELoss()
result_mse = loss_mse(inputs,targets)print(result_mse)
(3)输出:

tensor(1.3333)

3.CrossEntropyLoss(交叉熵):

(1)交叉熵计算公式

想要loss小,就要使x[class]大(其前有负号)以及log的部分小;

假设要分成三类:person、dog、cat,现在给他们分别赋上标签:0,1,2

设person假如经过神经网络的输出output为[0.1,0.2,0.3],target是1(狗的标签是1);

此时,代入到计算公式里面,output就是x,target就是class:

P22_第三个loss的使用

则计算结果为:loss(x,class) = -0.2 + log(exp(0.1)+exp(0.2)+exp(0.3)),计算机里面log默认是以e为底,即loss(x,class) = -0.2 + ln(exp(0.1)+exp(0.2)+exp(0.3)),得到1.1019
P22_CrossEntropyLoss的计算结果

(2)CrossEntropyLoss的shape要求

Input:(N,C),C是类别数,N是batch_size

Target:N

Output:scalar

(3)写入代码:
点击查看代码
x = torch.tensor([0.1,0.2,0.3])
y = torch.tensor([1])#按照Input的shape要求,把Input:(N,C),C是类别数,N是batch_size
#即分成三类:person,cat,dog;batch_size为一个
x = torch.reshape(x,(1,3))
#新建交叉熵
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x,y)print(result_cross)
结果如下: `tensor(1.1019)` 与上面我们计算的结果一致;

22.3在神经网络中使用Loss Function:

1.将CIFAR10的数据集放到神经网络

点击查看代码
import torchvision.transforms
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential#定义数据集
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)#定义神经网络
class DYL_seq(nn.Module):def __init__(self):super(DYL_seq, self).__init__()self.model1 = Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 32, 5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 64, 5, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return x#创建网络
dyl_seq = DYL_seq()#从dataloader去取数据
for data in dataloader:imgs,targets = dataoutputs = dyl_seq(imgs)print(outputs)print(targets)

2.输出:

点击查看代码
Files already downloaded and verified
tensor([[ 0.1075, -0.0205, -0.0137, -0.1302,  0.0355, -0.0131,  0.0246, -0.0382,0.1338,  0.0322]], grad_fn=<AddmmBackward0>)
tensor([3])
tensor([[ 0.0851,  0.0023,  0.0055, -0.1246,  0.0670, -0.0242,  0.0132, -0.0222,0.1512,  0.0251]], grad_fn=<AddmmBackward0>)
tensor([8])
......
结果中:

输入的图片放在神经网络中得到一个输出:一共有10个输出,每个值代表预测是这个类别的概率

[ 0.1075, -0.0205, -0.0137, -0.1302, 0.0355, -0.0131, 0.0246, -0.0382,0.1338, 0.0322]

其对应的tensor([3])代表其真实类别是3;

3.使用Loss Function

(1)添加代码:
点击查看代码
loss = nn.CrossEntropyLoss()#创建网络
dyl_seq = DYL_seq()#从dataloader去取数据
for data in dataloader:imgs,targets = dataoutputs = dyl_seq(imgs)result_loss = loss(outputs,targets)#CrossEntropyLoss的计算公式中,x就是神经网络的输出outputs,class就是targetsprint(result_loss)
(2)输出结果如下:
点击查看代码
Files already downloaded and verified
tensor(2.3312, grad_fn=<NllLossBackward0>)
tensor(2.1667, grad_fn=<NllLossBackward0>)
.......

我们得到的这些数:就是神经网络的输出和真实输出的误差

4.lossfunction的第二个作用:为我们更新输出提供一定的依据(反向传播)

(1)即grad

给每一个卷积核中的参数设置了一个grad梯度,当我们采用反向传播时,每一个要更新的参数就会求出对应多个梯度,在优化过程中,就会对其中参数进行一个优化

(2)代码如下
点击查看代码
x for data in dataloader:    imgs,targets = data    outputs = dyl_seq(imgs)    result_loss = loss(outputs,targets)    #CrossEntropyLoss的计算公式中,x就是神经网络的输出outputs,class就是targets    result_loss.backward()    #print(result_loss)    print("ok")
(3)把断点打在: result_loss.backward(),然后debug

可见:dyl_seq的modules下边的‘model1’中的modules的第一层‘0’里面的weight下边的gard为None

P22_weight下边的gard为None

(4)然后调试进入到下一步:

其gard梯度就计算出来了:
P22_调试到下一步后weight下边的gard为None

但是如果删掉了result_loss.backward,就会发现gard始终为none

所以:

要用backward去做反向传播,他就可以计算出每一个节点的参数,就得到各个节点的参数的梯度,接下里就可以选用合适的优化器,对其参数极限那优化,对loss达到降低的目的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/973312.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

团队作业3-需求改进系统设计

团队作业3-综合报告(Alpha阶段)项目 详情这个作业属于哪个课程 计科23级12班这个作业要求在哪里 作业要求链接这个作业的目标 对现有项目进行设计和需求&原型改进,进行 Alpha 阶段任务分配队名与队员: MCoder,…

完整教程:Opencv(一): 用Opencv了解图像

完整教程:Opencv(一): 用Opencv了解图像pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "M…

docker compose插件安装

参考链接 在 Ubuntu 22.04 中,我们需要从官方 Docker 仓库安装 Docker Compose 插件。首先,让我们确保我们具备必要的先决条件: sudo apt-get install -y ca-certificates curl gnupg现在,添加 Docker 的官方 GPG …

完整教程:树与二叉树的奥秘全解析

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

C#扩展成员全面解析:从方法到属性的演进

本文详细介绍了C#中扩展成员的发展历程,从传统的扩展方法到C# 14中的扩展属性和索引器,通过具体代码示例展示如何更优雅地扩展现有类型功能。C#扩展成员:你需要知道的一切 扩展方法在C#中已经存在很长时间。它们允许…

多机elasticsearch集群部署,超详细教程

假设我们有三台机器,172.24.52.209,172.24.52.210,172.24.52.211。 用户名xjw 三台机器都创建文件夹/home/xjw/docker/learning/elasticsearch,和/home/xjw/docker/learning/kibana,learning为项目名 mkdir -p /h…

DeepSeek 提取 交易所网站核心500词汇(名词与术语)

DeepSeek 提取 交易所网站核心500词汇(名词与术语)交易所网站核心词汇扩充(名词与术语) 1. 市场结构与微观结构 英文术语中文翻译Auction 竞价Opening Auction 开市竞价Volatility Control Mechanism 市场波动调节…

[251122 678mAh] 模拟赛没破防有感 3.0

/hanx我写完 T4 了。 编译运行。 好的,卡住了。 应该是 RE 了。 ?为什么是在输出完答案之后卡住。 (想起了不好的回忆) 诶,别! 别别别别别! 就剩一个小时了别给我干死机了,这次可不是在线提交啊机子还原一下整…

白银市一对一培训机构推荐,2026年最新课外辅导全面测评口碑排名榜

在白银这座教育资源蓬勃发展的城市,从白银区繁华的北京路商圈到平川区快速崛起的会展中心周边,从靖远县文教氛围浓厚的学府街区到会宁县底蕴深厚的教育板块,从景泰县充满活力的新城商圈到皋兰县快速成长的教育园区,…

天水市一对一培训机构推荐,2026最新课外辅导机构口碑深度测评排名榜

在天水市,无论是秦州、麦积两区的繁华都市圈,还是秦安、甘谷、武山、清水、张家川回族自治县等地的莘莘学子,家长们都怀揣着同样的期望:让孩子在接受优质校内教育的同时,能通过课外辅导弥补短板、拔高优势,在求学…

CSAPP bomblab

规则:对于每个\(phase\),你都需要输入一个字符串,使得\(explode\_bomb\)函数不被运行 在bomb目录下使用objdump -d bomb > bomb.s得到反汇编文件\(bomb.s\) \(shell\) 中使用 gdb bomb进入\(gdb\)调试phase_1000…

history of linux

Linux 是一个开源的、跨平台的操作系统,其历史可以追溯到 1991 年。以下是 Linux 的主要发展历史阶段:1. 前身:Minix(1987)开发者:Andrew S. Tanenbaum特点:一个小型、可移植的操作系统内核,主要用于教学。与 …

history linux

当你在 Linux 系统中运行 history 命令时,它会显示你之前执行过的命令历史记录。这个命令是 Shell(如 Bash)内置的,用于跟踪用户在终端中执行过的命令。1. 基本用法history功能:显示当前终端中执行过的命令历史记…

Spring BeanFactoryPostProcessor 接口

[[Spring IOC 源码学习总笔记]] BeanFactoryPostProcessor是 Spring 框架提供的一个扩展点接口,它允许开发者在 Spring 在BeanFactory 加载了所有bean定义,但尚未实例化任何bean 之后,对底层的 BeanDefinition 和 B…

嘉峪关市一对一培训机构推荐,2026年最新课外补习辅导口碑排名

在雄伟的嘉峪关脚下,教育的热潮正席卷这座城市的每个角落。从雄关区的人民商城周边,到长城区的富强路商圈,再到镜铁区的润泽园社区,随处可见家长们为子女教育奔波的身影。小学生的数学思维拓展与语文阅读能力提升,…

2025 AI 教育培训权威推荐榜深度评测排名

2025 AI 教育培训权威推荐榜深度评测排名 痛点深度剖析 我们团队在实践中发现,当前 AI 教育培训领域存在着诸多核心技术挑战。在教学内容方面,AI 技术发展迅猛,知识更新换代极快,很多培训机构的课程内容难以跟上技…

详细介绍:第七篇:匹配篇 | 怎么像做产品一样,为每个岗位“定制”你的简历?

详细介绍:第七篇:匹配篇 | 怎么像做产品一样,为每个岗位“定制”你的简历?2025-11-22 17:40 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow…

2025年布袋除尘器供应商权威推荐榜单:塑烧板除尘器/耐高温除尘器/防爆除尘器源头厂家精选

在环保要求日益严格的工业制造领域,布袋除尘器作为工业粉尘治理的核心设备,其过滤效率与运行稳定性直接关系到企业的环保合规与生产成本。 工业布袋除尘器通过滤袋过滤、脉冲清灰等技术,能有效捕集工业生产中产生的…

hbuilder是否支持云端部署

HBuilder确实支持云端部署。它提供了云端打包功能,允许开发者将项目上传到云端服务器进行打包,生成Android和iOS平台的安装包。以下是HBuilder云端部署的相关信息: HBuilder云端部署支持云端打包功能:HBuilder支持…

创建矩形并让矩形移动

RGB是颜色值 使⽤⼀个元组 (R, G, B) 表示,每个值范围 0-255 。 ⿊⾊: (0, 0, 0) ⽩⾊: (255, 255, 255) 绿⾊: (0, 255, 0) pygame 坐标系 原点 (0, 0):窗⼝的左上⻆。 X 轴:向右增加 Y 轴:向下增加 按下的按键类…