文章目录
- 前言
- 一、数据加载与预处理
- 1.1 代码实现
- 1.2 功能解析
- 二、LSTM介绍
- 2.1 LSTM原理
- 2.2 模型定义
- 代码解析
- 三、训练与预测
- 3.1 训练逻辑
- 代码解析
- 3.2 可视化工具
- 功能解析
- 功能结果
- 总结
前言
深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。
本文的代码分为四个主要部分:
- 数据加载与预处理(
utils_for_data.py
) - LSTM模型定义(Jupyter Notebook中的模型部分)
- 训练与预测逻辑(
utils_for_train.py
) - 可视化工具(
utils_for_huitu.py
)
以下是详细的实现与解析。
一、数据加载与预处理
首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py
中的完整代码及其功能说明。
1.1 代码实现
import random
import re
import torch
from collections import Counterdef read_time_machine():"""将时间机器数据集加载到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'错误:未知词元类型:{token}')def count_corpus(tokens):"""统计词元的频率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocabdef seq_data_iter_random(corpus, batch_size, num_steps):offset = random.randint(0, num_steps - 1)corpus = corpus[offset:]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)def seq_data_iter_sequential(corpus, batch_size, num_steps):offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) *