深度学习分词器char-level实战详解

一、三种分词器基本介绍

word-level:将文本按照空格或者标点分割成单词,但是词典大小太大

subword-level:词根分词(主流)

char-level:将文本按照字母级别分割成token

二、charlevel代码

导包:

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as Fprint(sys.version_info)
for module in mpl, np, pd, sklearn, torch:print(module.__name__, module.__version__)device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

数据准备(需下载):

# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
#文件已经下载好了
with open("./shakespeare.txt", "r", encoding="utf8") as file:text = file.read()print("length", len(text))
print(text[0:100])

 构造字典:

# 1. generate vocab
# 2. build mapping char->id
# 3. data -> id_data  把数据都转为id
# 4. a b c d [EOS] -> [BOS] b c d  预测下一个字符生成的模型,也就是输入是a,输出就是b#去重,留下独立字符,并排序(排序是为了好看)
vocab = sorted(set(text)) # 利用set去重,sorted排序
print(len(vocab))
print(vocab)#每个字符都编好号,enumerate对每一个位置编号,生成的是列表中是元组,下面字典生成式
char2idx = {char:idx for idx, char in enumerate(vocab)}
print(char2idx)# 把vocab从列表变为ndarray
idx2char = np.array(vocab)
print(idx2char)#把字符都转换为id
text_as_int = np.array([char2idx[c] for c in text])
print(text_as_int.shape)
print(len(text_as_int))
print(text_as_int[0:10])
print(text[0:10])
  • enumerate() 是Python内置函数,用于给可迭代对象添加序号
  • 语法:enumerate(iterable, start=0)
  • 作用:将列表/字符串等转换为(索引, 元素)元组的序列

一共1115394个字符,这里分为11043个batch,每个样本101个字符,原因如下:

比如有Jeep四个字符,那么那前三个字母输入J就预测到e,再输入e预测到e再预测到p,相当于错开预测。前100和最后一个错开,就是上图的效果。

把text分为样本:

rom torch.utils.data import Dataset, DataLoaderclass CharDataset(Dataset):#text_as_int是字符的id列表,seq_length是每个样本的长度def __init__(self, text_as_int, seq_length):self.sub_len = seq_length + 1 #一个样本的长度self.text_as_int = text_as_intself.num_seq = len(text_as_int) // self.sub_len #样本的个数def __getitem__(self, index):#index是样本的索引,返回的是一个样本,比如第一个,就是0-100的字符,总计101个字符return self.text_as_int[index * self.sub_len: (index + 1) * self.sub_len]def __len__(self): #返回样本的个数return self.num_seq#batch是一个列表,列表中的每一个元素是一个样本,有101个字符,前100个是输入,后100个是输出
def collat_fct(batch):src_list = [] #输入trg_list = [] #输出for part in batch:src_list.append(part[:-1]) #输入trg_list.append(part[1:]) #输出src_list = np.array(src_list) #把列表转换为ndarraytrg_list = np.array(trg_list) #把列表转换为ndarrayreturn torch.Tensor(src_list).to(dtype=torch.int64), torch.Tensor(trg_list).to(dtype=torch.int64) #返回的是一个元组,元组中的每一个元素是一个torch.Tensor#每个样本的长度是101,也就是100个字符+1个结束符
train_ds = CharDataset(text_as_int, 100)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collat_fct)
#%%
  • seq_length:模型输入的序列长度(例如100)

  • sub_len:实际存储长度 = 输入长度 + 目标长度(每个样本多存1个字符用于构造目标)

假设原始文本数字编码为:[1,2,3,4,5,6,7,8,9,10],当seq_length=3时:样本1: [1,2,3,4] → 输入[1,2,3],目标[2,3,4] 样本2: [5,6,7,8] → 输入[5,6,7],目标[6,7,8] 剩余字符[9,10]被舍弃

定义模型:

class CharRNN(nn.Module):def __init__(self, vocab_size, embedding_dim=256, hidden_dim=1024):super(CharRNN, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)#batch_first=True,输入的数据格式是(batch_size, seq_len, embedding_dim)self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x, hidden=None):x = self.embedding(x) #(batch_size, seq_len) -> (batch_size, seq_len, embedding_dim) (64, 100, 256)#这里和02的差异是没有只拿最后一个输出,而是把所有的输出都拿出来了#(batch_size, seq_len, embedding_dim)->(batch_size, seq_len, hidden_dim)(64, 100, 1024)output, hidden = self.rnn(x, hidden)x = self.fc(output) #[bs, seq_len, hidden_dim]--->[bs, seq_len, vocab_size] (64, 100,65)return x, hidden #x的shape是(batch_size, seq_len, vocab_size)vocab_size = len(vocab)print("{:=^80}".format(" 一层单向 RNN "))       
for key, value in CharRNN(vocab_size).named_parameters():print(f"{key:^40}paramerters num: {np.prod(value.shape)}")

