Pytorch intermediate(三) RNN分类

使用RNN对MNIST手写数字进行分类。RNN和LSTM模型结构

pytorch中的LSTM的使用让人有点头晕,这里讲述的是LSTM的模型参数的意义。


1、加载数据集

import torch 
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.utils.data as Data device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sequence_length = 28 
input_size = 28 
hidden_size = 128 
num_layers = 2 
num_classes = 10 
batch_size = 128 
num_epochs = 2 
learning_rate = 0.01 train_dataset = torchvision.datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor())train_loader = Data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=batch_size)

 2、构建RNN模型

  • input_size – 输入的特征维度

  • hidden_size – 隐状态的特征维度

  • num_layers – 层数(和时序展开要区分开)

  • bias – 如果为False,那么LSTM将不会使用,默认为True

  • batch_first – 如果为True,那么输入和输出Tensor的形状为(batch, seq, feature)

  • dropout – 如果非零的话,将会在RNN的输出上加个dropout,最后一层除外。

  • bidirectional – 如果为True,将会变成一个双向RNN,默认为False

       1、上面的参数来自于文档,最基本的参数是input_size, hidden_size, num_layer三个。input_size:输入数据向量维度,在这里为28;hidden_size:隐藏层特征维度,也是输出的特征维度,这里是128;num_layers:lstm模块个数,这里是2。

       2、h0和c0的初始化维度为(num_layer,batch_size, hidden_size

       3、lstm的输出有out和(hn,cn),其中out.shape = torch.Size([128, 28, 128]),对应(batch_size,时序数,隐藏特征维度),也就是保存了28个时序的输出特征,因为做的分类,所以只需要最后的输出特征。所以取出最后的输出特征,进行全连接计算,全连接计算的输出维度为10(10分类)。

       4、batch_first这个参数比较特殊:如果为true,那么输入数据的维度为(batch, seq, feature),否则为(seq, batch, feature)

       5、num_layers:lstm模块个数,如果有两个,那么第一个模块的输出会变成第二个模块的输入。

       总结:构建一个LSTM模型要用到的参数,(输入数据的特征维度,隐藏层的特征维度,lstm模块个数);时序的个数体现在X中, X.shape = (batch_size,  时序长度, 数据向量维度)。

       可以理解为LSTM可以根据我们的输入来实现自动的时序匹配,从而达到输入长短不同的功能。

class RNN(nn.Module):def __init__(self, input_size,hidden_size,num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers#input_size - 输入特征维度#hidden_size - 隐藏状态特征维度#num_layers - 层数(和时序展开要区分开),lstm模块的个数#batch_first为true,输入和输出的形状为(batch, seq, feature),true意为将batch_size放在第一维度,否则放在第二维度self.lstm = nn.LSTM(input_size,hidden_size,num_layers,batch_first = True)  self.fc = nn.Linear(hidden_size, num_classes)def forward(self,x):#参数:LSTM单元个数, batch_size, 隐藏层单元个数 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)   #h0.shape = (2, 128, 128)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)#输出output :  (seq_len, batch, hidden_size * num_directions)#(h_n, c_n):最后一个时间步的隐藏状态和细胞状态#对out的理解:维度batch, eq_len, hidden_size,其中保存着每个时序对应的输出,所以全连接部分只取最后一个时序的#out第一维batch_size,第二维时序的个数,第三维隐藏层个数,所以和lstm单元的个数是无关的out,_ = self.lstm(x, (h0, c0))  #shape = torch.Size([128, 28, 128])out = self.fc(out[:,-1,:])  #因为batch_first = true,所以维度顺序batch, eq_len, hidden_sizereturn out

 训练部分

model = RNN(input_size,hidden_size, num_layers, num_classes).to(device)
print(model)#RNN(
#  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
#  (fc): Linear(in_features=128, out_features=10, bias=True)
#)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)total_step = len(train_loader)
for epoch in range(num_epochs):for i,(images, labels) in enumerate(train_loader):#batch_size = -1, 序列长度 = 28, 数据向量维度 = 28images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward() optimizer.step()if (i+1) % 100 == 0:print(outputs.shape)print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/76709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

成都优优聚为什么值得信任?

成都优优聚能信任作为一家专业的电商服务公司,拥有丰富的经验和专业的团队,能够为商家提供全方位的美团代运营服务。 美团外卖作为国内领先的外卖平台,具有庞大的用户群体和丰富的商家资源。然而,美团代运营对于很多刚开始接触美团…

权威Scrum敏捷开发企业级实战培训-Leangoo领歌

​​​​​​​​课程简介 Scrum是目前运用最为广泛的敏捷开发方法,是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程,面向研发管理者、项目经理、产品经理、研发团队等,旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮…

Android平台GB28181接入SDK

华脉智联推出的Android平台GB28181接入SDK,可实现不具备国标音视频能力的 Android终端,通过平台注册接入到现有的GB/T28181—2016服务,可用于如执法记录仪、智能安全帽、智能监控、智慧零售、智慧教育、远程办公、明厨亮灶、智慧交通、智慧工…

ChatGPT追祖寻宗:GPT-1论文要点解读

论文地址:《Improving Language Understanding by Generative Pre-Training》 最近一直忙着打比赛,好久没更文了。这两天突然想再回顾一下GPT-1和GPT-2的论文, 于是花时间又整理了一下,也作为一个记录~话不多说,让我们…

生成多样、真实的评论(2019 IEEE International Conference on Big Data )

论文题目(Title):Learning to Generate Diverse and Authentic Reviews via an Encoder-Decoder Model with Transformer and GRU 研究问题(Question):评论生成,由上下文评论->生成评论 研…

vue3 props传入的组件无法正常刷新

问题描述: vue3写入的数据无法正常渲染,但是从子组件获取正常使用props导入 解决方案 在setup导出的时候,直接导入props,而不是导入props.变量 错误用法: props: [var1] let var1 "张三" setup() {ret…

如何学习运营管理

运营管理(Operations Management)是一门管理学科,它关注如何高效地组织和管理企业的生产、服务、供应链和业务过程以达到组织的目标。运营管理是企业管理的一个重要领域,它包含了多个内容和职能: 生产管理:…

2023高教社杯数学建模B题思路分析 - 多波束测线问题

# 1 赛题 B 题 多波束测线问题 单波束测深是利用声波在水中的传播特性来测量水体深度的技术。声波在均匀介质中作匀 速直线传播, 在不同界面上产生反射, 利用这一原理,从测量船换能器垂直向海底发射声波信 号,并记录从声波发射到…

界面控件DevExtreme DateRangeBox组件发布,支持日期范围选择!

在最新的v23.1版本中,DevExpress官方已经正式发布了DevExtreme DateRangeBox小部件,支持所有JavaScript框架,包括Angular、React、Vue和jQuery。这个新的控件允许最终用户选择一个日期范围,该组件继承了DateBox组件的特性&#xf…

一文了解气象观测站是什么?

一、气象观测站的定义 气象观测站是一种专门负责观测、记录气象数据的设施,包括风向、风速、温度、湿度、气压、降水量等多个气象要素。这些数据不仅对科研和预报具有重要意义,还对我们的日常生活有着极大的影响。 二、气象观测站的种类 气象观测站根…

Python练习分割字符串

str"itheima itcast boxuegu" # 统计字符串类有多少个“it”字符 count str.count("it") print(f"字符串类有{count}个“it”字符") # 将字符串内的空格全部替换为字符:“|” str_replace str.replace(" ", "|"…

一篇文章带你了解红黑树并将其模拟实现

了解红黑树并将其模拟实现 红黑树的概念和性质1. 概念2. 性质 红黑树的结构红黑树的节点定义及红黑树结构成员定义红黑树的插入1. 按照二叉搜索的树规则插入新节点2. 检测新节点插入后,红黑树的性质是否造到破坏情况一: cur为红,p为红,g为黑&…

从“白人饭”到美味佳肴,拓世AI为你打造独一无二的饮食计划

最近“白人饭”作为一种饮食方式在社交媒体上火了,成为打工人新的“午餐之光”。所谓“白人饭”,就是花最少的功夫准备仅仅能维系基本器官正常运作的食物,主打生吃或者简单炒,比如一个丹麦网友晒出的同事的午饭就是几根小胡萝卜和…

AjaxJavaScriptcss模仿百度一下模糊查询功能

1、效果 如下图所示,我们在输入大学时,程序会到后端查询名字中包含大学的数据,并展示到前端页面。 用户选择一个大学,该大学值会被赋值到input表单,同时关闭下拉表单; 当页面展示的数据都不符合条件时&…

四化智造MES(WEB)与金蝶云星空对接集成原材料/标准件采购查询(待采购)连通采购订单新增(其他采购订单行关闭-TEST)

四化智造MES(WEB)与金蝶云星空对接集成原材料/标准件采购查询(待采购)连通采购订单新增(其他采购订单行关闭-TEST) 数据源系统:四化智造MES(WEB) MES系统是集成生产管理、品质管理、设备管理、BI数据中心、…

AI是风口还是泡沫?

KlipC报道:狂热的人工智能追捧潮有所冷静,投资者在“上头”的追涨之后,开始回归到对基本面的关注。 KlipC的合伙人Andi D表示:“近日,有关英伟达二季度“破纪录”财报涉嫌造假的话题正在社交媒体和投资者论坛中甚嚣尘上…

MATLAB实现数据插值

目录 一.理论知识 二.一维插值实例 三.二维插值实例 一.理论知识 所谓插值,顾名思义,插入数值。很多时候,我们仅有离散点上的数据,这时如果我们想要分析变量之间的函数关系,则无法实现。但如果通过插值处理&#xf…

C#学习系列之UDP同端口收发问题

C#学习系列之UDP同端口收发问题 前言解决办法关于JoinMulticastGroup总结 前言 想测试自己的程序问题,建立了两个UDP程序,一个往端口中接到数就传出去,另一个从这个端口接数据来解析。 出现的问题是 每次打开端口,另一个程序就无…

AIGC之文本内容生成概述(下)—— GPT

GPT(GenerativePre-TrainedTransformer) 提到GPT模型,就不得不说众所周知的ChatGPT模型,ChatGPT的发展可以追溯到2018年,当时OpenAI发布了第一代GPT模型,即GPT-1,该模型采用Transformer结构和自…

Python实现机器学习(上)— 基础知识介绍及环境部署

前言:Hello大家好,我是小哥谈。本门课程将介绍人工智能相关概念,重点讲解机器学习原理机器基本算法(监督学习及非监督学习)。使用python,结合sklearn、jupyter-notebook进行编程,介绍iris、匹马…