Pytorch深度学习实践(8)多分类任务

多分类问题

多分类问题主要是利用了Softmax分类器,数据集采用MNIST手写数据集

设计方法:

  • 把每一个类别看成一个二分类的问题,分别输出10个概率

    在这里插入图片描述

    但是这种方法存在一种问题:不存在抑制问题,即按照常规来讲,当 y ^ 1 \hat y_1 y^1较大时,其他的值应该较小,换句话说,就是这10个概率相加不为1,即最终的输出结果不满足离散化分布的要求

  • 引入Softmax算子,即在输出层使用softmax
    在这里插入图片描述

Softmax

softmax函数定义:
P ( y = i ) = e z i ∑ j = 0 K − 1 e z j P(y = i) = \frac{e^{z_i}}{\sum _{j=0}^{K-1}e^{z_j}} P(y=i)=j=0K1ezjezi

  • 首先,使用指数 e z i e^{z_i} ezi,保证所有的输出都大于0
  • 其次,分母中对所有的输出进行指数计算后求和,相当于对结果归一化,保证了最后10个类别的概率相加为1

损失函数

NLL损失

L o s s ( Y ^ , Y ) = − Y l o g Y ^ Loss(\hat Y, Y) = -Ylog \hat Y Loss(Y^,Y)=YlogY^

在这里插入图片描述

y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum

也可以直接将标签为1的类对应的 l o g Y ^ log\hat Y logY^拿出来直接做运算

交叉熵损失

交叉熵损失 = LogSoftmax + NLLLoss
在这里插入图片描述

注意:

  • 交叉熵损失中,神经网络的最后一层不做激活,激活函数包括在交叉熵中了
  • y需要使用长整型的张量 y = torch.LongTensor([0])
  • 直接调用函数使用,criterion = torch.nn.CrossEntropyLoss()

代码实现

import torch
from torchvision import transforms  # 对图像进行处理
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F  # 非线性函数的包
import torch.optim as optim  # 优化

处理数据集

要把图像数据转化为Tensor形式,因此要定义transform对象

########### 数据集处理 ##########
batch_size = 64
## 将图像转化为张量
transform = transforms.Compose([transforms.ToTensor(),  # 使用ToTensor()方法将图像转化为张量transforms.Normalize((0.1307, ), (0.3081, ))  # 归一化 均值-标准差 切换到01分布
])
tran_dataset = datasets.MNIST(root='./dataset/mnist/',train=True,download=True,transform=transform)  # 将transform放到数据集里 直接处理train_loader = DataLoader(tran_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='./dataset/mnist/',train=False,download=True,transform=transform)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size)

模型定义

由于线性模型的输入必须是一个向量,而图像为 28 × 28 × 1 28×28×1 28×28×1的矩阵,因此要先把图像拉成成一个 1 × 784 1×784 1×784的向量

这种做法会导致图像中的一些局部信息丢失

########## 模型定义 ##########
## 1×28×28 --view-> 1×784
## 784 --> 512 --> 256 --> 128 --> 128 --> 64 -->10
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # 图像矩阵拉伸成向量x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))x = self.l5(x)  # 最后一层不需要激活 sigmoid包含在cross entropy里return xmodel = Net()

注意最后一层不需要添加softmax,因为损失函数使用的cross entorpy已经包含softmax函数

损失函数和优化器设置

由于模型比较复杂,所以使用动量梯度下降方法,参数设置为0.5

########## 损失函数和优化器设置 ##########
criterion = torch.nn.CrossEntropyLoss()
## momentum指带冲量的梯度下降
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.5)

动量梯度下降(Gradient descent with momentum)

动量梯度下降方法是对小批量mini-batch的一种优化,可以有效地减少收敛过程中摆动幅度的大小,提高效率

