【循环神经网络5】GRU模型实战,从零开始构建文本生成器 - 详解

news/2025/10/20 14:02:36/文章来源:https://www.cnblogs.com/yxysuanfa/p/19152502

【循环神经网络3】门控循环单元GRU详解-CSDN博客https://blog.csdn.net/colus_SEU/article/details/152218119?spm=1001.2014.3001.5501在以上▲笔记中,我们详解了GRU模型,接下来我们通过实战来进一步理解。

1 项目概述

本项目旨在构建一个字符级的语言模型,使用PyTorch框架实现GRU(门控循环单元)网络,来学习莎士比亚戏剧文本的语言模式,并生成具有莎士比亚风格的新文本。

  • 核心任务:文本生成。

  • 数据集:Kaggle上的莎士比亚戏剧数据集,包含超过11万行角色台词,下载地址:Shakespeare plays(直接点击download下载zip文件解压后将文件夹中的Shakespeare_data.csv文件放到项目目录指定位置即可)

  • 最终成果:一个能够生成语法基本正确、词汇风格贴近莎士比亚戏剧的文本的GRU模型。

2 项目目录

 GRU_shakespeare_sonnets/
 ├── data/
 │   └── raw/
 │       └── Shakespeare_data.csv
 ├── models/
 │   └── gru_model.pth
 ├── src/
 │   ├── data_loader.py   # 数据下载、预处理和加载
 │   ├── model.py         # GRU模型定义
 │   ├── train.py         # 训练脚本
 │   └── generate.py      # 文本生成脚本
 ├── vocab.pkl            # 保存的词汇表
 └── requirements.txt     # 项目依赖

