pytorch基于GloVe实现的词嵌入

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


 1. GloVe 介绍

基于词的共现信息(不像 Word2Vec 使用滑动窗口预测)
 适合较大规模的数据(比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 计算共现矩阵
window_size = 2
co_occurrence = Counter()for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):center_word = indices[center_idx]for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context_word = indices[context_idx]co_occurrence[(center_word, context_word)] += 1# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):def __init__(self, vocab_size, embedding_dim):super(GloVe, self).__init__()self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置nn.init.xavier_uniform_(self.w_embeddings.weight)nn.init.xavier_uniform_(self.c_embeddings.weight)def forward(self, center, context, co_occur):w_emb = self.w_embeddings(center)c_emb = self.c_embeddings(context)w_bias = self.w_bias(center).squeeze()c_bias = self.c_bias(context).squeeze()dot_product = (w_emb * c_emb).sum(dim=1)loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2return loss.mean()# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))for epoch in range(num_epochs):total_loss = 0np.random.shuffle(pairs)for center, context, co_occur in pairs:optimizer.zero_grad()loss = model(torch.tensor([center], dtype=torch.long),torch.tensor([context], dtype=torch.long),torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型)loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = np.dot(word_vectors, word_vec.T).squeeze()similar_idx = similarities.argsort()[::-1][1:top_n + 1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词(使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

 计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ 堆栈分配的区别

这两种声明方式有什么区别 1.使用 new 关键字动态分配内存 动态分配&#xff1a;使用 new 关键字会在堆&#xff08;heap&#xff09;上分配内存&#xff0c;并返回一个指向该内存位置的指针。生命周期&#xff1a;对象的生命周期不会随着声明它的作用域结束而结束&#xff0…

深入解析 Linux 内核中的页面错误处理机制

在现代操作系统中,页面错误(Page Fault)是内存管理的重要组成部分。当程序试图访问未映射到物理内存的虚拟内存地址时,CPU 会触发页面错误异常。Linux 内核通过一系列复杂的机制来处理这些异常,确保系统的稳定性和性能。本文将深入解析 Linux 内核中处理页面错误的核心代码…

MATLAB-Simulink并行仿真示例

一、概述 在进行simulink仿真的过程中常常遇到CPU利用率较低&#xff0c;仿真缓慢的情况&#xff0c;可以借助并行仿真改善这些问题&#xff0c;其核心思想是将参数扫描、蒙特卡洛分析或多工况验证等任务拆分成多个子任务&#xff0c;利用多核CPU或计算集群的并行计算能力&…

Workbench 中的热源仿真

探索使用自定义工具对移动热源进行建模及其在不同行业中的应用。 了解热源动力学 对移动热源进行建模为各种工业过程和应用提供了有价值的见解。激光加热和材料加工使用许多激光束来加热、焊接或切割材料。尽管在某些情况下&#xff0c;热源 &#xff08;q&#xff09; 不是通…

I2C基础知识

引言 这里祝大家新年快乐&#xff01;前面我们介绍了串口通讯协议&#xff0c;现在我们继续来介绍另一种常见的简单的串行通讯方式——I2C通讯协议。 一、什么是I2C I2C 通讯协议&#xff08;Inter-Integrated Circuit&#xff09;是由Phiilps公司在上个世纪80年代开发的&#…

深度学习 DAY3:NLP发展史

NLP发展史 NLP发展脉络简要梳理如下&#xff1a; (远古模型&#xff0c;上图没有但也可以算NLP&#xff09; 1940 - BOW&#xff08;无序统计模型&#xff09; 1950 - n-gram&#xff08;基于词序的模型&#xff09; (近代模型&#xff09; 2001 - Neural language models&am…

CSS 背景与边框:从基础到高级应用

CSS 背景与边框&#xff1a;从基础到高级应用 1. CSS 背景样式1.1 背景颜色示例代码&#xff1a;设置背景颜色 1.2 背景图像示例代码&#xff1a;设置背景图像 1.3 控制背景平铺行为示例代码&#xff1a;控制背景平铺 1.4 调整背景图像大小示例代码&#xff1a;调整背景图像大小…

HarmonyOS简介:应用开发的机遇、挑战和趋势

问题 更多的智能设备并没有带来更好的全场景体验 连接步骤复杂数据难以互通生态无法共享能力难以协同 主要挑战 针对不同设备上的不同操作系统&#xff0c;重复开发&#xff0c;维护多套版本 多种语言栈&#xff0c;对人员技能要求高 多种开发框架&#xff0c;不同的编程…

【Linux】列出所有连接的 WiFi 网络的密码

