基于CBOW模型的词向量训练实战:从原理到PyTorch实现

基于CBOW模型的词向量训练实战:从原理到PyTorch实现

在自然语言处理(NLP)领域,词向量是将单词映射为计算机可处理的数值向量的重要方式。通过词向量,单词之间的语义关系能够以数学形式表达,为后续的文本分析、机器翻译、情感分析等任务奠定基础。本文将结合连续词袋模型(CBOW),详细介绍如何使用PyTorch训练词向量,并通过具体代码实现和分析训练过程。

一、CBOW模型原理简介

CBOW(Continuous Bag-of-Words)模型是一种用于生成词向量的神经网络模型,它基于上下文预测目标词。其核心思想是:给定一个目标词的上下文单词,通过模型预测该目标词。在训练过程中,模型会不断调整参数,使得预测结果尽可能接近真实的目标词,最终训练得到的词向量能够捕捉单词之间的语义关系。

例如,在句子 “People create programs to direct processes” 中,如果目标词是 “programs”,CBOW模型会利用其上下文单词 “People”、“create”、“to”、“direct” 来预测 “programs”。通过大量类似样本的训练,模型能够学习到单词之间的语义关联,从而生成有效的词向量。

二、代码实现与详细解析

下面我会逐行解释你提供的代码,此代码借助 PyTorch 实现了一个连续词袋模型(CBOW)来学习词向量。

1. 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange  # 显示进度条
import numpy as np
  • torch:PyTorch 深度学习框架的核心库。
  • torch.nn:用于构建神经网络的模块。
  • torch.nn.functional:提供了许多常用的函数,像激活函数等。
  • torch.optim:包含各种优化算法。
  • tqdmtrange:用于在训练过程中显示进度条。
  • numpy:用于处理数值计算和数组操作。

2. 定义上下文窗口大小和原始文本