3 项目代码

 # src/model.py
 import torch
 import torch.nn as nn
 ​
 class GRUModel(nn.Module):
     def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2):
         '''
         :param vocab_size: 词汇表的大小,即模型中使用的不同单词的数量。
         :param embed_dim: 词嵌入的维度,默认为256。词嵌入是将单词转换为固定长度向量的过程,这些向量捕获单词的语义信息。
         :param hidden_dim: GRU隐藏层的维度,默认为512。这决定了模型的学习能力和复杂性。
         :param num_layers: GRU层的数量,默认为2。多层GRU可以捕获更复杂的序列模式。
         '''
         super(GRUModel, self).__init__()
         self.hidden_dim = hidden_dim
         self.num_layers = num_layers
 ​
         # 词嵌入层
         self.embedding = nn.Embedding(vocab_size, embed_dim)
         # 创建一个词嵌入层,使用nn.Embedding类。这个层将词汇表中的每个单词映射到一个embed_dim维的向量。
 ​
         # GRU层
         self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
         # 创建一个GRU层,使用nn.GRU类。
         # batch_first=True:表示输入数据的第一个维度是批量大小,这使得数据组织更加直观。
 ​
         # 全连接输出层
         self.fc = nn.Linear(hidden_dim, vocab_size)
         # 这个层将GRU的输出映射到词汇表的大小,从而预测下一个单词。hidden_dim是输入尺寸,vocab_size是输出尺寸。
 ​
 ​
     def forward(self, x, hidden):
         # x shape: (batch_size, seq_length)
         # hidden shape: (num_layers, batch_size, hidden_dim)
 ​
         # 嵌入层
         embedded = self.embedding(x)  # shape: (batch_size, seq_length, embed_dim)
 ​
         # GRU层
         # out shape: (batch_size, seq_length, hidden_dim)
         # hidden shape: (num_layers, batch_size, hidden_dim)
         out, hidden = self.gru(embedded, hidden)
 ​
         # 将GRU的输出传入全连接层
         # 我们需要将out重塑以便通过fc层
         out = out.contiguous().view(-1, self.hidden_dim)  # shape: (batch_size * seq_length, hidden_dim)
         out = self.fc(out)  # shape: (batch_size * seq_length, vocab_size)
 ​
         return out, hidden
 ​
     def init_hidden(self, batch_size, device):
         """初始化隐藏状态"""
         # 形状: (num_layers, batch_size, hidden_dim)
         weight = next(self.parameters()).data
         hidden = weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(device)
         return hidden
 # src/train.py
 ​
 import torch
 import torch.nn as nn
 import torch.optim as optim
 import os
 import pickle
 from tqdm import tqdm
 ​
 from data_loader import get_data_loaders, PROJECT_ROOT  # 导入项目根路径
 from model import GRUModel
 ​
 ​
 # --- 超参数配置 ---
 class Config:
     # 路径配置 (使用动态路径)
     MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
     VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
     MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "gru_model.pth")
 ​
     BATCH_SIZE = 128
     SEQUENCE_LENGTH = 100
     N_EPOCHS = 20
     LEARNING_RATE = 0.0005
     EMBED_DIM = 256
     HIDDEN_DIM = 512
     N_LAYERS = 2
 ​
     # 创建模型目录
     if not os.path.exists(MODEL_DIR):
         os.makedirs(MODEL_DIR)
 ​
 ​
 def train():
     """主训练函数"""
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     print(f"使用设备: {device}")
 ​
     train_loader, val_loader, vocab_size = get_data_loaders(
         batch_size=Config.BATCH_SIZE,
         sequence_length=Config.SEQUENCE_LENGTH
     )
 ​
     model = GRUModel(
         vocab_size=vocab_size,
         embed_dim=Config.EMBED_DIM,
         hidden_dim=Config.HIDDEN_DIM,
         num_layers=Config.N_LAYERS
     ).to(device)
 ​
     criterion = nn.CrossEntropyLoss()
     optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
 ​
     print("模型训练开始...")
     for epoch in range(Config.N_EPOCHS):
         model.train()
         total_loss = 0
 ​
         progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{Config.N_EPOCHS}")
         for i, (inputs, targets) in enumerate(progress_bar):
             # 将数据移动到指定设备
             inputs, targets = inputs.to(device), targets.to(device)
 ​
             # --- 关键修改 ---
             # 在每个批次开始时,根据当前批次的大小重新初始化隐藏状态
             current_batch_size = inputs.size(0)
             hidden = model.init_hidden(current_batch_size, device)
 ​
             # --- 关键修改:添加梯度裁剪 ---
             # 防止梯度爆炸,将梯度的L2范数限制在5.0以内
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
 ​
             # 梯度清零
             optimizer.zero_grad()
 ​
             # 前向传播
             output, hidden = model(inputs, hidden)
 ​
             # 将 targets 展平,以便与 output 的维度匹配
             targets = targets.view(-1)
 ​
             # 计算损失
             loss = criterion(output, targets)
 ​
             # 反向传播和优化
             loss.backward()
             optimizer.step()
 ​
             total_loss += loss.item()
             progress_bar.set_postfix(loss=loss.item())
 ​
         avg_loss = total_loss / len(train_loader)
         print(f"Epoch {epoch + 1} 完成, 平均训练损失: {avg_loss:.4f}")
 ​
         torch.save({
             'epoch': epoch,
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': avg_loss,
             'config': Config
         }, Config.MODEL_SAVE_PATH)
         print(f"模型已保存至 {Config.MODEL_SAVE_PATH}")
 ​
     print("模型训练完成!")
 ​
 ​
 if __name__ == '__main__':
     train()