因为字典太小,所以embedding_dim要放大。输入形状(bs,seq)→输出形状(bs,seq,emb_dim)。

这样的话才能把里面的信息分的更清楚,其他情况都是缩小。

生成的时候不能只取最后一个时间步了,全都要。

前向传播流程:x→Embedding→RNN→Linear

训练:

class SaveCheckpointsCallback:def __init__(self, save_dir, save_step=5000, save_best_only=True):"""Save checkpoints each save_epoch epoch. We save checkpoint by epoch in this implementation.Usually, training scripts with pytorch evaluating model and save checkpoint by step.Args:save_dir (str): dir to save checkpointsave_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.save_best_only (bool, optional): If True, only save the best model or save each model at every epoch."""self.save_dir = save_dirself.save_step = save_stepself.save_best_only = save_best_onlyself.best_metrics = -1# mkdirif not os.path.exists(self.save_dir):os.mkdir(self.save_dir)def __call__(self, step, state_dict, metric=None):if step % self.save_step > 0:returnif self.save_best_only:assert metric is not Noneif metric >= self.best_metrics:# save checkpointstorch.save(state_dict, os.path.join(self.save_dir, "best.ckpt"))# update best metricsself.best_metrics = metricelse:torch.save(state_dict, os.path.join(self.save_dir, f"{step}.ckpt"))#%%
# 训练
def training(model, train_loader, epoch, loss_fct, optimizer, save_ckpt_callback=None,stateful=False      # 想用stateful,batch里的数据就必须连续,不能打乱):record_dict = {"train": [],}global_step = 0model.train()hidden = Nonewith tqdm(total=epoch * len(train_loader)) as pbar:for epoch_id in range(epoch):# trainingfor datas, labels in train_loader:datas = datas.to(device)labels = labels.to(device)# 梯度清空optimizer.zero_grad()# 模型前向计算,如果数据集打乱了,stateful=False,hidden就要清空# 如果数据集没有打乱,stateful=True,hidden就不需要清空logits, hidden = model(datas, hidden=hidden if stateful else None)# 计算损失,交叉熵损失第一个参数要是二阶张量,第二个参数要是一阶张量,所以要reshapeloss = loss_fct(logits.reshape(-1, vocab_size), labels.reshape(-1))# 梯度回传loss.backward()# 调整优化器,包括学习率的变动等optimizer.step()loss = loss.cpu().item()# recordrecord_dict["train"].append({"loss": loss, "step": global_step})# 保存模型权重 save model checkpointif save_ckpt_callback is not None:save_ckpt_callback(global_step, model.state_dict(), metric=-loss)# udate stepglobal_step += 1pbar.update(1)pbar.set_postfix({"epoch": epoch_id})return record_dictepoch = 100model = CharRNN(vocab_size=vocab_size)# 1. 定义损失函数 采用交叉熵损失 
loss_fct = nn.CrossEntropyLoss()
# 2. 定义优化器 采用 adam
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# save best
if not os.path.exists("checkpoints"):os.makedirs("checkpoints")
save_ckpt_callback = SaveCheckpointsCallback("checkpoints/text_generation", save_step=1000, save_best_only=True)model = model.to(device)#%%
record = training(model,train_dl,epoch,loss_fct,optimizer,save_ckpt_callback=save_ckpt_callback,)
#%%
plt.plot([i["step"] for i in record["train"][::50]], [i["loss"] for i in record["train"][::50]], label="train")
plt.grid()
plt.show()
#%% md
## 推理
#%%#下面的例子是为了说明temperature
logits = torch.tensor([400.0,600.0]) #这里是logitsprobs1 = F.softmax(logits, dim=-1)
print(probs1)
#%%
logits = torch.tensor([0.04,0.06])  #现在 temperature是2probs1 = F.softmax(logits, dim=-1)
print(probs1)
#%%
import torch# 创建一个概率分布,表示每个类别被选中的概率
# 这里我们有一个简单的四个类别的概率分布
prob_dist = torch.tensor([0.1, 0.45, 0.35, 0.1])# 使用 multinomial 进行抽样
# num_samples 表示要抽取的样本数量
num_samples = 5# 抽取样本,随机抽样,概率越高,抽到的概率就越高,1代表只抽取一个样本,replacement=True表示可以重复抽样
samples_index = torch.multinomial(prob_dist, 1, replacement=True)print("概率分布:", prob_dist)
print("抽取的样本索引:", samples_index)# 显示每个样本对应的概率
print("每个样本对应的概率:", prob_dist[samples_index])
#%%
def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True):input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1) #bacth_size=1, seq_len长度是多少都可以 (1,5)hidden = Nonetext_generated = [] #用来保存生成的文本model.eval()pbar = tqdm(range(max_len)) # 进度条print(start_string, end="")# no_grad是一个上下文管理器,用于指定在其中的代码块中不需要计算梯度。在这个区域内,不会记录梯度信息,用于在生成文本时不影响模型权重。with torch.no_grad():for i in pbar:#控制进度条logits, hidden = model(input_eval, hidden=hidden)# 温度采样,较高的温度会增加预测结果的多样性,较低的温度则更加保守。#取-1的目的是只要最后,拼到原有的输入上logits = logits[0, -1, :] / temperature #logits变为1维的# using multinomial to samplingprobs = F.softmax(logits, dim=-1) #算为概率分布idx = torch.multinomial(probs, 1).item() #从概率分布中抽取一个样本,取概率较大的那些input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1) #把idx转为tensortext_generated.append(idx)if stream:print(idx2char[idx], end="", flush=True)return "".join([idx2char[i] for i in text_generated])# load checkpoints
model.load_state_dict(torch.load("checkpoints/text_generation/best.ckpt", weights_only=True,map_location="cpu"))
start_string = "All: " #这里就是开头,什么都可以
res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/897464.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于SpringBoot实现旅游酒店平台功能六

