pytorch-LSTM的输入和输出尺寸

LSTM的输入和输出尺寸

CLASS torch.nn.LSTM(*args, **kwargs)

Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
For each element in the input sequence, each layer computes the following function:

对于一个输入序列实现多层长短期记忆的RNN网络,对于输入序列中的每一个元素,LSTM的每一层进行如下计算:
it=σ(Wiixt+bii+Whiht−1+bhi)ft=σ(Wifxt+bif+Whfht−1+bhf)gt=tanh⁡(Wigxt+big+Whght−1+bhg)ot=σ(Wioxt+bio+Whoht−1+bho)ct=ft⊙ct−1+it⊙gtht=ot⊙tanh⁡(ct)i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ it=σ(Wiixt+bii+Whiht1+bhi)ft=σ(Wifxt+bif+Whfht1+bhf)gt=tanh(Wigxt+big+Whght1+bhg)ot=σ(Wioxt+bio+Whoht1+bho)ct=ftct1+itgtht=ottanh(ct)
其中:

  • ht:h_t:ht时间步t的隐藏状态
  • ct:c_t:ct时间步t的细胞状态
  • xt:x_t:xt时间步t的输入
  • ht−1:h_{t-1}:ht1时间步t-1的隐藏状态或者初始化的隐藏状态(时间步0)
  • it、ft、gt:i_t、f_t、g_t:itftgt分别是输入门,遗忘门,单元门和输出门
  • σ:\sigma:σsigmoid函数
  • ⊙:\odot:Hadamard积

其中的参数:

input_size :输入的维度hidden_size:h的维度num_layers:堆叠LSTM的层数,默认值为1bias:偏置 ,默认值:Truebatch_first: 如果是True,则input为(batch, seq, input_size)。默认值为:False(seq_len, batch, input_size)bidirectional :是否双向传播,默认值为False

输入

Inputs: input, (h_0, c_0)
  • Input输入维度是(seq_len, batch, input_size),即(句子中字的数量,批量大小,每个字向量的长度)

  • h_0 的维度(num_layers * num_directions, batch, hidden_size),即(层数∗*LSTM方向数量(单向或者双向),批量大小,隐藏向量维度)

  • c_0 的维度 (num_layers * num_directions, batch, hidden_size),即(层数∗*LSTM方向数量,隐藏向量维度)

  • If (h_0, c_0) is not provided, both h_0 and c_0 default to zero,h_0和c_0的默认参数都是全0.

输出

Outputs: output, (h_n, c_n)
  • output 输出维度 (seq_len, batch, num_directions * hidden_size),即(句子中字的数量,批量大小,LSTM方向数量∗*隐藏向量维度)
  • h_n 维度 (num_layers * num_directions, batch, hidden_size)
  • c_n 维度 (num_layers * num_directions, batch, hidden_size)

举个例子

  • num_layers = 1
import torch.nn as nn
import torch
x = torch.rand(5,50,100)#(seq_len, batch, input_size)
lstm = nn.LSTM(100,20,num_layers=2)
output,(hidden,cell) = lstm(x)
print("output size:{} \nhidden size:{} \ncell size:{}".format(output.size(),hidden.size(),cell.size()))

输出:

output size:torch.Size([5, 50, 20]) 
hidden size:torch.Size([2, 50, 20]) 
cell size:torch.Size([2, 50, 20])
  • bidirecrtional = True
import torch.nn as nn
import torch
x = torch.rand(5,50,100)
lstm = nn.LSTM(100,20,bidirectional=True)
output,(hidden,cell) = lstm(x)
print("output size:{} \nhidden size:{} \ncell size:{}".format(output.size(),hidden.size(),cell.size()))

输出:

output size:torch.Size([5, 50, 40]) 
hidden size:torch.Size([2, 50, 20]) 
cell size:torch.Size([2, 50, 20])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507938.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python中的[-1]、[:-1]、[::-1]、[n::-1]

import numpy as np anp.random.rand(4) print(a)[0.48720333 0.67178384 0.65662903 0.40513918]print(a[-1]) #取最后一个元素 0.4051391774882336print(a[:-1]) #去除最后一个元素 [0.48720333 0.67178384 0.65662903]print(a[::-1]) #逆序 [0.40513918 0.65662903 0.67178…

torchtext.data.Field

