dynamic_rnn转nn.GRU详细记录

news/2025/11/20 15:29:14/文章来源:https://www.cnblogs.com/jcchan/p/19247263

(原文发表在知乎专栏上,时间为2020年8月13日)

今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。

直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。

import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializersinput = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output)        # (batch size, time steps, features)
print(tf_state)         # (batch size, features) for the final time steps
print('\n')print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float()   # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float()                  # (batch size, features)
pytorch_output_cell = []
for i in range(1):pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float()        # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float()          # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)

最后的结果如下

input:  (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]][[0.8363702  0.85099391 0.75740424 0.30633335 0.20097122]][[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden:  (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136][0.39461795 0.29340918 0.4515729  0.6921841  0.44068605][0.89315058 0.72514622 0.2925488  0.45433305 0.59910906]]
====================  tensorflow result  ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]][[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]][[0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068][0.19730898 0.14670459 0.22578645 0.34609205 0.22034303][0.44657529 0.36257311 0.1462744  0.22716653 0.29955453]], shape=(3, 5), dtype=float64)====================  rnn cell result  ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]====================  rnn result  ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/971096.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NAS、对象存储与 JuiceFS:百亿量化基金的存储选型实践

在量化投资领域,存储系统的性能与可扩展性是支撑高效研究与计算任务的关键基础。JuiceFS 已广泛应用于多家头部百亿级量化私募机构,在回测与模型训练等核心环节中支撑高性能、低成本、可弹性扩展的存储体系。 本文将…

我踩遍了所有坑,终于搞懂了企业微信聊天记录存档!

vx: llike620 gofly.v1kf.com 作为一名技术开发者,最近我接到了一个需求:实现企业微信的聊天记录存档功能。本以为就是个简单的API调用,没想到这一脚踩进去,发现水不是一般的深。 那个藏在后台的神秘功能 事情是这…

实用指南:【Linux基础知识系列:第一百五十九篇】磁盘健康监测:smartctl

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年风机联云端批发厂家权威推荐榜单:风机物联网云平台/风机物联网/小型物联网风系统平台源头厂家精选

随着工业4.0与物联网技术的深度融合,风机行业正经历一场智能化革命。据《2024-2029年中国风机行业市场展望与投资分析报告》显示,集成云端监控功能的智能风机市场年增长率已超过25%,预计到2025年,其在工业风机中的…

CF2172H Shuffling Cards with Problem Solver 68!

首先切牌肯定有性质,但是你认为我没有脑子,建图倍增可以快速将最终序列的每个位置对应的原位置求出来。 相当于我要循环位移目前数列,使得按照给定关键字排序后字典序最小。 借用后缀排序的思路,维护一个长度的倍增…

STM32HAL库通用定时器学后笔记 - 实践

STM32HAL库通用定时器学后笔记 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Mona…

2025年手工雕刻石碑生产厂家权威推荐榜单:汉白玉墓碑/石碑/汉白玉石碑源头厂家精选

一块质朴的石头,在匠人手中被赋予生命与温度,这或许就是手工雕刻石碑的魅力所在。 在现代化机械加工普及的今天,手工雕刻石碑因其独特的艺术价值和不可复制的文化内涵,依然在市场中占据着重要地位。随着消费者对个…

2025不容错过!可燃气体报警器十大实力厂家大盘点

2025不容错过!可燃气体报警器十大实力厂家大盘点 一、引言 在工业生产和日常生活中,可燃气体的广泛使用带来了诸多便利,但同时也伴随着泄漏、爆炸等安全隐患。可燃气体报警器作为预防此类事故的关键设备,能够实时监…

记基于现有项目架构通过ai生成的一个语音助手功能开发设计文档

记基于现有项目架构通过ai生成的一个语音助手功能开发设计文档题前不得不赞叹一句有了AI的协同,实在是太高效了 📘 语音助手功能设计文档 目录系统架构概览 核心流程 翻译模式详解 内部处理机制 关键讨论点系统架构…