# src/generate.py
 ​
 import torch
 import pickle
 import os
 import random
 ​
 from model import GRUModel
 from train import Config, PROJECT_ROOT  # 导入项目根路径
 ​
 ​
 def generate_text(start_str="shall i compare thee to a summer's day?\n", gen_length=500, temperature=0.8):
     """使用训练好的模型生成文本"""
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 ​
     # 路径配置 (使用动态路径)
     VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
     MODEL_SAVE_PATH = Config.MODEL_SAVE_PATH
 ​
     # 1. 加载词汇表
     try:
         with open(VOCAB_FILE, 'rb') as f:
             char_to_ix, ix_to_char, vocab_size = pickle.load(f)
     except FileNotFoundError:
         print(f"错误: 词汇表文件 {VOCAB_FILE} 未找到。请先运行训练脚本。")
         return
 ​
     # 2. 初始化模型
     model = GRUModel(
         vocab_size=vocab_size,
         embed_dim=Config.EMBED_DIM,
         hidden_dim=Config.HIDDEN_DIM,
         num_layers=Config.N_LAYERS
     ).to(device)
 ​
     # 3. 加载模型权重
     try:
         checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
         model.load_state_dict(checkpoint['model_state_dict'])
         print(f"成功加载模型权重: {MODEL_SAVE_PATH}")
     except FileNotFoundError:
         print(f"错误: 模型文件 {MODEL_SAVE_PATH} 未找到。请先运行训练脚本。")
         return
 ​
     model.eval()
 ​
     input_seq = [char_to_ix[ch] for ch in start_str]
     input_tensor = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
     hidden = model.init_hidden(1, device)
 ​
     generated_text = start_str
     with torch.no_grad():
         for _ in range(gen_length):
             output, hidden = model(input_tensor, hidden)
             output = output.squeeze(0).div(temperature).exp()
             top_i = torch.multinomial(output, 1)[0]
             predicted_char = ix_to_char[top_i.item()]
             generated_text += predicted_char
             input_tensor = torch.tensor([[top_i]], dtype=torch.long).to(device)
 ​
     print("\n--- 生成的文本 ---")
     print(generated_text)
     print("------------------\n")
 ​
 ​
 if __name__ == '__main__':
     generate_text(start_str="from fairest creatures we desire increase,\n", gen_length=400, temperature=0.8)
 # src/data_loader.pyimport os
 import pandas as pd
 import torch
 from torch.utils.data import Dataset, DataLoader
 import pickle
 import sys
 ​
 # --- 路径配置 (本地优化) ---
 # 获取当前脚本所在的目录
 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 # 获取项目根目录
 PROJECT_ROOT = os.path.dirname(CURRENT_DIR)
 DATA_DIR = os.path.join(PROJECT_ROOT, "data")
 RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")
 # 修改为你实际的文件名
 CSV_FILE = os.path.join(RAW_DATA_DIR, "Shakespeare_data.csv")
 VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
 ​
 ​
 def check_data_exists():
     """检查数据文件是否存在,如果不存在则给出提示并退出"""
     if not os.path.exists(CSV_FILE):
         print("=" * 50)
         print("错误:数据文件未找到!")
         print(f"请确保数据文件已放置在以下位置:\n{CSV_FILE}")
         print("\n请按以下步骤操作:")
         print("1. 访问 https://www.kaggle.com/datasets/kingburrito666/shakespeare-sonnets")
         print("2. 点击 'Download' 按钮。")
         print("3. 解压下载的 zip 文件。")
         print("4. 将 'Shakespeare_data.csv' 文件复制到 'data/raw/' 目录下。")
         print("=" * 50)
         sys.exit(1)  # 退出程序
 ​
 ​
 def prepare_data(sequence_length=100):
     """加载、预处理数据并创建DataLoader"""
     # 1. 检查数据文件
     check_data_exists()
 ​
     # 2. 加载数据
     print(f"正在从 {CSV_FILE} 加载数据...")
     df = pd.read_csv(CSV_FILE)
 ​
     # 3. 数据清洗:过滤掉舞台说明,只保留角色台词
     player_lines_df = df.dropna(subset=['Player'])
     print(f"数据加载完成,共加载了 {len(player_lines_df)} 行台词。")
 ​
     # 4. 合并所有台词为一个长字符串
     text = player_lines_df['PlayerLine'].str.cat(sep='\n')
 ​
     # 5. 创建字符集
     chars = sorted(list(set(text)))
     vocab_size = len(chars)
 ​
     char_to_ix = {ch: i for i, ch in enumerate(chars)}
     ix_to_char = {i: ch for i, ch in enumerate(chars)}
 ​
     # 保存词汇表
     with open(VOCAB_FILE, 'wb') as f:
         pickle.dump((char_to_ix, ix_to_char, vocab_size), f)
     print(f"词汇表已保存至 {VOCAB_FILE},词汇量大小: {vocab_size}")
 ​
     # 6. 将整个文本转换为整数序列
     text_as_int = [char_to_ix[ch] for ch in text]
 ​
     # 7. 创建输入序列和目标序列 (关键修改!)
     input_sequences = []
     target_sequences = []
     for i in range(0, len(text_as_int) - sequence_length):
         # 输入是从 i 到 i+sequence_length-1
         input_sequences.append(text_as_int[i: i + sequence_length])
         # 目标是从 i+1 到 i+sequence_length (向左移动一位)
         target_sequences.append(text_as_int[i + 1: i + sequence_length + 1])
 ​
     print(f"总序列数: {len(input_sequences)}")
 ​
     # 8. 创建自定义Dataset (关键修改!)
     class ShakespeareDataset(Dataset):
         def __init__(self, sequences, targets):
             self.sequences = sequences
             self.targets = targets
 ​
         def __len__(self):
             return len(self.sequences)
 ​
         def __getitem__(self, idx):
             # 返回一个输入序列和对应的目标序列
             return torch.tensor(self.sequences[idx], dtype=torch.long), torch.tensor(self.targets[idx],
                                                                                      dtype=torch.long)
 ​
     dataset = ShakespeareDataset(input_sequences, target_sequences)
 ​
     return dataset, vocab_size
 ​
 ​
 def get_data_loaders(batch_size=64, sequence_length=100):
     """获取训练和验证DataLoader"""
     dataset, vocab_size = prepare_data(sequence_length)
 ​
     train_size = int(0.8 * len(dataset))
     val_size = len(dataset) - train_size
     train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
 ​
     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
 ​
     return train_loader, val_loader, vocab_size
 ​