一、前言介绍: 1.1 项目摘要 随着社会的快速发展和人民生活水平的不断提高,旅游已经成为人们休闲娱乐的重要方式之一。人们越来越注重生活的品质和精神文化的追求,旅游需求呈现出爆发式增长。这种增长不仅体现在旅游人数的增加上&#xff0…

git规范提交之commitizen conventional-changelog-cli 安装

一、引言 使用规范的提交信息可以让项目更加模块化、易于维护和理解,同时也便于自动化工具(如发布工具或 Changelog 生成器)解析和处理提交记录。 通过编写符合规范的提交消息,可以让团队和协作者更好地理解项目的变更历史和版本…

前端实现版本更新自动检测✅

🤖 作者简介:水煮白菜王,一位资深前端劝退师 👻 👀 文章专栏: 前端专栏 ,记录一下平时在博客写作中,总结出的一些开发技巧和知识归纳总结✍。 感谢支持💕💕&a…

硬件基础(4):(5)设置ADC电压采集中MCU的参考电压

Vref 引脚通常是 MCU (特别是带有 ADC 的微控制器) 上用来提供或接收基准电压的引脚,ADC 会以该基准电压作为量程参考对输入模拟信号进行数字化转换。具体来说: 命名方式 在不同厂家的 MCU 中,Vref 引脚可能会被标记为 VREF / VREF- / VREF_…

postman接口请求中的 Raw是什么

前言 在现代的网络开发中,API 的使用已经成为数据交换的核心方式之一。然而,在与 API 打交道时,关于如何发送请求体(body)内容类型的问题常常困扰着开发者们,尤其是“raw”和“json”这两个术语之间的区别…

为什么要使用前缀索引,以及建立前缀索引:sql示例

背景: 你想啊,数据库里有些字段,它老长了,就像那种 varchar(255) 的字段,这玩意儿要是整个字段都拿来建索引,那可太占地方了。打个比方,这就好比你要在一个超级大的笔记本上记东西,每…

【语料数据爬虫】Python爬虫|批量采集会议纪要数据(1)

前言 本文是该专栏的第2篇,后面会持续分享Python爬虫采集各种语料数据的的干货知识,值得关注。 在本文中,笔者将主要来介绍基于Python,来实现批量采集“会议纪要”数据。同时,本文也是采集“会议纪要”数据系列的第1篇。 采集相关数据的具体细节部分以及详细思路逻辑,笔…