动量梯度下降的过程类似于一个带有质量的小球在函数曲线上向下滚落,当小求滚落在最低点后,由于惯性还会继续上升一段距离,然后再滚落回来,再经最低点上去…最终小球停留在最低点处。而且由于小球具有惯性这一特点,当函数曲面比较复杂陡峭时,它便可以越过这些而尽快达到最低点

动量梯度下降每次更新参数时,对各个mini-batch求得的梯度 ∇ W \nabla W W ∇ b \nabla b b使用指数加权平均得到 V ∇ w V_{\nabla w} Vw V ∇ b V_{\nabla b} Vb,即通过历史数据来更新当前参数,从而得到更新公式:
V ∇ W n + 1 = β V ∇ W n + ( 1 − β ) ∇ W n V_{\nabla W_{n+1}} = \beta V_{\nabla W_n} + (1 - \beta)\nabla W_n VWn+1=βVWn+(1β)Wn

V ∇ b n + 1 = β V ∇ b n + ( 1 − β ) ∇ b n V_{\nabla b_{n+1}} = \beta V_{\nabla b_n} + (1 - \beta)\nabla b_n Vbn+1=βVbn+(1β)bn

模型训练

将训练阶段封装成一个函数

########## 模型训练 ##########
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()  # 优化器清零## forwardoutputs = model(inputs)loss = criterion(outputs, labels)## backwardloss.backward()## updataoptimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:  # 每300轮 输出一次损失print('[%d %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0

模型测试

将模型的测试封装成一个函数,且每训练一轮,测试一次

def test():correct = 0total = 0with torch.no_grad():  # 以下代码不会计算梯度for data in test_loader:images, labels = dataoutputs = model(images)  # 得到十个概率## 返回 最大值和最大值下标 我们只需要最大值下标即可_, predicted = torch.max(outputs.data, dim=1)  # 取最大值total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))

主函数

########### main ##########
if __name__ == '__main__':accuracy_history = []epoch_history = []for epoch in range(50):train(epoch)accuracy = test()accuracy_history.append(accuracy)epoch_history.append(epoch)plt.plot(epoch_history, accuracy_history)plt.xlabel('epoch')plt.ylabel('accuracy')plt.show()

完整代码

import torch
import matplotlib.pyplot as plt
from torchvision import transforms  # 对图像进行处理
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F  # 非线性函数的包
import torch.optim as optim  # 优化########### 数据集处理 ##########
batch_size = 64
## 将图像转化为张量
transform = transforms.Compose([transforms.ToTensor(),  # 使用ToTensor()方法将图像转化为张量transforms.Normalize((0.1307, ), (0.3081, ))  # 归一化 均值-标准差 切换到01分布
])
tran_dataset = datasets.MNIST(root='./dataset/mnist/',train=True,download=False,transform=transform)  # 将transform放到数据集里 直接处理train_loader = DataLoader(tran_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='./dataset/mnist/',train=False,download=False,transform=transform)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size)########## 模型定义 ##########
## 1×28×28 --view-> 1×784
## 784 --> 512 --> 256 --> 128 --> 128 --> 64 -->10
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.l1 = torch.nn.Linear(784, 512)self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 784)  # 图像矩阵拉伸成向量x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))x = self.l5(x)  # 最后一层不需要激活 sigmoid包含在cross entropy里return xmodel = Net()########## 损失函数和优化器设置 ##########
criterion = torch.nn.CrossEntropyLoss()
## momentum指带冲量的梯度下降
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.5)########## 模型训练 ##########
def train(epoch):running_loss = 0.0for batch_idx, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()  # 优化器清零## forwardoutputs = model(inputs)loss = criterion(outputs, labels)## backwardloss.backward()## updataoptimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:  # 每300轮 输出一次损失print('[%d %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0
########## 模型测试 ##########
def test():correct = 0total = 0with torch.no_grad():  # 以下代码不会计算梯度for data in test_loader:images, labels = dataoutputs = model(images)  # 得到十个概率## 返回 最大值和最大值下标 我们只需要最大值下标即可_, predicted = torch.max(outputs.data, dim=1)  # 取最大值total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' % (100 * correct / total))return 100*correct / total########### main ##########
if __name__ == '__main__':accuracy_history = []epoch_history = []for epoch in range(50):train(epoch)accuracy = test()accuracy_history.append(accuracy)epoch_history.append(epoch)plt.plot(epoch_history, accuracy_history)plt.xlabel('epoch')plt.ylabel('accuracy(%)')plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/876491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