2025 最新推荐海外仓服务平台榜单:覆盖欧美东南亚等核心市场,美国 / 英国 / 德国 / 法国海外仓/换标 / 维修 / 检测优质服务商权威测评

引言 跨境电商行业的全球化扩张推动海外仓需求持续激增,据国际跨境物流协会(ICLA)2025 年度测评报告显示,全球海外仓服务商数量年增 37%,但服务合规率仅 62%,物流延误、库存失控等问题导致卖家平均损失率达 18%。…

Agent Dart证书验证漏洞深度解析

本文详细分析了Agent Dart库中存在的证书验证漏洞CVE-2024-48915,包括委托验证缺失canister_ranges检查和时间戳验证问题,这些安全缺陷可能导致子网冒充和证书无限期有效等严重风险。Agent Dart缺失证书验证检查 CV…

2025年北京集团法律顾问服务权威推荐榜单:私人法律顾问/高级法律顾问/社区法律顾问服务精选

在法治环境日益完善的今天,北京集团法律顾问服务市场已形成专业化、精细化的服务格局,为企业稳健经营提供着坚实的法律保障。 随着企业法律需求的多元化和复杂化,北京地区的集团法律顾问服务行业呈现出专业化分工、…

2025年螺旋输送机批发厂家权威榜单:带式输送机/链板输送机/皮带输送机设备源头厂家精选

在工业物料输送领域,螺旋输送机凭借其结构紧凑、密封性好、操作简便等优势,成为粮食加工、化工生产、矿山冶炼、环保处理等行业的关键设备。根据2024年行业数据统计,国内螺旋输送机市场规模已突破50亿元,其中管式螺…

【图像超分】论文复现:轻量化超分 | RLFN的Pytorch源码复现,跑通源码,整合到EDSR-PyTorch中进行训练、测试 - 教程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年合肥外呼系哪家好--外呼系统推荐

摘要 在数字化营销转型的浪潮下,电销外呼系统依然是金融、教育、企业客服服务等行业触达客户最直接、最高效的手段。然而,随着外呼行业规范化发展,“高频封号”、“接通率低”的问题已成为制约外呼行业增长的痛点。…

2025年四川搭建网站维护服务权威推荐:四川网站搭建平台/四川企业网站开发/四川企业官网搭建公司源头机构精选

在数字化转型浪潮下,一个稳定高效的网站已成为企业市场竞争的关键一环。 在数字经济蓬勃发展的背景下,四川省网站建设及维护服务市场呈现出快速增长态势。据相关统计,2025年四川省企业网站建设需求较去年同期增长约…

js yield Generator

// 定义Generator函数 function* simpleGenerator() {debugger;console.log(开始执行);debugger;yield 第一次暂停;debugger;console.log(恢复执行);debugger;yield 第二次暂停;debugger;console.log(结束执行);debugg…

c++11之移动构造函数

class CObject { public:CObject(string str):m_str(str) {cout << "构造函数" << endl;}CObject(const CObject& obj) {m_str = string(obj.m_str);cout << "拷贝构造函数"…

2025年高光谱成像技术应用实力榜:高光谱成像系统、高校用的高光谱相机、高校教学高光谱相机、科研机构高光谱相机、工业用高光谱相机、五家企业以产品性能与专业服务赢得市场认可

随着遥感与精准探测需求的持续增长,高光谱成像系统作为关键设备,其技术性能与数据精度成为科研与行业应用的关注焦点。在高校教学、科研机构、农业遥感、工业检测等主流应用场景中,一批具备自主研发能力与专业技术服…

《浙商》杂志|协作方能共赢,湘湖论剑网易专场对接会描绘AI人机共生新蓝图

前言:近日,网易伏羲受邀出席2025湘湖论剑“中国视谷”产业生态大会,人机协作任务平台网易有灵智能体和工程机械智能化品牌网易灵动同步亮相,近40位企业家及产学研界代表围绕行业智能化转型的方向需求和前景共同探讨…