【深度学习笔记】6_5 RNN的pytorch实现

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

6.5 循环神经网络的简洁实现

本节将使用PyTorch来更简洁地实现基于循环神经网络的语言模型。首先,我们读取周杰伦专辑歌词数据集。

import time
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

6.5.1 定义模型

PyTorch中的nn模块提供了循环神经网络的实现。下面构造一个含单隐藏层、隐藏单元个数为256的循环神经网络层rnn_layer

num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)

与上一节中实现的循环神经网络不同,这里rnn_layer的输入形状为(时间步数, 批量大小, 输入个数)。其中输入个数即one-hot向量长度(词典大小)。此外,rnn_layer作为nn.RNN实例,在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输入。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数)。而nn.RNN实例在前向计算返回的隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每一层的隐藏状态都会记录在该变量中;对于像长短期记忆(LSTM),隐藏状态是一个元组(h, c),即hidden state和cell state。我们会在本章的后面介绍长短期记忆和深度循环神经网络。关于循环神经网络(以LSTM为例)的输出,可以参考下图(图片来源)。

在这里插入图片描述

循环神经网络(以LSTM为例)的输出

来看看我们的例子,输出形状为(时间步数, 批量大小, 隐藏单元个数),隐藏状态h的形状为(层数, 批量大小, 隐藏单元个数)。

num_steps = 35
batch_size = 2
state = None
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new[0].shape)

输出:

torch.Size([35, 2, 256]) 1 torch.Size([2, 256])

如果rnn_layernn.LSTM实例,那么上面的输出是什么?

接下来我们继承Module类来定义一个完整的循环神经网络。它首先将输入数据使用one-hot向量表示后输入到rnn_layer中,然后使用全连接输出层得到输出。输出个数等于词典大小vocab_size

# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):def __init__(self, rnn_layer, vocab_size):super(RNNModel, self).__init__()self.rnn = rnn_layerself.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) self.vocab_size = vocab_sizeself.dense = nn.Linear(self.hidden_size, vocab_size)self.state = Nonedef forward(self, inputs, state): # inputs: (batch, seq_len)# 获取one-hot向量表示X = d2l.to_onehot(inputs, self.vocab_size) # X是个listY, self.state = self.rnn(torch.stack(X), state)# 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出# 形状为(num_steps * batch_size, vocab_size)output = self.dense(Y.view(-1, Y.shape[-1]))return output, self.state

6.5.2 训练模型

同上一节一样,下面定义一个预测函数。这里的实现区别在于前向计算和初始化隐藏状态的函数接口。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,char_to_idx):state = Noneoutput = [char_to_idx[prefix[0]]] # output会记录prefix加上输出for t in range(num_chars + len(prefix) - 1):X = torch.tensor([output[-1]], device=device).view(1, 1)if state is not None:if isinstance(state, tuple): # LSTM, state:(h, c)  state = (state[0].to(device), state[1].to(device))else:   state = state.to(device)(Y, state) = model(X, state)if t < len(prefix) - 1:output.append(char_to_idx[prefix[t + 1]])else:output.append(int(Y.argmax(dim=1).item()))return ''.join([idx_to_char[i] for i in output])

让我们使用权重为随机值的模型来预测一次。

model = RNNModel(rnn_layer, vocab_size).to(device)
predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)

输出:

'分开戏想暖迎凉想征凉征征'

接下来实现训练函数。算法同上一节的一样,但这里只使用了相邻采样来读取数据。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes):loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)model.to(device)state = Nonefor epoch in range(num_epochs):l_sum, n, start = 0.0, 0, time.time()data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样for X, Y in data_iter:if state is not None:# 使用detach函数从计算图分离隐藏状态, 这是为了# 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)if isinstance (state, tuple): # LSTM, state:(h, c)  state = (state[0].detach(), state[1].detach())else:   state = state.detach()(output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size)# Y的形状是(batch_size, num_steps),转置后再变成长度为# batch * num_steps 的向量,这样跟输出的行一一对应y = torch.transpose(Y, 0, 1).contiguous().view(-1)l = loss(output, y.long())optimizer.zero_grad()l.backward()# 梯度裁剪d2l.grad_clipping(model.parameters(), clipping_theta, device)optimizer.step()l_sum += l.item() * y.shape[0]n += y.shape[0]try:perplexity = math.exp(l_sum / n)except OverflowError:perplexity = float('inf')if (epoch + 1) % pred_period == 0:print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, perplexity, time.time() - start))for prefix in prefixes:print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char,char_to_idx))

使用和上一节实验中一样的超参数(除了学习率)来训练模型。

num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

输出:

