Pytorch intermediate(四) Language Model (RNN-LM)

       前一篇中介绍了一种双向的递归神经网络,将数据进行正序输入和倒序输入,兼顾向前的语义以及向后的语义,从而达到更好的分类效果。

       之前的两篇使用递归神经网络做的是分类,可以发现做分类时我们不需要使用时序输入过程中产生的输出,只需关注每个时序输入产生隐藏信息,最后一个时序产生的输出即最后的输出。

       这里将会介绍语言模型,这个模型中我们需要重点关注的是每个时序输入过程中产生的输出。可以理解为,我输入a,那么我需要知道这个时序的输出是不是b,如果不是那么我就要调整模型了。


import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils import clip_grad_norm_
from data_utils import Dictionary, Corpusdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')embed_size = 128 
hidden_size = 1024 
num_layers = 1
num_epochs = 5 
num_samples = 1000 
batch_size = 20 
seq_length = 30 
learning_rate = 0.002 corpus = Corpus()
ids = corpus.get_data('data/train.txt', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_lengthprint(ids.size())
print(vocab_size)
print(num_batches)#torch.Size([20, 46479])
#10000
#1549

参数解释

1、ids:从train.txt中获取的训练数据,总共为20条,下面的模型只对这20条数据进行训练。

2、vocab_size:词库,总共包含有10000个单词

3、num_batch:可能有人要问前面有batch_size,这里的num_batch是干嘛用的?前面的batch_size是从语料库中抽取20条,每条数据长度为46497,除以序列长度seq_length(输入时序为30),个num_batch可以理解为是输入时序块的个数,也就是一个epoch中我们将所有语料输入网络需要循环的次数。


模型构建

模型很简单,但是参数比较难理解,这里在讲流程的时候依旧对参数进行解释。

1、Embedding层:保存了固定字典和大小的简单查找表,第一个参数是嵌入字典的大小,第二个是每个嵌入向量的大小。也就是说,每个时间序列的特征都被转化成128维的向量。假设一个序列维[20, 30],经过嵌入会变成[20, 30, 128]

2、LSTM层:3个重要参数,输入维度即为嵌入向量大小embed_size = 128,隐藏层神经元个数hidden_size = 1024,lstm单元个数num_layers = 1

3、LSTM的输出结果out中包含了30个时间序列的所有隐藏层输出,这里不仅仅只用最后一层了,要用到所有层的输出。

4、线性激活层:LSTM的隐藏层有1024个特征,要把这1024个特征通过全连接组合成我们词库特征10000,得到的就是这10000个词被选中的概率了。

class RNNLM(nn.Module):def __init__(self,vocab_size,embed_size,hidden_size,num_layers):super(RNNLM,self).__init__()#parameters - 1、嵌入字典的大小  2、每个嵌入向量的大小self.embed = nn.Embedding(vocab_size,embed_size)self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first = True)self.linear = nn.Linear(hidden_size, vocab_size)def forward(self, x, h):#转化为词向量x = self.embed(x)  #x.shape = torch.Size([20, 30, 128])#分成30个时序,在训练的过程中的循环中体现out,(h,c) = self.lstm(x,h)  #out.shape = torch.Size([20, 30, 1024])#out中保存每个时序的输出,这里不仅仅要用最后一个时序,要用上一层的输出和下一层的输入做对比,计算损失out = out.reshape(out.size(0) * out.size(1), out.size(2))   #输出10000是因为字典中存在10000个单词out = self.linear(out)   #out.shape = torch.Size([600, 10000])return out,(h,c)

实例化模型

向前传播时,我们需要输入两个参数,分别是数据x,h0和c0。每个epoch都要将h0和c0重新初始化。

可以看到在训练之前对输入数据做了一些处理。每次取出长度为30的序列输入,相应的依次向后取一位做为target,这是因为我们的目标就是让每个序列输出的值和下一个字符项相近似。

输出的维度为(600, 10000),将target维度进行转化,计算交叉熵时会自动独热处理。

反向传播过程,防止梯度爆炸,进行了梯度修剪。

model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)def detach(states):return [state.detach() for state in states] 
for epoch in range(num_epochs):# Set initial hidden and cell statesstates = (torch.zeros(num_layers, batch_size, hidden_size).to(device),torch.zeros(num_layers, batch_size, hidden_size).to(device))for i in range(0, ids.size(1) - seq_length, seq_length):# Get mini-batch inputs and targetsinputs = ids[:, i:i+seq_length].to(device)          #input torch.Size([20, 30])targets = ids[:, (i+1):(i+1)+seq_length].to(device) #target torch.Size([20, 30])# Forward passstates = detach(states)#用前一层输出和下一层输入计算损失outputs, states = model(inputs, states)             #output torch.Size([600, 10000])loss = criterion(outputs, targets.reshape(-1))# Backward and optimizemodel.zero_grad()loss.backward()clip_grad_norm_(model.parameters(), 0.5)            #梯度修剪optimizer.step()step = (i+1) // seq_lengthif step % 100 == 0:print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'.format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