stm32h7串口发送寄存器空中断

关于stm32串口的发送完成中断UART_IT_TC网上资料挺多的,但是使用发送寄存器空中断UART_IT_TXE的不太多 UART_IT_TC 和 UART_IT_TXE区别 UART_IT_TC 和 UART_IT_TXE 是两种不同的 UART 中断源,用于表示不同的发送状态。它们的主要区别如下: …

raise JSONDecodeError(“Expecting value”, s, err.value) from None

raise JSONDecodeError(“Expecting value”, s, err.value) from None 目录 raise JSONDecodeError(“Expecting value”, s, err.value) from None 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是…

实战: SpringBoot中5种增强的方法 : 加解密、脱敏、格式转换、时间时区处理(码到三十五)

1. 使用JsonSerialize和JsonDeserialize注解 2. 全局配置Jackson的ObjectMapper 3. 使用ControllerAdvice配合InitBinder 4. 自定义HttpMessageConverter 5. 使用AOP进行切面编程 结语 在Spring Boot中,对接口的请求入参和出参进行自定义的增强或者修改&…

数字图像处理笔记(三) ---- 傅里叶变换的基本原理

系列文章目录 数字图像处理笔记(一)---- 图像数字化与显示 数字图像处理笔记(二)---- 像素加图像统计特征 数字图像处理笔记(三) ---- 傅里叶变换的基本原理 文章目录 系列文章目录前言一、傅里叶变换二、离散傅里叶变…

ChatTTS(文本转语音) 一键本地安装爆火语音模型

想不想让你喜欢的文章,有着一个动听的配音,没错,他就可以实现。 ChatTTS 是一款专为对话场景设计的文本转语音模型,例如 LLM 助手对话任务。它支持英语和中文两种语言。 当下爆火模型,在Git收获23.5k的Star&#xff…

Android中集成前端页面探索(Capacitor 或 Cordova 插件)待完善......

探索目标:Android中集成前端页面 之前使用的webview加载html页面,使用bridge的方式进行原生安卓和html页面的通信的方式,探索capacitor-android插件是如何操作的 capacitor-android用途 Capacitor 是一个用于构建现代跨平台应用程序的开源框…

可消费的媒体类型和可生成的媒体类型

可消费的媒体类型和可生成的媒体类型 在 Spring MVC 中,“可消费的媒体类型”和“可生成的媒体类型”是两个重要的概念,用于控制控制器方法处理和返回的内容类型。它们分别通过 consumes 和 produces 属性来指定。下面是它们的详细区别: 可…

【Pod 详解】Pod 的概念、使用方法、容器类型

《Pod 详解》系列,共包含以下几篇文章: Pod 的概念、使用方法、容器类型Pod 的生命周期(一):Pod 阶段与状况、容器的状态与重启策略Pod 的生命周期(二):Pod 的健康检查之容器探针Po…

C++入门基础:C++中的常用操作符练习

开头介绍下C语言先,C是一种广泛使用的计算机程序设计语言,起源于20世纪80年代,由比雅尼斯特劳斯特鲁普在贝尔实验室开发。它是C语言的扩展,增加了面向对象编程的特性。C的应用场景广泛,包括系统软件、游戏开发、嵌入式…

智慧医院临床检验管理系统源码(LIS),全套LIS系统源码交付,商业源码,自主版权,支持二次开发