epoch 50, perplexity 10.658418, time 0.05 sec- 分开始我妈  想要你 我不多 让我心到的 我妈妈 我不能再想 我不多再想 我不要再想 我不多再想 我不要- 不分开 我想要你不你 我 你不要 让我心到的 我妈人 可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的
epoch 100, perplexity 1.308539, time 0.05 sec- 分开不会痛 不要 你在黑色幽默 开始了美丽全脸的梦滴 闪烁成回忆 伤人的美丽 你的完美主义 太彻底 让我- 不分开不是我不要再想你 我不能这样牵着你的手不放开 爱可不可以简简单单没有伤害 你 靠着我的肩膀 你 在我
epoch 150, perplexity 1.070370, time 0.05 sec- 分开不能去河南嵩山 学少林跟武当 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 习武之人切记 仁者无敌- 不分开 在我会想通 是谁开没有全有开始 他心今天 一切人看 我 一口令秋软语的姑娘缓缓走过外滩 消失的 旧
epoch 200, perplexity 1.034663, time 0.05 sec- 分开不能去吗周杰伦 才离 没要你在一场悲剧 我的完美主义 太彻底 分手的话像语言暴力 我已无能为力再提起- 不分开 让我面到你 爱情来的太快就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不
epoch 250, perplexity 1.021437, time 0.05 sec- 分开 我我外的家边 你知道这 我爱不看的太  我想一个又重来不以 迷已文一只剩下回忆 让我叫带你 你你的- 不分开 我我想想和 是你听没不  我不能不想  不知不觉 你已经离开我 不知不觉 我跟了这节奏 后知后觉 

小结

  • PyTorch的nn模块提供了循环神经网络层的实现。
  • PyTorch的nn.RNN实例在前向计算后会分别返回输出和隐藏状态。该前向计算并不涉及输出层计算。

注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733534.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python操作Redis 各种数据类型

本文将深入探讨如何使用Python操作Redis&#xff0c;覆盖从基础数据类型到高级功能的广泛主题。无论是字符串、列表、散列、集合还是有序集合&#xff0c;我们将一一解析&#xff0c;同时提供丰富的代码示例帮助读者更好地理解和应用。除此之外&#xff0c;本文还将介绍Redis的…

JVM和JVM内存管理

Java虚拟机&#xff08;Java Virtual Machine 简称JVM&#xff09;是运行所有Java程序的抽象计算机&#xff0c;是Java语言的运行环境&#xff0c;它是Java 最具吸引力的特性之一。Java源代码经过编译器编译后生成与平台无关的字节码文件&#xff08;.class文件&#xff09;。当…

【20240309】WORD宏设置批量修改全部表格格式

WORD宏设置批量修改全部表格格式 引言1. 设置表格文字样式2. 设置表格边框样式3. 设置所有表格边框样式为075pt4. 删除行参考 引言 这两周已经彻底变为office工程师了&#xff0c;更准确一点应该是Word工程师&#xff0c;一篇文档动不动就成百上千页&#xff0c;表格图片也是上…

计算机视觉(CV)自然语言处理(NLP)大模型应用,如何实现小模型

在人工智能领域&#xff0c;大模型已经成为引领创新和进步的重要推动力。它们不仅在自然语言处理、计算机视觉等任务中展现了强大的性能&#xff0c;还为各行各业带来了前所未有的机遇和挑战。本文将从一个高级写作专家的角度&#xff0c;深入探讨大模型的现状、技术突破以及未…

STM32之串口中断接收UART_Start_Receive_IT

网上搜索了好多&#xff0c;都是说主函数增加UART_Receive_IT()函数来着&#xff0c;实际正确的是UART_Start_Receive_IT()函数。 —————————————————— 参考时间&#xff1a;2024年3月9日 Cube版本&#xff1a;STM32CubeMX 6.8.1版本 参考芯片&#xff1a…

Svg Flow Editor 原生svg流程图编辑器(二)

系列文章 Svg Flow Editor 原生svg流程图编辑器&#xff08;一&#xff09; 说明 这项目也是我第一次写TS代码哈&#xff0c;现在还被绕在类型中头昏脑胀&#xff0c;更新可能会慢点&#xff0c;大家见谅~ 目前实现的功能&#xff1a;1. 元件的创建、移动、形变&#xff1b;2…

完全背包问题(一般写法与空间优化写法)

题目 有 N 种物品和一个容量是 V 的背包&#xff0c;每种物品都有无限件可用。 第 i 种物品的体积是 vi&#xff0c;价值是 wi。 求解将哪些物品装入背包&#xff0c;可使这些物品的总体积不超过背包容量&#xff0c;且总价值最大。 输出最大价值。 输入格式 第一行两个整…