4 结果分析

模型训练

运行train.py,可以看到以下类似输出:

 使用设备: cuda
 正在从 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/data/raw/Shakespeare_data.csv 加载数据...
 数据加载完成,共加载了 111389 行台词。
 词汇表已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/vocab.pkl,词汇量大小: 77
 总序列数: 4365922
 模型训练开始...
 Epoch 1/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.953]
 Epoch 1 完成, 平均训练损失: 1.1017
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 2/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.28it/s, loss=0.927]
 Epoch 2 完成, 平均训练损失: 0.9059
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 3/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.47it/s, loss=0.728]
 Epoch 3 完成, 平均训练损失: 0.8570
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 4/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.58it/s, loss=0.992]
 Epoch 4 完成, 平均训练损失: 0.8321
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 5/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.912]
 Epoch 5 完成, 平均训练损失: 0.8161
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 6/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.55it/s, loss=0.786]
 Epoch 6 完成, 平均训练损失: 0.8047
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 7/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.48it/s, loss=0.797]
 Epoch 7 完成, 平均训练损失: 0.7960
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 8/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.74]
 Epoch 8 完成, 平均训练损失: 0.7892
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 9/20: 100%|██████████| 27288/27288 [08:45<00:00, 51.97it/s, loss=0.792]
 Epoch 9 完成, 平均训练损失: 0.7837
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 10/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.37it/s, loss=0.865]
 Epoch 10 完成, 平均训练损失: 0.7794
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 11/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.42it/s, loss=0.681]
 Epoch 11 完成, 平均训练损失: 0.7753
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 12/20: 100%|██████████| 27288/27288 [08:40<00:00, 52.45it/s, loss=0.852]
 Epoch 12 完成, 平均训练损失: 0.7720
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 13/20: 100%|██████████| 27288/27288 [08:44<00:00, 52.01it/s, loss=0.819]
 Epoch 13 完成, 平均训练损失: 0.7694
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 14/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.678]
 Epoch 14 完成, 平均训练损失: 0.7670
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 15/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.19it/s, loss=0.532]
 Epoch 15 完成, 平均训练损失: 0.7647
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 16/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.696]
 Epoch 16 完成, 平均训练损失: 0.7631
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 17/20: 100%|██████████| 27288/27288 [08:46<00:00, 51.85it/s, loss=0.561]
 Epoch 17 完成, 平均训练损失: 0.7614
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 18/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.715]
 Epoch 18 完成, 平均训练损失: 0.7598
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 19/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.17it/s, loss=0.674]
 Epoch 19 完成, 平均训练损失: 0.7582
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 20/20: 100%|██████████| 27288/27288 [08:53<00:00, 51.15it/s, loss=0.864]
 Epoch 20 完成, 平均训练损失: 0.7572
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 模型训练完成!

模型成功完成了20个Epoch的训练。训练损失从初始的1.10稳定下降到最终的0.75,表明模型有效地学习了数据中的语言模式。训练过程稳定,速度高效。

模型测试

模型训练完成后,使用generate.py脚本生成文本

 from fairest creatures we desire increase,
 that never shall become him.
 As you have that, you rascal, yet the saying is,
 Shall wait upon the bed of day to-morrow.
 So that, for mine I pray you:
 I make you these rocks to the elements,
 I'll never wear some strange enemy prove
 A full, and then I will be by her foot,
 Yet I am remember'd with the host:
 When such proceeding breedings, dogs she stands,
 A thousand of her many hours of good Clifford
  • 优点:

    • 词汇与拼写:生成的单词拼写完全正确,并成功使用了fairest, creatures, rascal等符合风格的词汇。

    • 语法与格式:基本遵循英文语法,标点符号使用合理。最令人惊喜的是,模型学会了剧本格式,在最后生成了一个角色名Clifford

    • 风格捕捉:文本整体带有一种戏剧化和古雅的腔调,成功模仿了莎士比亚文本的“形”。

  • 不足之处:

    • 语义不连贯:句子之间缺乏逻辑联系,整体内容是随机的、无意义的。这是字符级语言模型的普遍局限。

    • 逻辑混乱:模型不理解其生成内容的含义,只是在进行概率上的字符预测。