测试模型 

  测试时随机选择一个词作为输入,因为没有一个停止的标准,所以我们需要利用循环来控制到底输出多少个字符。

输入维度[1, 1],我们之前的输入是[20, 30]。

本来有一种想法:我们现在只有一个时序了,但是我们的训练时有30个时序,那么还有什么意义?忽然想起来我们训练的参数是公用的!!!所以只要输入一个数据就能预测下面的数据了,并不要所谓的30层。

这里的初始输入是1,那么能不能是2呢?或者是根据我们之前的输入取预测新的字符?其实是可以的,但是由于初始化h0和c0的问题,我们更改了输入的长度,相应的h0和c0也要改变的。

我们最后的输出结果需要转化成为概率,然后随机抽取

# Test the model
with torch.no_grad():with open('sample.txt', 'w') as f:# Set intial hidden ane cell statesstate = (torch.zeros(num_layers, 1, hidden_size).to(device),torch.zeros(num_layers, 1, hidden_size).to(device))# Select one word id randomlyprob = torch.ones(vocab_size)input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)for i in range(num_samples):# Forward propagate RNN output, state = model(input, state)   #output.shape = torch.Size([1, 10000])# Sample a word idprob = output.exp()word_id = torch.multinomial(prob, num_samples=1).item()   #根据输出的概率随机采样# Fill input with sampled word id for the next time stepinput.fill_(word_id)# File writeword = corpus.dictionary.idx2word[word_id]word = '\n' if word == '<eos>' else word + ' 'f.write(word)if (i+1) % 100 == 0:print('Sampled [{}/{}] words and save to {}'.format(i+1, num_samples, 'sample.txt'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/79019.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SSM - Springboot - MyBatis-Plus 全栈体系(八)

第二章 SpringFramework 四、SpringIoC 实践和应用 4. 基于 配置类 方式管理 Bean 4.4 实验三&#xff1a;高级特性&#xff1a;Bean 注解细节 4.4.1 Bean 生成 BeanName 问题 Bean 注解源码&#xff1a; public interface Bean {//前两个注解可以指定Bean的标识AliasFor…

思科的简易配置

vlan 划分配置 1. 拓扑连接 2. 终端设备配置&#xff0c;vlan(v2, v3)配置&#xff0c;模式设置 然后设置交换机 fa 0/5 口为 trunk 模式&#xff0c;使得不同交换机同一 vlan 下 PC 可以互连 3.测试配置结果 用 ip 地址为 192.168.1.1 的主机(PC0)向同一 vlan(v2)下的 192.…

Binder进程通信基础使用

Binder 进程通信基础使用 一、服务端进程创建 Service&#xff0c;Service 中创建 Binder 子类对象并于 onBind 中返回。xml 定义。 创建 Service&#xff0c;创建 Binder 子类对象并于 onBind 返回 class UserService : Service() {private companion object {const val TAG…

BI与数据治理以及数据仓库有什么区别

你可能已经听说过BI、数据治理和数据仓库这些术语&#xff0c;它们在现代企业中起着重要的作用。虽然它们都与数据相关&#xff0c;但它们之间有着明显的区别和各自独特的功能。数聚将详细探讨BI&#xff08;商业智能&#xff09;、数据治理和数据仓库之间的区别&#xff0c;帮…

如何统计iOS产品不同渠道的下载量?

一、前言 在开发过程中&#xff0c;Android可能会打出来很多的包&#xff0c;用于标识不同的商店下载量。原来觉得苹果只有一个商店&#xff1a;AppStore&#xff0c;如何做出不同来源的统计呢&#xff1f;本篇文章就是告诉大家如何做不同渠道来源统计。 二、正文 先看一下苹…

云智研发公司面试真题

1.静态方法可以被重写吗 静态方法不能被重写。静态方法是属于类的&#xff0c;而不是属于实例的。当子类继承一个父类时&#xff0c;子类会继承父类的静态方法&#xff0c;但是子类不能重写父类的静态方法。如果子类定义了一个与父类静态方法同名的静态方法&#xff0c;那么它…

算法——快乐数

202. 快乐数 - 力扣&#xff08;LeetCode&#xff09; 由图可知&#xff0c;其实这也是一个判断循环的过程&#xff0c;要用到快慢指针&#xff0c;且相遇后&#xff0c;若在全为1的循环里&#xff0c;那么就是快乐数&#xff0c;若相遇后不为1&#xff0c;说明这不是快乐数。 …

备份数据重删