【Linux】列出所有连接的 WiFi 网络的密码 终端输入 sudo grep psk /etc/NetworkManager/system-connections/*会列出所有连接过 Wifi 的信息&#xff0c;格式类似 /etc/NetworkManager/system-connections/AAAAA.nmconnection:pskBBBBBAAAAA 是 SSID&#xff0c;BBBBB 是对…

如何使用tushare pro获取股票数据——附爬虫代码以及tushare积分获取方式

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据 总结 一、Tushare 介绍 Tushare 是一个提供中国股市数据的API接口服务&#xff0c;它允许用户…

观察者模式和订阅发布模式的关系

有人把观察者模式等同于发布订阅模式&#xff0c;也有人认为这两种模式存在差异&#xff0c;本质上就是调度的方法不同。 发布订阅模式: 观察者模式: 相比较&#xff0c;发布订阅将发布者和观察者之间解耦。&#xff08;发布订阅有调度中心处理&#xff09;

linux 环境安装 dlib 的 gpu 版本

默认使用 pip 安装的 dlib 是不使用 gpu 的 在国内社区用百度查如何安装 gpu 版本的 dlib 感觉信息都不太对&#xff0c;都是说要源码编译还有点复杂 还需要自己安装 cuda 相关的包啥的&#xff0c;看着就头大 于是想到这个因该 conda 自己就支持了吧&#xff0c;然后查了一下…

【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(二)

目录 1 -> HML语法 1.1 -> 页面结构 1.2 -> 数据绑定 1.3 -> 普通事件绑定 1.4 -> 冒泡事件绑定5 1.5 -> 捕获事件绑定5 1.6 -> 列表渲染 1.7 -> 条件渲染 1.8 -> 逻辑控制块 1.9 -> 模板引用 2 -> CSS语法 2.1 -> 尺寸单位 …

三路排序算法

三路排序算法 引言 排序算法是计算机科学中基础且重要的算法之一。在数据分析和处理中&#xff0c;排序算法的效率直接影响着程序的执行速度和系统的稳定性。本文将深入探讨三路排序算法&#xff0c;包括其原理、实现和应用场景。 一、三路排序算法的原理 三路排序算法是一…

Python的那些事第五篇:数据结构的艺术与应用

新月人物传记&#xff1a;人物传记之新月篇-CSDN博客 目录 一、列表&#xff08;List&#xff09;&#xff1a;动态的容器 二、元组&#xff08;Tuple&#xff09;&#xff1a;不可变的序列 三、字典&#xff08;Dict&#xff09;&#xff1a;键值对的集合 四、集合&#xf…

【AI】DeepSeek 概念/影响/使用/部署

在大年三十那天&#xff0c;不知道你是否留意到&#xff0c;“deepseek”这个词出现在了各大热搜榜单上。这引起了我的关注&#xff0c;出于学习的兴趣&#xff0c;我深入研究了一番&#xff0c;才有了这篇文章的诞生。 概念 那么&#xff0c;什么是DeepSeek&#xff1f;首先百…

MapReduce简单应用(一)——WordCount

目录 1. 执行过程1.1 分割1.2 Map1.3 Combine1.4 Reduce 2. 代码和结果2.1 pom.xml中依赖配置2.2 工具类util2.3 WordCount2.4 结果 参考 1. 执行过程 假设WordCount的两个输入文本text1.txt和text2.txt如下。 Hello World Bye WorldHello Hadoop Bye Hadoop1.1 分割 将每个文…

Dest1ny漏洞库:用友 U8 Cloud ReleaseRepMngAction SQL 注入漏洞(CNVD-2024-33023)

大家好&#xff0c;今天是Dest1ny漏洞库的专题&#xff01;&#xff01; 会时不时发送新的漏洞资讯&#xff01;&#xff01; 大家多多关注&#xff0c;多多点赞&#xff01;&#xff01;&#xff01; 0x01 产品简介 用友U8 Cloud是用友推出的新一代云ERP&#xff0c;主要聚…

使用where子句筛选记录

默认情况下,SearchCursor将返回一个表或要素类的所有行.然而在很多情况下,常常需要某些条件来限制返回行数. 操作方法: 1.打开IDLE,加载先前编写的SearchCursor.py脚本 2.添加where子句,更新SearchCursor()函数,查找记录中有<>文本的<>字段 with arcpy.da.Searc…

使用国内镜像加速器解决 Docker Hub 拉取镜像慢或被屏蔽的问题

一、问题背景 Docker Hub 是 Docker 默认的镜像仓库&#xff0c;但由于网络限制&#xff0c;国内用户直接拉取镜像可能面临以下问题&#xff1a; 下载速度极慢&#xff08;尤其是大镜像&#xff09;。连接超时或完全被屏蔽&#xff08;部分网络环境&#xff09;。依赖国外源的…