5 原理总结

在【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中,我们对GRU的数学原理和结构有了非常扎实的理论了解。但是反观代码,似乎代码中并未体现相关内容。为什么呢?


从宏观到微观——PyTorch的“魔法”

首先需要明白的是,PyTorch这样的深度学习框架,其核心目标之一就是封装复杂的数学运算,让你能用更简洁、更高层的代码来表达模型。

我们在项目中使用的 nn.GRU 就是这样一个高度封装的“魔法盒子”。

 # 项目中的代码
 self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)

这短短一行代码,背后就对应了【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中所有的公式:重置门、更新门、候选隐状态的计算等等。框架自动初始化了所有的 Wb 参数,并实现了整个前向传播流程。

我们的目标:现在,我们要打开这个“魔法盒子”,亲手制作里面的每一个零件,看看它们是如何协同工作的。


从零开始,构建一个GRU单元

为了彻底理解,我们不直接用 nn.GRU,而是自己动手写一个 GRUCell。一个 GRUCell 就是GRU在单个时间步 t 所做的所有计算。

【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中列出了所有需要学习的参数矩阵和偏置向量。在我们的PyTorch代码中,它们通常以 nn.Linear 层的形式存在。

点拨nn.Linear(in_features, out_features) 本质上就是实现 Y = X @ W.T + b。注意这里有个转置 .T,所以 W 的形状是 (out_features, in_features)。这和【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中的 X_t * W_xr (其中 W_xrd x h) 是完全一致的。

 import torch
 import torch.nn as nn
 ​
 ​
 class GRUCell_FromScratch(nn.Module):
     """
     从零实现的GRU单元,其计算逻辑与PyTorch官方的nn.GRUCell完全一致。
     """
     def __init__(self, input_size, hidden_size):
         super(GRUCell_FromScratch, self).__init__()
         self.input_size = input_size
         self.hidden_size = hidden_size
 ​
         # --- 定义所有线性变换的权重矩阵 ---
         # 注意:所有nn.Linear层都设置bias=False,因为我们将手动创建和管理所有偏置项,
         # 这样做可以更精确地控制偏置项的运算方式,以匹配官方实现。
 ​
         # 重置门 的权重: W_xr, W_hr
         self.linear_xr = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hr = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # 更新门 的权重: W_xz, W_hz
         self.linear_xz = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hz = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # 候选隐状态 的权重: W_xh, W_hh
         self.linear_xh = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hh = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # --- 手动创建所有偏置项 ---
         # 重置门和更新门的偏置 (官方实现中,输入偏置和隐藏状态偏置是相加的)
         self.bias_r = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_r = b_ir + b_hr
         self.bias_z = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_z = b_iz + b_hz
         # 候选隐状态的两个独立偏置项
         self.bias_in = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_in, 不被重置门门控
         self.bias_hn = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_hn, 被重置门门控
 ​
     def forward(self, x_t, h_prev):
         """
         前向传播计算单个时间步。
         参数:
             x_t: 当前时间步的输入, shape (batch_size, input_size)
             h_prev: 前一时间步的隐状态, shape (batch_size, hidden_size)
         返回:
             h_t: 当前时间步计算出的新隐状态, shape (batch_size, hidden_size)
         """
         # --- 步骤 1: 计算重置门 ---
         # 公式: R_t = σ(X_t * W_xr + H_{t-1} * W_hr + b_r)
         r_t = torch.sigmoid(self.linear_xr(x_t) + self.linear_hr(h_prev) + self.bias_r)
 ​
         # --- 步骤 2: 计算更新门 ---
         # 公式: Z_t = σ(X_t * W_xz + H_{t-1} * W_hz + b_z)
         z_t = torch.sigmoid(self.linear_xz(x_t) + self.linear_hz(h_prev) + self.bias_z)
 ​
         # --- 步骤 3: 计算候选隐状态 ---
         # 公式: H_tilde = tanh(X_t * W_xh + R_t ⊙ (H_{t-1} * W_hh + b_hn) + b_in)
         # 这里的 * 运算符实现了哈达玛积 (⊙)
         h_tilde = torch.tanh(
             self.linear_xh(x_t) +
             r_t * (self.linear_hh(h_prev) + self.bias_hn) +
             self.bias_in
         )
 ​
         # --- 步骤 4: 计算最终的隐状态 ---
         # 公式: H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_tilde
         h_t = z_t * h_prev + (1 - z_t) * h_tilde
 ​
         return h_t

