从代码学习深度学习 - LSTM PyTorch版

文章目录

  • 前言
  • 一、数据加载与预处理
    • 1.1 代码实现
    • 1.2 功能解析
  • 二、LSTM介绍
    • 2.1 LSTM原理
    • 2.2 模型定义
      • 代码解析
  • 三、训练与预测
    • 3.1 训练逻辑
      • 代码解析
    • 3.2 可视化工具
      • 功能解析
      • 功能结果
  • 总结


前言

深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。

本文的代码分为四个主要部分:

  1. 数据加载与预处理(utils_for_data.py
  2. LSTM模型定义(Jupyter Notebook中的模型部分)
  3. 训练与预测逻辑(utils_for_train.py
  4. 可视化工具(utils_for_huitu.py

以下是详细的实现与解析。


一、数据加载与预处理

首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py中的完整代码及其功能说明。

1.1 代码实现

import random
import re
import torch
from collections import Counterdef read_time_machine():"""将时间机器数据集加载到文本行的列表中"""with open('timemachine.txt', 'r') as f:lines = f.readlines()return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]def tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print(f'错误:未知词元类型:{token}')def count_corpus(tokens):"""统计词元的频率"""if not tokens:return Counter()if isinstance(tokens[0], list):flattened_tokens = [token for sublist in tokens for token in sublist]else:flattened_tokens = tokensreturn Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1@staticmethoddef _count_corpus(tokens):if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):return len(self.idx_to_token)def __getitem__(self, tokens):if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]@propertydef unk(self):return 0@propertydef token_freqs(self):return self._token_freqsdef load_corpus_time_machine(max_tokens=-1):lines = read_time_machine()tokens = tokenize(lines, 'char')vocab = Vocab(tokens)corpus = [vocab[token] for line in tokens for token in line]if max_tokens > 0:corpus = corpus[:max_tokens]return corpus, vocabdef seq_data_iter_random(corpus, batch_size, num_steps):offset = random.randint(0, num_steps - 1)corpus = corpus[offset:]num_subseqs = (len(corpus) - 1) // num_stepsinitial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)def data(pos):return corpus[pos:pos + num_steps]num_batches = num_subseqs // batch_sizefor i in range(0, batch_size * num_batches, batch_size):initial_indices_per_batch = initial_indices[i:i + batch_size]X = [data(j) for j in initial_indices_per_batch]Y = [data(j + 1) for j in initial_indices_per_batch]yield torch.tensor(X), torch.tensor(Y)def seq_data_iter_sequential(corpus, batch_size, num_steps):offset = random.randint(0, num_steps)num_tokens = ((len(corpus) - offset - 1) // batch_size) *

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74708.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

easy-poi 一对多导出

1. 需求&#xff1a; 某一列上下两行单元格A,B值一样且这两个单元格&#xff0c; 前面所有列对应单元格值一样的话&#xff0c; 就对A,B 两个单元格进行纵向合并单元格 1. 核心思路&#xff1a; 先对数据集的国家&#xff0c;省份&#xff0c;城市...... id 身份证进行排序…

AI比人脑更强,因为被植入思维模型【42】思维投影思维模型

giszz的理解&#xff1a;本质和外在。我们的行为举止&#xff0c;都是我们的内心的表现。从外边可以看内心&#xff0c;从内心可以判断外在。曾国藩有&#xff17;个识人的方法&#xff0c;大部分的人在他的面前如同没穿衣服一样。对于我们自身的启迪&#xff0c;我认为有四点&…

Spring Boot 打印日志

1.通过slf4j包中的logger对象打印日志 Spring Boot内置了日志框架slf4j&#xff0c;在程序中调用slf4j来输出日志 通过创建logger对象打印日志&#xff0c;Logger 对象是属于 org.slf4j 包下的不要导错包。 2.日志级别 日志级别从高到低依次为: FATAL:致命信息&#xff0c;表…

【IOS webview】源代码映射错误,页面卡住不动

报错场景 safari页面报源代码映射错误&#xff0c;页面卡住不动。 机型&#xff1a;IOS13 技术栈&#xff1a;react 其他IOS也会报错&#xff0c;但不影响页面显示。 debug webpack配置不要GENERATE_SOURCEMAP。 解决方法&#xff1a; GENERATE_SOURCEMAPfalse react-app…

ES中经纬度查询geo_point