Android 线程池实战指南:高效管理多线程任务

在 Android 开发中,线程池的使用非常重要,尤其是在需要处理大量异步任务时。线程池可以有效地管理线程资源,避免频繁创建和销毁线程带来的性能开销。以下是线程池的使用方法和最佳实践。 1. 线程池的基本使用 (1)创建线…

SQL29 计算用户的平均次日留存率

SQL29 计算用户的平均次日留存率 计算用户的平均次日留存率_牛客题霸_牛客网 题目:现在运营想要查看用户在某天刷题后第二天还会再来刷题的留存率。 示例:question_practice_detail -- 输入: DROP TABLE IF EXISTS question_practice_detai…

深度学习分类回归(衣帽数据集)

一、步骤 1 加载数据集fashion_minst 2 搭建class NeuralNetwork模型 3 设置损失函数,优化器 4 编写评估函数 5 编写训练函数 6 开始训练 7 绘制损失,准确率曲线 二、代码 导包,打印版本号: import matplotlib as mpl im…

【leetcode hot 100 19】删除链表的第N个节点

解法一:将ListNode放入ArrayList中,要删除的元素为num list.size()-n。如果num 0则将头节点删除;否则利用num-1个元素的next删除第num个元素。 /*** Definition for singly-linked list.* public class ListNode {* int val;* Lis…

【iOS逆向与安全】sms短信转发插件与上传服务器开发

一、目标 一步步分析并编写一个短信自动转发的deb插件 二、工具 mac系统已越狱iOS设备:脱壳及frida调试IDA Pro:静态分析测试设备:iphone6s-ios14.1.1三、步骤 1、守护进程 ​ 守护进程(daemon)是一类在后台运行的特殊进程,用于执行特定的系统任务。例如:推送服务、人…

Midjourney绘图参数详解:从基础到高级的全面指南

引言 Midjourney作为当前最受欢迎的AI绘图工具之一,其强大的参数系统为用户提供了丰富的创作可能性。本文将深入解析Midjourney的各项参数,帮助开发者更好地掌握这一工具,提升创作效率和质量。 一、基本参数配置 1. 图像比例调整 使用--ar…

音频进阶学习十九——逆系统(简单进行回声消除)

文章目录 前言一、可逆系统1.定义2.解卷积3.逆系统恢复原始信号过程4.逆系统与原系统的零极点关系 二、使用逆系统去除回声获取原信号的频谱原系统和逆系统幅频响应和相频响应使用逆系统恢复原始信号整体代码如下 总结 前言 在上一篇音频进阶学习十八——幅频响应相同系统、全…

vue3 使用sass变量

1. 在<style>中使用scss定义的变量和css变量 1. 在/style/variables.scss文件中定义scss变量 // scss变量 $menuText: #bfcbd9; $menuActiveText: #409eff; $menuBg: #304156; // css变量 :root {--el-menu-active-color: $menuActiveText; // 活动菜单项的文本颜色--el…

gbase8s rss集群通信流程

什么是rss RSS是一种将数据从主服务器复制到备服务器的方法 实例级别的复制 (所有启用日志记录功能的数据库) 基于逻辑日志的复制技术&#xff0c;需要传输大量的逻辑日志,数据库需启用日志模式 通过网络持续将数据复制到备节点 如果主服务器发生故障&#xff0c;那么备用服务…

熵与交叉熵详解

前言 本文隶属于专栏《机器学习数学通关指南》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见《机器学习数学通关指南》 ima 知识库 知识库广场搜索&#…

程序化广告行业(3/89):深度剖析行业知识与数据处理实践

程序化广告行业&#xff08;3/89&#xff09;&#xff1a;深度剖析行业知识与数据处理实践 大家好&#xff01;一直以来&#xff0c;我都希望能和各位技术爱好者一起在学习的道路上共同进步&#xff0c;分享知识、交流经验。今天&#xff0c;咱们聚焦在程序化广告这个充满挑战…

探索在生成扩散模型中基于RAG增强生成的实现与未来

概述 像 Stable Diffusion、Flux 这样的生成扩散模型&#xff0c;以及 Hunyuan 等视频模型&#xff0c;都依赖于在单一、资源密集型的训练过程中通过固定数据集获取的知识。任何在训练之后引入的概念——被称为 知识截止——除非通过 微调 或外部适应技术&#xff08;如 低秩适…

DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加列宽调整功能,示例Table14基础固定表头示例

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…