运行并验证我们的“手搓”GRU

理论讲完了,我们来做个小实验,验证一下我们写的 GRUCell_FromScratch 和PyTorch官方的 nn.GRUCell 是不是等价的。

 # --- 实验验证部分 ---
 # 定义超参数
 batch_size = 4
 input_size = 10
 hidden_size = 20
 ​
 # 创建随机的输入数据
 x_t = torch.randn(batch_size, input_size)
 h_prev = torch.randn(batch_size, hidden_size)
 ​
 # 实例化我们实现的GRU单元和PyTorch官方的GRU单元
 gru_scratch = GRUCell_FromScratch(input_size, hidden_size)
 gru_official = nn.GRUCell(input_size, hidden_size)
 ​
 # --- 将官方模型的参数复制到我们实现的模型中 ---
 # PyTorch将所有输入权重和偏置堆叠在 weight_ih 和 bias_ih 中
 # 将所有隐藏状态权重和偏置堆叠在 weight_hh 和 bias_hh 中
 ​
 # 复制权重
 gru_scratch.linear_xr.weight.data.copy_(gru_official.weight_ih[:hidden_size, :])
 gru_scratch.linear_hr.weight.data.copy_(gru_official.weight_hh[:hidden_size, :])
 gru_scratch.linear_xz.weight.data.copy_(gru_official.weight_ih[hidden_size:hidden_size * 2, :])
 gru_scratch.linear_hz.weight.data.copy_(gru_official.weight_hh[hidden_size:hidden_size * 2, :])
 gru_scratch.linear_xh.weight.data.copy_(gru_official.weight_ih[hidden_size * 2:, :])
 gru_scratch.linear_hh.weight.data.copy_(gru_official.weight_hh[hidden_size * 2:, :])
 ​
 # 复制偏置
 gru_scratch.bias_r.data.copy_(gru_official.bias_ih[:hidden_size] + gru_official.bias_hh[:hidden_size])
 gru_scratch.bias_z.data.copy_(gru_official.bias_ih[hidden_size:hidden_size * 2] + gru_official.bias_hh[hidden_size:hidden_size * 2])
 gru_scratch.bias_in.data.copy_(gru_official.bias_ih[hidden_size * 2:])
 gru_scratch.bias_hn.data.copy_(gru_official.bias_hh[hidden_size * 2:])
 ​
 # 分别用两个模型进行前向传播
 h_t_scratch = gru_scratch(x_t, h_prev)
 h_t_official = gru_official(x_t, h_prev)
 ​
 # 比较结果
 print("我们手搓的GRU输出:", h_t_scratch)
 print("PyTorch官方GRU输出:", h_t_official)
 ​
 # 检查两个输出是否在数值上几乎相同
 print("\n两个输出是否几乎相同?", torch.allclose(h_t_scratch, h_t_official))