0. ES版本 6.x版本 1. 创建索引 PUT /location {"settings": {"number_of_shards": 1,"number_of_replicas": 0},"mappings": {"location": {"properties": {"id": {"type": "keywor…

OpenCV界面编程

《OpenCV计算机视觉开发实践&#xff1a;基于Python&#xff08;人工智能技术丛书&#xff09;》(朱文伟&#xff0c;李建英)【摘要 书评 试读】- 京东图书 OpenCV的Python开发环境搭建(Windows)-CSDN博客 OpenCV也支持有限的界面编程&#xff0c;主要是针对窗口、控件和鼠标…

GOC L2 第五课模运算和周期二

课堂回顾&#xff1a; 求取余数的过程叫做模运算 每轮的动作都是重复的&#xff0c;我们称这个过程位周期。 课堂学习&#xff1a; 剩余计算器 秋天到了&#xff0c;学校里的苹果熟了&#xff0c;太乙老师&#xff0c;想让哪吒帮忙设计一个计算器&#xff0c;看每个小朋友能分…

54.大学生心理健康管理系统(基于springboot项目)

目录 1.系统的受众说明 2.相关技术 2.1 B/S结构 2.2 MySQL数据库 3.系统分析 3.1可行性分析 3.1.1时间可行性 3.1.2 经济可行性 3.1.3 操作可行性 3.1.4 技术可行性 3.1.5 法律可行性 3.2系统流程分析 3.3系统功能需求分析 3.4 系统非功能需求分析 4.系统设计…

Redis 除了数据类型外的核心功能 的详细说明,包含事务、流水线、发布/订阅、Lua 脚本的完整代码示例和表格总结

以下是 Redis 除了数据类型外的核心功能 的详细说明&#xff0c;包含事务、流水线、发布/订阅、Lua 脚本的完整代码示例和表格总结&#xff1a; 1. Redis 事务&#xff08;Transactions&#xff09; 功能描述 事务通过 MULTI 和 EXEC 命令将一组命令打包执行&#xff0c;保证…

STM32F103C8T6单片机硬核原理篇:讨论GPIO的基本原理篇章1——只讨论我们的GPIO简单输入和输出

目录 前言 输出时的GPIO控制部分 标准库是如何操作寄存器完成GPIO驱动的初始化的&#xff1f; 问题1&#xff1a;如何掌握GPIO的编程细节——跟寄存器如何打交道 问题2&#xff1a;哪些寄存器&#xff0c;去哪里找呢&#xff1f; 问题三&#xff0c;寄存器的含义&#xff…

前端布局难题:父元素padding导致子元素无法全屏?3种解决方案

大家好&#xff0c;我是一诺。今天要跟大家分享一个我在实际项目中经常用到的CSS技巧——如何让子元素突破父元素的padding限制&#xff0c;实现真正的全屏宽度效果。 为什么会有这个需求&#xff1f; 记得我刚入行的时候&#xff0c;接到一个需求&#xff1a;要在内容区插入…

当网页受到DDOS网络攻击有哪些应对方法?

分布式拒绝服务攻击也是人们较为熟悉的DDOS攻击&#xff0c;这类攻击会通过大量受控制的僵尸网络向目标服务器发送请求&#xff0c;以此来消耗服务器中的资源&#xff0c;致使用户无法正常访问&#xff0c;当网页受到分布式拒绝服务攻击时都有哪些应对方法呢&#xff1f; 建立全…

LeNet-5简介及matlab实现

文章目录 一、LeNet-5网络结构简介二、LeNet-5每一层的实现原理2.1. 第一层 (C1) &#xff1a;卷积层&#xff08;Convolution Layer&#xff09;2.2. 第二层 (S2) &#xff1a;池化层&#xff08;Pooling Layer&#xff09;2.3. 第三层&#xff08;C3&#xff09;&#xff1a;…

【LLM】MCP(Python):实现 stdio 通信的Client与Server

本文将详细介绍如何使用 Model Context Protocol (MCP) 在 Python 中实现基于 STDIO 通信的 Client 与 Server。MCP 是一个开放协议&#xff0c;它使 LLM 应用与外部数据源和工具之间的无缝集成成为可能。无论你是构建 AI 驱动的 IDE、改善 chat 交互&#xff0c;还是构建自定义…

Docker 安装 Elasticsearch 教程

目录 一、安装 Elasticsearch 二、安装 Kibana 三、安装 IK 分词器 四、Elasticsearch 常用配置 五、Elasticsearch 常用命令 一、安装 Elasticsearch &#xff08;一&#xff09;创建 Docker 网络 因为后续还需要部署 Kibana 容器&#xff0c;所以需要让 Elasticsearch…

Swagger @ApiOperation

ApiOperation 注解并非 Spring Boot 自带的注解&#xff0c;而是来自 Swagger 框架&#xff0c;Swagger 是一个规范且完整的框架&#xff0c;用于生成、描述、调用和可视化 RESTful 风格的 Web 服务&#xff0c;而 ApiOperation 主要用于为 API 接口的操作添加描述信息。以下为…

【奇点时刻】GPT4o新图像生成模型底层原理深度洞察报告(篇2)

由于上一篇解析深度不足&#xff0c;经过查看学习相关论文&#xff0c;以下是一份对 GPT-4o 最新的图像生成模型 的深度梳理与洞察&#xff0c;从模型原理到社区解读、对比传统扩散模型&#xff0c;再到对未来趋势的分析。为了便于阅读&#xff0c;整理成以下七个部分&#xff…

C# 窗体应用(.FET Framework ) 打开文件操作

一、 打开文件或文件夹加载数据 1. 定义一个列表用来接收路径 public List<string> paths new List<string>();2. 打开文件选择一个文件并将文件放入列表中 OpenFileDialog open new OpenFileDialog(); // 过滤 open.Filter "(*.jpg;*.jpge;*.bmp;*.png…

Scala 面向对象编程总结

​​​抽象属性和抽象方法 基本语法 定义抽象类&#xff1a;abstract class Person{} //通过 abstract 关键字标记抽象类定义抽象属性&#xff1a;val|var name:String //一个属性没有初始化&#xff0c;就是抽象属性定义抽象方法&#xff1a;def hello():String //只声明而没…

人工智能赋能工业制造:智能制造的未来之路

一、引言 随着人工智能技术的飞速发展&#xff0c;其应用场景不断拓展&#xff0c;从消费电子到医疗健康&#xff0c;从金融科技到交通运输&#xff0c;几乎涵盖了所有行业。而工业制造作为国民经济的支柱产业&#xff0c;也在人工智能的浪潮中迎来了深刻的变革。智能制造&…