最多几个直角三角形python

最多几个直角三角形 问题描述思路代码实现 问题描述 最多可以组成几个直角三角形&#xff0c;一个边只能用一次。 输入描述&#xff1a; 第一行输入一个正整数T&#xff08;1<&#xff1d;T<&#xff1d;100&#xff09;&#xff0c;表示有T组测试数据. 对于每组测试数据…

贪心算法: 奶牛做题

5289. 奶牛做题 - AcWing题库 贝茜正在参加一场奶牛智力竞赛。 赛事方给每位选手发放 n 张试卷。 每张试卷包含 k 道题目&#xff0c;编号 1∼k。 已知&#xff0c;不同卷子上的相同编号题目的难度相同&#xff0c;解题时间也相同。 其中&#xff0c;解决第 i 道题&#xff08;…

【C语言】字符指针

在指针的类型中我们知道有一种指针类型为字符指针char* 一般使用&#xff1a; int main() { char ch w; char *pc &ch; *pc w; return 0; } 还有一种使用方式&#xff0c;如下&#xff1a; int main() { const char* pstr "hello bit.";//这⾥是把⼀个字…

plantUML使用指南之序列图

文章目录 前言一、序列图1.1 语法规则1.1.1 参与者1.1.2 生命线1.1.3 消息1.1.4 自动编号1.1.5 注释1.1.6 其它1.1.7 例子 1.2 如何画好 参考 前言 在软件开发、系统设计和架构文档编写过程中&#xff0c;图形化建模工具扮演着重要的角色。而 PlantUML 作为一种强大且简洁的开…

【stm32 外部中断】

中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行 中断优先级&#xff1a;当有多个中…

LoadBalancer (本地负载均衡)

1.loadbalancer本地负载均衡客户端 VS Nginx服务端负载均衡区别 Nginx是服务器负载均衡&#xff0c;客户端所有请求都会交给nginx&#xff0c;然后由nginx实现转发请求&#xff0c;即负载均衡是由服务端实现的。 loadbalancer本地负载均衡&#xff0c;在调用微服务接口时候&a…

考研复习C语言初阶(4)+标记和BFS展开的扫雷游戏

目录 1. 一维数组的创建和初始化。 1.1 数组的创建 1.2 数组的初始化 1.3 一维数组的使用 1.4 一维数组在内存中的存储 2. 二维数组的创建和初始化 2.1 二维数组的创建 2.2 二维数组的初始化 2.3 二维数组的使用 2.4 二维数组在内存中的存储 3. 数组越界 4. 冒泡…

【Java JVM】Class 文件的加载

Java 虚拟机把描述类的数据从 Class 文件加载到内存, 并对数据进行校验, 转换解析和初始化, 最终形成可以被虚拟机直接使用的 Java 类型, 这个过程被称作虚拟机的类加载机制。 与那些在编译时需要进行连接的语言不同, 在 Java 语言里面, 类的加载, 连接和初始化过程都是在程序…

Apache HBase

一、HBase简介 1、HBase定义 Apache HBase™是以hdfs为数据存储的&#xff0c;一种分布式、可扩展的NoSQL数据库。 HBase官网 Welcome to Apache HBase™ Apache HBase™ is the Hadoop database, a distributed, scalable, big data store.Use Apache HBase™ when you nee…

解决阿里云服务器开启frp服务端,内网服务器开启frp客户端却连接不上的问题

解决方法&#xff1a; 把阿里云自带的Alibabxxxxxxxlinux系统 换成centos 7系统&#xff01;&#xff01;&#xff01;&#xff01; 说一下我的过程和问题&#xff1a;由于我们内网的服务器在校外是不能连接的&#xff0c;因此我弄了个阿里云服务器做内网穿透&#xff0c;所谓…

【git】总结

一般提交流程&#xff1a; git add 或者直接加号暂存 git status 检查状态 git commit -m “提交信息” git pull git push origin HEAD:refs/for/分支 git pull 拉代码时有冲突&#xff1a; git add . 暂存所有 git stash 暂存 git pull 解决冲突 git stash apply 放出暂存…

大模型学习过程记录

一、基础知识 自然语言处理&#xff1a;能够让计算理解人类的语言。 检测计算机是否智能化的方法&#xff1a;图灵测试 自然语言处理相关基础点&#xff1a; 基础点1——词表示问题&#xff1a; 1、词表示&#xff1a;把自然语言中最基本的语言单位——词&#xff0c;将它转…

【C语言】C语言编程实战:Base64编解码算法从理论到实现(附完整代码)

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向的学习指导…