当你运行这段代码,如果最后打印出 True,那么恭喜你!你已经成功地用代码复现了GRU的核心数学原理。这证明了你对GRU的理解已经深入到了“像素级”。以下是我运行的结果:

 我们手搓的GRU输出: tensor([[-0.8295, -0.2595,  0.9684,  0.2140, -0.5785, -0.1933,  0.5778,  0.3714,
           0.5419, -1.2080,  0.5973,  0.6886, -0.7297, -0.1810, -0.3550,  0.1375,
           0.3358, -0.2605, -0.4440,  0.6498],
         [ 0.3155,  0.0616, -0.5738, -0.1972, -0.0984,  0.0601,  0.3601,  0.1683,
          -0.4179,  0.4705,  0.4867, -0.5043,  1.2716,  0.0027, -0.4619, -0.3631,
          -0.4136, -0.6153, -0.3496, -0.8575],
         [-0.1301, -0.4527, -0.3129,  0.2685, -0.3576, -0.3155, -0.4003,  0.4550,
          -0.3802,  0.3482,  0.8009,  0.1505,  0.2446,  0.0780,  0.4634, -0.1107,
           0.2131,  0.3837, -0.4669,  0.0181],
         [-0.0648,  0.0902, -0.0132,  0.0585, -0.1076, -0.5664,  0.1125, -0.1067,
          -0.0702, -0.5483,  0.5603,  0.2239,  0.0498,  0.8238, -0.0751, -0.4099,
           0.1920,  0.5400, -0.1944,  0.4914]], grad_fn=)
 PyTorch官方GRU输出: tensor([[-0.8295, -0.2595,  0.9684,  0.2140, -0.5785, -0.1933,  0.5778,  0.3714,
           0.5419, -1.2080,  0.5973,  0.6886, -0.7297, -0.1810, -0.3550,  0.1375,
           0.3358, -0.2605, -0.4440,  0.6498],
         [ 0.3155,  0.0616, -0.5738, -0.1972, -0.0984,  0.0601,  0.3601,  0.1683,
          -0.4179,  0.4705,  0.4867, -0.5043,  1.2716,  0.0027, -0.4619, -0.3631,
          -0.4136, -0.6153, -0.3496, -0.8575],
         [-0.1301, -0.4527, -0.3129,  0.2685, -0.3576, -0.3155, -0.4003,  0.4550,
          -0.3802,  0.3482,  0.8009,  0.1505,  0.2446,  0.0780,  0.4634, -0.1107,
           0.2131,  0.3837, -0.4669,  0.0181],
         [-0.0648,  0.0902, -0.0132,  0.0585, -0.1076, -0.5664,  0.1125, -0.1067,
          -0.0702, -0.5483,  0.5603,  0.2239,  0.0498,  0.8238, -0.0751, -0.4099,
           0.1920,  0.5400, -0.1944,  0.4914]], grad_fn=)
 ​
 两个输出是否几乎相同? True

回归项目,融会贯通

现在,我们再回头看项目代码 src/model.py

 # src/model.py
 class GRUModel(nn.Module):
     # ...
     def forward(self, x, hidden):
         embedded = self.embedding(x)
         out, hidden = self.gru(embedded, hidden) # <--- 就是这里!
         out = out.contiguous().view(-1, self.hidden_dim)
         out = self.fc(out)
         return out, hidden

这里的 self.gru 就是我们上面验证的 nn.GRUCell 的“循环版本”。nn.GRU 会在内部自动地遍历输入序列的每一个时间步,反复调用 GRUCell 的计算逻辑,并把每一步的隐状态传递给下一步。

  • self.embedding:将你的字符ID转换为密集向量。

  • self.gru:处理整个序列,输出每个时间步的隐状态 out 和最后一个时间步的隐状态 hidden

  • self.fc:这就是【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记 2.4 节 的最终输出层 O_t = H_t * W_hq + b_q,它将GRU的输出映射到词汇表大小,用于预测下一个字符的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/941176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

得帆AI aPaaS(AI低代码)1.0产品特性(5)-智能搭建(二)

在上一期中,我们探讨了「为什么用智能搭建」。今天,我们将拆解搭建任务管理系统的真实案例,亲历从「说需求」到「应用上线」的完整旅程,揭秘智能搭建的六大核心步骤! 第一步:需求理解——对话驱动,精准捕捉需求…

实用指南:如何快速学习一个网络协议?

实用指南:如何快速学习一个网络协议?2025-10-20 14:00 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !…

实用指南:【Linux 系统】命令行参数和环境变量

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

得帆AI aPaaS(AI低代码)1.0产品特性(4)-智能搭建(一)

得帆AI aPaaS是一个多智能体协同的Agent平台:把零散、需手动配置的动作,转成可执行的任务编排,快速落地“小而美”的业务场景,并能持续迭代。 智能搭建是其核心能力,可以基于业务语言交互、Excel、需求文档等,在…

日记11

今天终于搞懂了 ArrayList 和数组的区别!之前总混淆两者,今天对着代码调试才发现, ArrayList 能自动扩容,不用像数组那样一开始就定死长度,比如添加第11个元素时,它会悄悄把容量从10变成15。 不过写遍历代码时还…

element 表单校验失败定位到指定元素