CONTEXT_SIZE = 2
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules called a program. 
People create programs to direct processes. 
In effect,we conjure the spirits of the computer with our spells.""".split()
  • CONTEXT_SIZE:上下文窗口的大小,意味着在预测目标词时,会考虑其前后各 CONTEXT_SIZE 个单词。
  • raw_text:原始文本,将其按空格分割成单词列表。

3. 构建词汇表和索引映射

vocab = set(raw_text)  # 集合,词库,里面的内容独一无二(将文本中所有单词去重后得到的词汇表)
vocab_size = len(vocab)  # 词汇表的大小word_to_idx = {word: i for i, word in enumerate(vocab)}  # 单词到索引的映射字典
idx_to_word = {i: word for i, word in enumerate(vocab)}  # 索引到单词的映射字典
  • vocab:把原始文本中的所有单词去重后得到的词汇表。
  • vocab_size:词汇表的大小。
  • word_to_idx:将单词映射为对应的索引。
  • idx_to_word:将索引映射为对应的单词。

4. 构建训练数据集

data = []  # 获取上下文词,将上下文词作为输入,目标词作为输出,构建训练数据集(用于存储训练数据,每个元素是一个元组,包含上下文词列表和目标词)
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):context = ([raw_text[i - (2 - j)] for j in range(CONTEXT_SIZE)]+ [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)])  # 获取上下文词target = raw_text[i]  # 获取目标词data.append((context, target))  # 将上下文词和目标词保存到 data 中
  • data:用于存储训练数据,每个元素是一个元组,包含上下文词列表和目标词。
  • 通过循环遍历原始文本,提取每个目标词及其上下文词,然后将它们添加到 data 中。

5. 定义将上下文词转换为张量的函数

def make_context_vector(context, word_to_ix):  # 将上下词转换为 one - hotidxs = [word_to_ix[w] for w in context]return torch.tensor(idxs, dtype=torch.long)
  • make_context_vector:把上下文词列表转换为对应的索引张量。

6. 打印第一个上下文词的索引张量

print(make_context_vector(data[0][0], word_to_idx))
  • 打印第一个训练样本的上下文词对应的索引张量。

7. 定义 CBOW 模型

class CBOW(nn.Module):  # 神经网络def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))  # nn.relu() 激活层out = self.output(out)nll_prob = F.log_softmax(out, dim=1)return nll_prob
  • CBOW:继承自 nn.Module,定义了 CBOW 模型的结构。
    • __init__:初始化模型的层,包含一个嵌入层、一个线性层和另一个线性层。
    • forward:定义了前向传播过程,将输入的上下文词索引转换为嵌入向量,求和后经过线性层和激活函数,最后输出对数概率。

8. 选择设备并创建模型实例

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化
model = CBOW(vocab_size, 10).to(device)
  • device:检查当前设备是否支持 GPU(CUDA 或 MPS),若支持则使用 GPU,否则使用 CPU。
  • model:创建 CBOW 模型的实例,并将其移动到指定设备上。

9. 定义优化器、损失函数和损失列表

optimizer = optim.Adam(model.parameters(), lr=0.001)  # 创建一个优化器,
losses = []  # 存储损失的集合
loss_function = nn.NLLLoss()
  • optimizer:使用 Adam 优化器来更新模型的参数。
  • losses:用于存储每个 epoch 的损失值。
  • loss_function:使用负对数似然损失函数。

10. 训练模型

model.train()for epoch in tqdm(range(200)):total_loss = 0for context, target in data:context_vector = make_context_vector(context, word_to_idx).to(device)target = torch.tensor([word_to_idx[target]]).to(device)# 开始向前传播train_predict = model(context_vector)loss = loss_function(train_predict, target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()losses.append(total_loss)print(losses)
  • model.train():将模型设置为训练模式。
  • 通过循环进行 200 个 epoch 的训练,每个 epoch 遍历所有训练数据。
    • 将上下文词和目标词转换为张量并移动到指定设备上。
    • 进行前向传播得到预测结果。
    • 计算损失。
    • 进行反向传播并更新模型参数。
    • 累加每个 epoch 的损失值。

11. 进行预测

context = ['People', 'create', 'to', 'direct']
context_vector = make_context_vector(context, word_to_idx).to(device)model.eval()  # 进入到测试模式
predict = model(context_vector)
max_idx = predict.argmax(1)  # dim = 1 表示每一行中的最大值对应的索引号, dim = 0 表示每一列中的最大值对应的索引号print("CBOW embedding weight =", model.embeddings.weight)  # GPU
W = model.embeddings.weight.cpu().detach().numpy()
print(W)
  • 选择一个上下文词列表进行预测。
  • model.eval():将模型设置为评估模式。
  • 进行预测并获取预测结果中概率最大的索引。
  • 打印嵌入层的权重,并将其转换为 NumPy 数组。

12. 构建词向量字典

word_2_vec = {}
for word in word_to_idx.keys():word_2_vec[word] = W[word_to_idx[word], :]
print('jiesu')
  • word_2_vec:将每个单词映射到其对应的词向量。

13. 保存和加载词向量

np.savez('word2vec实现.npz', file_1 = W)
data = np.load('word2vec实现.npz')
print(data.files)
  • np.savez:将词向量保存为 .npz 文件。
  • np.load:加载保存的 .npz 文件,并打印文件中的数组名称。

综上所述,这段代码实现了一个简单的 CBOW 模型来学习词向量,并将学习到的词向量保存到文件中。 。运行结果
在这里插入图片描述

三、总结

通过上述代码的实现和分析,我们成功地使用CBOW模型在PyTorch框架下完成了词向量的训练。从数据准备、模型定义,到训练和测试,再到词向量的保存,每一个步骤都紧密相连,共同构建了一个完整的词向量训练流程。

CBOW模型通过上下文预测目标词的方式,能够有效地学习到单词之间的语义关系,生成的词向量可以应用于各种自然语言处理任务。在实际应用中,我们还可以通过调整模型的超参数(如词向量维度、上下文窗口大小、训练轮数等),以及使用更大规模的数据集,进一步优化词向量的质量和模型的性能。希望本文的内容能够帮助读者更好地理解CBOW模型和词向量训练的原理与实践。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux——进程终止/等待/替换

前言 本章主要对进程终止,进程等待,进程替换的详细认识,根据实验去理解其中的原理,干货满满! 1.进程终止 概念:进程终止就是释放进程申请的内核数据结构和对应的代码和数据 进程退出的三种状态 代码运行…

iOS开发架构——MVC、MVP和MVVM对比

文章目录 前言MVC(Model - View - Controller)MVP(Model - View - Presenter)MVVM(Model - View - ViewModel) 前言 在 iOS 开发中,MVC、MVVM、和 MVP 是常见的三种架构模式,它们主…

0506--01-DA

36. 单选题 在娱乐方式多元化的今天,“ ”是不少人(特别是中青年群体)对待戏曲的态度。这里面固然存在 的偏见、难以静下心来欣赏戏曲之美等因素,却也有另一个无法回避的原因:一些戏曲虽然与观众…

关于Java多态简单讲解

面向对象程序设计有三大特征,分别是封装,继承和多态。 这三大特性相辅相成,可以使程序员更容易用编程语言描述现实对象。 其中多态 多态是方法的多态,是通过子类通过对父类的重写,实现不同子类对同一方法有不同的实现…

【Trea】Trea国际版|海外版下载

Trea目前有两个版本,海外版和国内版。‌ Trae 版本差异 ‌大模型选择‌: ‌国内版‌:提供了字节自己的Doubao-1.5-pro以及DeepSeek的V3版本和R1版本。海外版:提供了ChartGPT以及Claude-3.5-Sonnet和3.7-Sonnt. ‌功能和界面‌&a…

Missashe考研日记-day33

Missashe考研日记-day33 1 专业课408 学习时间:2h30min学习内容: 今天开始学习OS最后一章I/O管理的内容,听了第一小节的内容,然后把课后习题也做了。知识点回顾: 1.I/O设备分类:按信息交换单位、按设备传…

链表的面试题3找出中间节点

来来来,接着继续我们的第三道题 。 解法 暴力求解 快慢指针 https://leetcode.cn/problems/middle-of-the-linked-list/submissions/ 这道题的话,思路是非常明确的,就是让你找出我们这个所谓的中间节点并且输出。 那这道题我们就需要注意…

linux磁盘介绍与LVM管理

一、磁盘基本概述 GPT是全局唯一标识分区表的缩写,是全局唯一标示磁盘分区表格式。而MBR则是另一种磁盘分区形式,它是主引导记录的缩写。相比之下,MBR比GPT出现得要更早一些。 MBR 与 GPT MBR 支持的磁盘最大容量为 2 TB,GPT 最大支持的磁盘容量为 18 EB,当前数据盘支持…

突破测试环境文件上传带宽瓶颈!React Native 阿里云 OSS 直传文件格式问题攻克二

上一篇我们对服务端和阿里云oss的配置及前端调用做了简单的介绍,但是一直报错。最终判断是文件格式问题,通常我们在reactnative中用formData上传, formData.append(file, {uri: file, name: nameType(type), type: multipart/form-data});这…

Spring Boot 中 @Bean 注解详解:从入门到实践

在 Spring Boot 开发中,Bean注解是一个非常重要且常用的注解,它能够帮助开发者轻松地将 Java 对象纳入 Spring 容器的管理之下,实现对象的依赖注入和生命周期管理。对于新手来说,理解并掌握Bean注解,是深入学习 Spring…

TCP 协议设计入门:自定义消息格式与粘包解决方案

目录 一、为什么需要自定义 TCP 协议? TCP粘包问题的本质 1.1 粘包与拆包的定义 1.2 粘包的根本原因 1.3 粘包的典型场景 二、自定义消息格式设计 2.1 协议结构设计 方案1:固定长度协议 方案2:分隔符标记法 方案3:长度前…

了解一下OceanBase中的表分区

OceanBase 是一个高性能的分布式关系型数据库,它支持 SQL 标准的大部分功能,包括分区表。分区表可以帮助管理大量数据,提高查询效率,通过将数据分散到不同的物理段中,可以减少查询时的数据扫描量。 在 OceanBase 中操…

多线程网络编程:粘包问题、多线程/多进程服务器实战与常见问题解析

多线程网络编程:粘包问题、多线程/多进程服务器实战与常见问题解析 一、TCP粘包问题:成因、影响与解决方案 1. 粘包问题本质 TCP是面向流的协议,数据传输时没有明确的消息边界,导致多个消息可能被合并(粘包&#xf…

大模型主干

1.什么是语言模型骨架LLM-Backbone,在多模态模型中的作用? 语言模型骨架(LLM Backbone)是多模态模型中的核心组件之一。它利用预训练的语言模型(如Flan-T5、ChatGLM、UL2等)来处理各种模态的特征,进行语义…

[创业之路-350]:光刻机、激光器、自动驾驶、具身智能:跨学科技术体系全景解析(光-机-电-材-热-信-控-软-网-算-智)

光刻机、激光器、自动驾驶、具身智能四大领域的技术突破均依赖光、机、电、材、热、信、控、软、网、算、智十一大学科体系的深度耦合。以下从技术原理、跨学科融合、关键挑战三个维度展开系统性分析: 一、光刻机:精密制造的极限挑战 1. 核心技术与学科…

SVTAV1 编码函数 svt_aom_is_pic_skipped

一 函数解释 1.1 svt_aom_is_pic_skipped函数的作用是判断当前图片是否可以跳过编码处理。 具体分析如下 函数逻辑 参数说明:函数接收一个指向图片父控制集的指针PictureParentControlSet *pcs, 通过这个指针可以获取与图片相关的各种信息,用于判断是否跳…

【Redis新手入门指南】从小白入门到日常使用(全)

文章目录 前言redis是什么?定义原理与特点与MySQL对比 Redis安装方式一、Homebrew 快速安装 Redis(推荐)方式二、源码编译安装redisHomebrew vs 源码安装对比 redis配置说明修改redis配置的方法常见redis配置项说明 redis常用命令redis服务启…

Linux grep 命令详解及示例大全

文章目录 一、基本语法二、常用选项及示例1. 基本匹配:查找包含某字符串的行2. 忽略大小写匹配 -i3. 显示行号 -n4. 递归查找目录下的文件 -r 或 -R5. 仅显示匹配的字符串 -o6. 使用正则表达式 -E(扩展)或 egrep7. 显示匹配前后行 -A, -B, -C…

【排序算法】快速排序(全坤式超详解)———有这一篇就够啦

【排序算法】——快速排序 目录 一:快速排序——思想 二:快速排序——分析 三:快速排序——动态演示图 四:快速排序——单趟排序 4.1:霍尔法 4.2:挖坑法 4.3:前后指针法 五:…

【platform push 提示 Invalid source ref: HEAD】

platform push 提示 Invalid source ref: HEAD 场景:环境:排查过程:解决: 场景: 使用platform push 命令行输入git -v 可以输出git 版本号,但就是提示Invalid source ref: HEAD,platform creat…