实验室信息系统是集申请、采样、核收、计费、检验、审核、发布、质控、查询、耗材控制等检验科工作为一体的网络管理系统。它的开发和应用将加快检验科管理的统一化、网络化、标准化的进程。一体化设计,与其他系统无缝连接,全程化条码管理。支持危机值管…

DataX(二):DataX安装与入门

1. 官方地址 下载地址:http://datax-opensource.oss-cn-hangzhou.aliyuncs.com/datax.tar.gz 源码地址:GitHub - alibaba/DataX: DataX是阿里云DataWorks数据集成的开源版本。 2. 前置要求 Linux JDK(1.8 以上,推荐 1.8) Python(推荐 Pyt…

一文总结代理:代理模式、代理服务器

概述 代理在计算机编程领域,是一个很通用的概念,包括:代理设计模式,代理服务器等。 代理类持有具体实现类的实例,将在代理类上的操作转化为实例上方法的调用。为某个对象提供一个代理,以控制对这个对象的…

测试分类篇

按测试对象划分 这里可以分为界面测试, 可靠性测试, 容错率测试, 文档测试, 兼容性测试, 安装卸载测试, 安全测试, 性能测试, 内存泄露测试. 界面测试 界面测试(简称UI测试),指按照界面的需求(一般是UI设计稿)和界面的设计规则…

Vue3+element-plus 实现图片图片

在看下面内容之前,请一定要去看看 element-plus 中上传组件 el-upload组件 上传组件 重点关注下面几个属性 :auto-upload“false” , 关闭自动上传 :on-change“onUploadFile” 监听上传情况 简单示例: <el-form-item label"文章封面" prop"cover_img"&…

flume知识点

1. 简述什么是Flume &#xff1f; flume 作为 cloudera 开发的实时日志收集系统&#xff0c;受到了业界的认可与广泛应用。Flume 初始的发行版本目前被统称为 Flume OG&#xff08;original generation&#xff09;&#xff0c;属于 cloudera。 但随着 FLume 功能的扩展&#…

AI大模型学习必备十大网站

随着人工智能技术的快速发展&#xff0c;AI大模型&#xff08;如GPT-3、BERT等&#xff09;在自然语言处理、计算机视觉等领域取得了显著的成果。对于希望深入学习AI大模型的开发者和研究者来说&#xff0c;找到合适的学习资源至关重要。本文将为大家推荐十大必备网站&#xff…

[AI]在家中使用日常设备运行您自己的 AI 集群.适用于移动、桌面和服务器的分布式 LLM 推理。

创作不易 只因热爱!! 热衷分享&#xff0c;一起成长! “你的鼓励就是我努力付出的动力” AI发展不可谓不快, 从ollama个人电脑CPU运行到现在,日常设备AI集群. 下面对比一下,两款开源AI 大模型的分布式推理应用, exo 和cake. 1.AI 集群推理应用exo 和cake的简单对比 #mermaid-s…

DOS攻击实验

实验背景 Dos 攻击是指故意的攻击网络协议实现的缺陷或直接通过野蛮手段&#xff0c;残忍地耗尽被攻击对象的资源&#xff0c;目的是让目标计算机或网络无法提供正常的服务或资源访问&#xff0c;使目标系统服务系统停止响应甚至崩溃。 实验设备 一个网络 net:cloud0 一台模…

在Ubuntu 18.04上安装和使用PostgreSQL的方法

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 简介 关系数据库管理系统是许多网站和应用程序的关键组件。它们提供了一种结构化的方式来存储、组织和访问信息。 PostgreSQL&#xf…

基于微信小程序+SpringBoot+Vue的儿童预防接种预约系统(带1w+文档)

基于微信小程序SpringBootVue的儿童预防接种预约系统(带1w文档) 基于微信小程序SpringBootVue的儿童预防接种预约系统(带1w文档) 开发合适的儿童预防接种预约微信小程序&#xff0c;可以方便管理人员对儿童预防接种预约微信小程序的管理&#xff0c;提高信息管理工作效率及查询…