torchtext.data.Field 类接口 class torchtext.data.Field(sequentialTrue, use_vocabTrue, init_tokenNone, eos_tokenNone, fix_lengthNone, dtypetorch.int64, preprocessingNone, postprocessingNone, lowerFalse, tokenizeNone, tokenizer_languageen, include_lengthsF…

np.triu

np.triu numpy.triu(m, k0) Upper triangle of an array. Return a copy of a matrix with the elements below the k-th diagonal zeroed. 返回一个矩阵的上三角矩阵,第k条对角线以下的元素归零 例如: import numpy as np np.triu(np.ones([4,4]), …

python读取json格式的超参数

python读取json格式的超参数 json文件: {"full_finetuning": true,"max_len": 180,"learning_rate": 3e-5,"weight_decay": 0.01,"clip_grad": 2,"batch_size": 30,"epoch_num": 20,"…

python缺少标准库_干货分享:Python如何自动导入缺失的库

很多同学在写Python项目时会遇到导入模块失败的情况:ImportError: No module named xxx或者ModuleNotFoundError: No module named xxx。导入模块失败通常分为两种:一种是导入自己写的模块(即以 .py 为后缀的文件),另一种是导入三方库。接下来…

.val()数据乱码_【目标检测数据集】PASCAL VOC制作

【VOC20072012】数据集地址:https://pjreddie.com/projects/pascal-voc-dataset-mirror/PASCAL VOC为图像识别和分类提供了一整套标准化的优秀的数据集,用于构建和评估用于图像分类(Classification),检测(O…

pytorch-多GPU训练(单机多卡、多机多卡)

pytorch-多GPU训练(单机多卡、多机多卡) pytorch 单机多卡训练 首先是数据集的分布处理 需要用到的包: torch.utils.data.distributed.DistributedSampler torch.utils.data.DataLoader torch.utils.data.Dataset DistributedSampler这个…

机器人 铑元素_智能机器人 三十三

福里斯特茫然不知所措地从窃窃私语的黑暗中转过身来,沉重的失败感犹如巨怪压得他喘不过气来。他顺从地跛着脚走向笼内角落里的小浴室,冲着微笑着的木偶刚才经过的地方点点头,漫不经意地问道:“你们是如何抓住他们的?”…

character-level OCR之Character Region Awareness for Text Detection(CRAFT) 论文阅读

Character Region Awareness for Text Detection 论文阅读 论文地址(arXiv) ,pytorch版本代码地址 最近在看一些OCR的问题,CRAFT是在场景OCR中效果比较好的模型,记录一下论文的阅读 已有的文本检测工作大致如下: 基于回归的文…

c# wpf 面试_【远程面试】九强通信 | 九洲电器集团全资子公司

成都IT内推圈成立于2016年,专注成都IT互联网领域的招聘与求职;覆盖精准IT人群10W,通过内推圈推荐且已入职人数超过5000,合作公司均系成都知名或靠谱公司.此公众号每天7:30AM准时推送当天职位详情,敬请关注并置顶!岗位投递一、登陆内推圈官网: www.itneituiquan.com,…

ViT(Vision Transformer)学习

ViT(Vison Transformer)学习 Paper:An image is worth 1616 words: transformers for image recognition at scale. In ICLR, 2021. Transformer 在 NLP领域大放异彩,并且随着模型和数据集的不断增长,仍然没有表现出饱和的迹象。这使得使用更大规模的数…

mysql php宝塔 root_[转载]在安卓中安装宝塔面板运行PHP+MySQL

手机上的操作我用的手机是小米10pro,其他手机应该也能用相同的方法安装成功。安装Linux Deploy,然后给它root权限。点击左上角的菜单按钮。点击号,创建一个名为debian的配置文件。如果已经有了名为debian的配置文件,选择它即可。返…

cpri带宽不足的解决方法_u盘容量不足怎么办 u盘容量不足解决方法【介绍】

我们在使用u盘的时候总能碰到各种各样的问题,其中u盘容量不足问题也是神烦,很多时候打开并没有发现有文件存在,但是在你存文件的时候又被提示u盘容量不足无法操作,关于这个问题u启动通过整理和大家一起分享下解决办法。1、u盘里的…

(python numpy) np.array.shape 中 (3,)、(3,1)、(1,3)的区别

(python numpy) np.array.shape 中 (3,)、(3,1)、(1,3)的区别 被人问到这个问题,就记录一下吧 1. (3,) (3,)是[x,y,z][x,y,z][x,y,z]的形式,即为一维数组,访问数组元素用一个index for example: >>> array1 np.array([1,2,3]) …

复合的赋值运算符例题_Java学习:运算符的使用与注意事项

运算符的使用与注意事项四则运算当中的加号“”有常见的三种用法:对于数值来,那就是加法。对于字符char类型来说,在计算之前,char会被提升成为int,然后再计算。char类型字符,和int类型数字之间的对照关系比…

腾讯会议如何使用讲演者模式进行汇报(nian gao)

腾讯会议如何使用讲演者模式进行汇报(nian gao) 首先列出步骤,再一一演示: altf5 开启讲演者模式,调整讲演者模式的窗口为小窗alttab 切换回腾讯会议界面,屏幕共享power point窗口(注意不是“…

bulk这个词的用法_15、形容词与副词(二)比较的用法

初中英语语法——形容词与副词(二)比较的用法语法解释1、形容词与副词比较级和最高级的规则变化单音节词与部分双音节词:(1)一般情况加-er,-estlong-longer-longest strong-stronger-strongestclean-cleaner-cleanest(2)以不发音的e结尾的词,…

pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题

首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可…

retinex 的水下图像增强算法_图像增强论文:腾讯优图CVPR2019

Underexposed Photo Enhancement using Deep Illumination Estimation基于深度学习优化光照的暗光下的图像增强论文地址:Underexposed Photo Enhancement using Deep Illumination Estimation暗光拍照也清晰,这是手机厂商目前激烈竞争的新拍照目标。提出…

python 实现 BCH 纠错码的方法

python 实现 BCH 纠错码的方法 BCH码是一类重要的纠错码,它把信源待发的信息序列按固定的κ位一组划分成消息组,再将每一消息组独立变换成长为n(n>κ)的二进制数字组,称为码字。如果消息组的数目为M(显然M>2),由此所获得的M个码字的全…