this.$refs.generateForm.validate(valid => {if (valid) {} else {console.log(表单数据校验失败)this.moveToErr()}moveToErr() {this.$nextTick(() => {let isError = document.getElementsByClassName(is-er…

腾讯企业邮箱管理

一、邮箱更换手机号(切换新的登录人) 1.增加新邮箱A 在企业管理后台——通讯录——组织架构——添加成员 添加新成员 设置新成员是这个手机号,同时先给这个成员设置其他企业邮箱账号NewA@xxx.cn。 2、删除原邮箱Ol…

2025年湖北武汉实验室设计哪家口碑好/哪家信誉好/哪家售后好?

2025年湖北武汉实验室设计口碑之选——湖北特尔诺实验室设备有限公司 在当今科技日新月异的时代,实验室作为科研、教学与检测的重要场所,其设计与建设质量直接关系到实验结果的准确性与人员的安全。特别是在湖北武汉…

国产化Word处理控件Spire.Doc教程:用Java实现TXT文本与Word互转的完整教程

纯文本(.txt)文件因简洁通用被广泛使用,但无法支持字体、表格、图片等格式;而 Word(.docx)文件虽具备丰富的排版能力,却难以直接用于文本分析、索引等场景。本文将详细介绍如何通过 Spire.Doc for Java(一款轻…

C# Avalonia 16- Animation- BombDropper

C# Avalonia 16- Animation- BombDropper结合我们之前写的AnimationPlayer,现在实现一个小游戏。 定义自己的Style,前面有例子已经说明了如何在自己的Styles.axaml中写Style。<!-- Bomb 样式 --> <Style Se…

C# 使用NPOI生成Word文件

NuGet 安装 NPOI 1. 建立模板(可选): 手动建立Word模板, 多使用表格然后隐藏边框, 方便数据插入固定位置 2. 建立Word对象引入模板string dPath = $"{Environment.CurrentDirectory}\\Data\\Demo.docx"; Str…

2025年太阳能板定制厂家口碑排行榜单:权威推荐与选择指南

摘要 随着全球能源转型加速,太阳能板行业迎来爆发式增长,2025年市场规模预计突破3000亿美元。消费者在选择太阳能板定制厂家时面临诸多困惑,本文基于技术实力、产品质量、客户口碑等维度,为您呈现最新行业排行榜单…

2025年太阳能板定制厂家口碑排行榜前十强:专业评测与选择指南

摘要 随着全球能源转型加速,太阳能板行业在2025年迎来爆发式增长,定制化需求显著提升。本文基于市场调研和用户反馈,整理出太阳能板厂家口碑排行榜单,旨在帮助用户快速找到可靠供应商。榜单结合技术参数、服务质量…

Python3 statistics 模块

Python3 statistics 模块statistics 是 Python 3.4 引入的标准库,专注于提供基本的统计计算功能,可用于分析数值数据的集中趋势、离散程度、分布形状等。它无需额外安装,接口简洁,适合快速完成简单的统计分析任务(…

linux内核开发学习计划

目录岗位需求实习--字节跳动--Linux内核开发实习生--实时核方向正式--Linux内核驱动工程师 央企直招实习--乐研--linux内核研发工程师正式--京东--OS内核核心研发正式--小米--linux内核高级工程师 岗位需求 实习--字节…

随机生成动态头像

Multiavatar 是一个多文化的头像生成器,使用 JavaScript 编写。它能够生成代表不同种族、文化、年龄组、世界观和生活方式的头像。Multiavatar 可以生成超过 120 亿个独特的头像。Multiavatar Github 地址安装 pnpm i…

2025年湖北武汉实验室装修/实验室设计/实验室改造哪个厂家好

2025年湖北武汉实验室装修厂家推荐:湖北特尔诺实验室设备有限公司 在2025年,如果您正在寻找湖北武汉地区优质的实验室装修厂家,那么​​湖北特尔诺实验室设备有限公司​​无疑是您的不二之选。这家公司凭借其专业的…

能源AI天团:多智能体如何破解行业复杂任务 - 实践

能源AI天团:多智能体如何破解行业复杂任务 - 实践2025-10-20 13:39 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; displ…

2025年AI搜索优化品牌推荐排行榜前十强深度解析

摘要 随着人工智能技术的迅猛发展,AI搜索优化行业在2025年迎来爆发式增长,企业依赖高效、精准的搜索优化服务提升在线可见性和业务转化。本文基于市场调研和数据统计,解析2025年AI搜索优化品牌排行榜前十强,为读者…

2025年AI搜索优化品牌推荐排行榜:技术深度解析与选择指南

摘要 随着人工智能技术的飞速发展,AI搜索优化行业在2025年迎来爆发式增长,旨在提升内容精准度和用户体验。本排行基于技术实力、服务口碑、案例实效等维度综合评估,为寻求加盟或服务的企业提供参考。表单数据来源于…