(原文发表在知乎专栏上,时间为2020年8月13日)
今天在将一份tensorflow的代码转为pytorch时遇到的一点困难,经过多次debug以后终于弄清楚了这里应该是如何进行转换的,因此记录下来。
直接上代码吧,为了确保最终的结果是一致的,这里我将网络层的权重全部初始化为0。
import torch
import torch.nn as nn
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializersinput = np.random.rand(3, 1, 5)
hidden = np.random.rand(3, 5)print("input: ", input.shape)
print(input)
print("hidden: ", hidden.shape)
print(hidden)print("="*20, ' tensorflow result ', "="*20)
# cell with zeros initializer
cell = tf.compat.v1.nn.rnn_cell.GRUCell(5, kernel_initializer=initializers.Zeros(), bias_initializer=initializers.Zeros())
tf_output, tf_state = tf.compat.v1.nn.dynamic_rnn(cell, input, initial_state=hidden)
print(tf_output) # (batch size, time steps, features)
print(tf_state) # (batch size, features) for the final time steps
print('\n')print("="*20, ' rnn cell result ', "="*20)
# rnn cell
pytorch_rnn_cell = nn.GRUCell(5, 5)
for k, v in pytorch_rnn_cell.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input_cell = torch.from_numpy(input).permute(1, 0, 2).float() # (time steps, batch size, features)
pytorch_hidden_cell = torch.from_numpy(hidden).float() # (batch size, features)
pytorch_output_cell = []
for i in range(1):pytorch_hidden_cell = pytorch_rnn_cell(pytorch_input_cell[i], pytorch_hidden_cell)pytorch_output_cell.append(pytorch_hidden_cell)
print(pytorch_output_cell)
print('\n')print("="*20, ' rnn result ', "="*20)
# rnn
pytorch_rnn = nn.GRU(5, 5)
for k, v in pytorch_rnn.state_dict().items():torch.nn.init.constant_(v, 0)
pytorch_input = torch.from_numpy(input).permute(1, 0, 2).float() # (time steps, batch size, feature size)
pytorch_hidden = torch.from_numpy(hidden).unsqueeze(0).float() # (time steps, batch size, hidden size)
pytorch_output, pytorch_state = pytorch_rnn(pytorch_input, pytorch_hidden)
print(pytorch_output, pytorch_output.shape)
print(pytorch_state, pytorch_state.shape)
最后的结果如下
input: (3, 1, 5)
[[[0.98175333 0.59281082 0.47678967 0.70612923 0.73616147]][[0.8363702 0.85099391 0.75740424 0.30633335 0.20097122]][[0.60316062 0.21921029 0.16052985 0.25654177 0.40698399]]]
hidden: (3, 5)
[[0.46976021 0.19681885 0.59240364 0.79540728 0.27608136][0.39461795 0.29340918 0.4515729 0.6921841 0.44068605][0.89315058 0.72514622 0.2925488 0.45433305 0.59910906]]
==================== tensorflow result ====================
tf.Tensor(
[[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068]][[0.19730898 0.14670459 0.22578645 0.34609205 0.22034303]][[0.44657529 0.36257311 0.1462744 0.22716653 0.29955453]]], shape=(3, 1, 5), dtype=float64)
tf.Tensor(
[[0.23488011 0.09840942 0.29620182 0.39770364 0.13804068][0.19730898 0.14670459 0.22578645 0.34609205 0.22034303][0.44657529 0.36257311 0.1462744 0.22716653 0.29955453]], shape=(3, 5), dtype=float64)==================== rnn cell result ====================
[tensor([[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]], grad_fn=<AddBackward0>)]==================== rnn result ====================
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])
tensor([[[0.2349, 0.0984, 0.2962, 0.3977, 0.1380],[0.1973, 0.1467, 0.2258, 0.3461, 0.2203],[0.4466, 0.3626, 0.1463, 0.2272, 0.2996]]], grad_fn=<StackBackward>) torch.Size([1, 3, 5])Process finished with exit code 0