重复数据删除&#xff1a; 在计算中&#xff0c;重复数据删除是一种消除重复数据重复副本的技术。此技术用于提高存储利用率&#xff0c;还可以应用于网络数据传输以减少必须发送的字节数。在重复数据删除过程中&#xff0c;将在分析过程中识别并存储唯一的数据块或字节模式。…

MySQL入门教程

MySQL 是最流行的关系型数据库管理系统 1、什么是数据库&#xff1f; 数据库&#xff08;Database&#xff09;是按照数据结构来组织、存储和管理数据的仓库。 每个数据库都有一个或多个不同的 API 用于创建&#xff0c;访问&#xff0c;管理&#xff0c;搜索和复制所保存的…

HAlcon例子

气泡思想 * This example shows the use of the operator dyn_threshold for * the segmentation of the raised dots of braille chharacters. * The operator dyn_threshold is especially usefull if the * background is inhomogeneously illuminated. In this example, *…

vue3的生命周期

1.vue3生命周期官方流程图 2.vue3中的选项式生命周期 vue3中的选项式生命周期钩子基本与vue2中的大体相同&#xff0c;它们都是定义在 vue实例的对象参数中的函数&#xff0c;它们在vue中实例的生命周期的不同阶段被调用。生命周期函数钩子会在我们的实例挂载&#xff0c;更新…

竞赛 基于机器视觉的火车票识别系统

文章目录 0 前言1 课题意义课题难点&#xff1a; 2 实现方法2.1 图像预处理2.2 字符分割2.3 字符识别部分实现代码 3 实现效果最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于机器视觉的火车票识别系统 该项目较为新颖&#xff0c;适合作为竞赛…

23下半年学习计划

大二上学期计划 现在已经是大二了&#xff0c;java只学了些皮毛&#xff0c;要学的知识还有很多&#xff0c;新的学期要找准方向&#xff0c;把要学的知识罗列&#xff0c;按部就班地完成计划&#xff0c;合理安排时间&#xff0c;按时完成学习任务。 学习node.js&#xff0c…

element-ui《input》输入框效验

目录 常用的 element-ui el-input 输入框 1. 过滤字母e&#xff0c; 2. 只能输入正整数 3. 只允许输入数字和小数 / 数字和空格 4. 只允许输入正整数且不能以0开头 4. 允许输入小数点后几位 5. 设置范围&#xff0c;最大值&#xff0c;最小值 6. form 表单中校验输入框只能…

VUE写后台管理(2)

VUE写后台管理&#xff08;2&#xff09; 1.环境2.Element界面3.Vue-Router路由后台1.左导航栏2.上面导航条 1.环境 1.下载管理node版本的工具nvm&#xff08;Node Version Manager&#xff09; 2.安装node(vue工程的环境管理工具)&#xff1a;nvm install 16.13.0 3.安装vue工…

JS for...in 和 for...of 的区别?

for...in 和for ...of的区别&#xff1f; 1. 前言2. for...in3. for...of4&#xff0c;区别5. 总结&#xff1a; 1. 前言 for...in和for...of都是JavaScript中遍历数据的方法&#xff0c;让我们来了解一下他们的区别。 2. for…in for…in是为遍历对象属性而构建的&#xff0…

运维学习之部署Grafana

sudo nohup wget https://dl.grafana.com/oss/release/grafana-10.1.1.linux-amd64.tar.gz &后台下载压缩包&#xff0c;然后按一下回车键。 ps -aux | grep 15358发现有两条记录&#xff0c;就是还在下载中。 ps -aux | grep 15358发现有一条记录&#xff0c;并且tail …

CAS(compare and swa)中的ABA问题及解决

CAS(compare and swap) CAS是&#xff08;compare and swap&#xff09;的缩写&#xff0c;字面意思是比较交换。CAS锁通常也是实现乐观锁的一种机制&#xff0c;首先会给它一个期望值&#xff0c;用期望值与老值做比较&#xff0c;如果相等就用新传入的值进行修改。但是CAS通常…

一百七十八、ClickHouse——海豚调度执行ClickHouse的.sql文件

一、目的 由于数仓的ADS层是在ClickHouse中&#xff0c;即把Hive中DWS层的结果数据同步到ClickHouse中&#xff0c;因此需要在ClickHouse中建表&#xff0c;于是需要海豚调度执行ClickHouse的.sql文件 二、实施步骤 &#xff08;一&#xff09;第一步&#xff0c;海豚建立Cl…

Python in Visual Studio Code 2023年9月更新

作者&#xff1a;Courtney Webster - Program Manager, Python Extension in Visual Studio Code 排版&#xff1a;Alan Wang 我们很高兴地宣布 Visual Studio Code 的 Python 和 Jupyter 扩展将于 2023 年 9 月发布&#xff01; 此版本包括以下内容&#xff1a; • 将 Python …