pytorch实现基于Word2Vec的词嵌入

PyTorch 实现 Word2Vec(Skip-gram 模型) 的完整代码,使用 中文语料 进行训练,包括数据预处理、模型定义、训练和测试


1. 主要特点

支持中文数据,基于 jieba 进行分词
使用 Skip-gram 进行训练,适用于小数据集
支持负采样,提升训练效率
使用 cosine similarity 计算相似单词

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import jieba
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 超参数
window_size = 2      # 窗口大小
embedding_dim = 10   # 词向量维度
num_epochs = 100     # 训练轮数
learning_rate = 0.01 # 学习率
batch_size = 4       # 批大小
neg_samples = 5      # 负采样个数# 分词 & 构建词汇表
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 统计词频
word_counts = Counter([word for sentence in tokenized_corpus for word in sentence])
total_words = sum(word_counts.values())# 计算负采样概率
word_freqs = {word: count / total_words for word, count in word_counts.items()}
word_powers = {word: freq ** 0.75 for word, freq in word_freqs.items()}
Z = sum(word_powers.values())
word_distribution = {word: prob / Z for word, prob in word_powers.items()}# 负采样函数
def negative_sampling(positive_word, num_samples=5):words = list(word_distribution.keys())probabilities = list(word_distribution.values())negatives = []while len(negatives) < num_samples:neg = np.random.choice(words, p=probabilities)if neg != positive_word:negatives.append(neg)return negatives# 生成 Skip-gram 训练数据
data = []
for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):center_word = indices[center_idx]for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context_word = indices[context_idx]data.append((center_word, context_word))# 转换为 PyTorch 张量
data = [(torch.tensor(center), torch.tensor(context)) for center, context in data]# ========== 2. 定义 Word2Vec (Skip-gram) 模型 ==========
class Word2Vec(nn.Module):def __init__(self, vocab_size, embedding_dim):super(Word2Vec, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.output_layer = nn.Linear(embedding_dim, vocab_size)def forward(self, center_word):embed = self.embedding(center_word)  # 获取中心词向量out = self.output_layer(embed)       # 计算词分布return out# 初始化模型
model = Word2Vec(len(vocab), embedding_dim)# ========== 3. 训练 Word2Vec ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):total_loss = 0random.shuffle(data)  # 每轮打乱数据for center_word, context_word in data:optimizer.zero_grad()output = model(center_word.unsqueeze(0))  # 预测词分布loss = criterion(output, context_word.unsqueeze(0))  # 计算损失loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 测试词向量 ==========
word_vectors = model.embedding.weight.data.numpy()# 计算单词相似度
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = cosine_similarity(word_vec, word_vectors)[0]# 获取相似度最高的 top_n 个单词(排除自身)similar_idx = similarities.argsort()[::-1][1:top_n+1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试相似词
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

数据预处理
  • 使用 jieba.cut() 进行分词
  • 创建 word2idxidx2word
  • 使用滑动窗口生成 (中心词, 上下文词) 训练样本
  • 实现 negative_sampling() 提高训练效率
模型
  • Embedding 学习词向量
  • Linear 计算单词的概率分布
  • CrossEntropyLoss 计算目标词与预测词的匹配度
  • 使用 Adam 进行梯度更新
计算词相似度
  • 使用 cosine_similarity 计算词向量相似度
  • 找出 top_n 个最相似的单词

 5. 可优化点

 使用更大的中文语料库(如 THUCNews
 使用 t-SNE 进行词向量可视化
增加负采样,提升模型训练效率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】_链表经典算法OJ(力扣/牛客第二弹)

目录 1. 题目1&#xff1a;返回倒数第k个节点 1.1 题目链接及描述 1.2 解题思路 1.3 程序 2. 题目2&#xff1a;链表的回文结构 2.1 题目链接及描述 2.2 解题思路 2.3 程序 1. 题目1&#xff1a;返回倒数第k个节点 1.1 题目链接及描述 题目链接&#xff1a; 面试题 …

pytorch基于 Transformer 预训练模型的方法实现词嵌入(tiansz/bert-base-chinese)

以下是一个完整的词嵌入&#xff08;Word Embedding&#xff09;示例代码&#xff0c;使用 modelscope 下载 tiansz/bert-base-chinese 模型&#xff0c;并通过 transformers 加载模型&#xff0c;获取中文句子的词嵌入。 from modelscope.hub.snapshot_download import snaps…

爬虫基础之爬取某站视频

目标网址:为了1/4螺口买小米SU7&#xff0c;开了一个月&#xff0c;它值吗&#xff1f;_哔哩哔哩_bilibili 本案例所使用到的模块 requests (发送HTTP请求)subprocess(执行系统命令)re (正则表达式操作)json (处理JSON数据) 需求分析: 视频的名称 F12 打开开发者工具 or 右击…

DeepSeek R1本地化部署 Ollama + Chatbox 打造最强 AI 工具

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 一&#xff1a;&#x1f525; Ollama &#x1f98b; 下载 Ollama&#x1f98b; 选择模型&#x1f98b; 运行模型&#x1f98b; 使用 && 测试 二&#xff1a;&#x1f525; Chat…

【linux网络(5)】传输层协议详解(下)

目录 前言1. TCP的超时重传机制2. TCP的流量控制机制3. TCP的滑动窗口机制4. TCP的拥塞控制机制5. TCP的延迟应答机制6. TCP的捎带应答机制7. 总结以及思考 前言 强烈建议先看传输层协议详解(上)后再看这篇文章. 上一篇文章讲到TCP协议为了保证可靠性而做的一些策略, 这篇文章…

DeepSeek 遭 DDoS 攻击背后:DDoS 攻击的 “千层套路” 与安全防御 “金钟罩”

当算力博弈升级为网络战争&#xff1a;拆解DDoS攻击背后的技术攻防战——从DeepSeek遇袭看全球网络安全新趋势 在数字化浪潮席卷全球的当下&#xff0c;网络已然成为人类社会运转的关键基础设施&#xff0c;深刻融入经济、生活、政务等各个领域。从金融交易的实时清算&#xf…

二、CSS笔记

(一)css概述 1、定义 CSS是Cascading Style Sheets的简称,中文称为层叠样式表,用来控制网页数据的表现,可以使网页的表现与数据内容分离。 2、要点 怎么找到标签怎么操作标签对象(element) 3、css的四种引入方式 3.1 行内式 在标签的style属性中设定CSS样式。这种方…

第三篇:模型压缩与量化技术——DeepSeek如何在边缘侧突破“小而强”的算力困局

——从算法到芯片的全栈式优化实践 随着AI应用向移动终端与物联网设备渗透&#xff0c;模型轻量化成为行业核心挑战。DeepSeek通过自研的“算法-编译-硬件”协同优化体系&#xff0c;在保持模型性能的前提下&#xff0c;实现参数量与能耗的指数级压缩。本文从技术原理、工程实…

C++编程语言:抽象机制:泛型编程(Bjarne Stroustrup)

泛型编程(Generic Programming) 目录 24.1 引言(Introduction) 24.2 算法和(通用性的)提升(Algorithms and Lifting) 24.3 概念(此指模板参数的插件)(Concepts) 24.3.1 发现插件集(Discovering a Concept) 24.3.2 概念与约束(Concepts and Constraints) 24.4 具体化…

DeepSeek-R1本地部署实践

一、下载安装 --Ollama Ollama是一个开源的 LLM&#xff08;大型语言模型&#xff09;服务工具&#xff0c;用于简化在本地运行大语言模型&#xff0c;降低使用大语言模型的门槛&#xff0c;使得大模型的开发者、研究人员和爱好者能够在本地环境快速实验、管理和部署最新大语言…

AI技术路线(marked)

人工智能&#xff08;AI&#xff09;是一个非常广泛且充满潜力的领域&#xff0c;它涉及了让计算机能够执行通常需要人类智能的任务&#xff0c;比如感知、推理、学习、决策等。人工智能的应用已经渗透到各行各业&#xff0c;从自动驾驶到医疗诊断&#xff0c;再到推荐系统和自…

【leetcode详解】T598 区间加法

598. 区间加法 II - 力扣&#xff08;LeetCode&#xff09; 思路分析 核心在于将问题转化&#xff0c; 题目不是要求最大整数本身&#xff0c;而是要求解最大整数的个数 结合矩阵元素的增加原理&#xff0c;我们将抽象问题转为可操作的方法&#xff0c;其实就是再找每组ops中…

【最后203篇系列】004 -Smarklink

说明 这个用来替代nginx。 最初是希望用nginx进行故障检测和负载均衡&#xff0c;花了很多时间&#xff0c;大致的结论是&#xff1a;nginx可以实现&#xff0c;但是是在商业版里。非得要找替代肯定可以搞出来&#xff0c;但是太麻烦了&#xff08;即使是nginx本身的配置也很烦…

完全卸载mysql server步骤

1. 在控制面板中卸载mysql 2. 打开注册表&#xff0c;运行regedit, 删除mysql信息 HKEY_LOCAL_MACHINE-> SYSTEM->CurrentContolSet->Services->EventLog->Application->Mysql HKEY_LOCAL_MACHINE-> SYSTEM->CurrentContolSet->Services->Mysql …

1. 【.NET Aspire 从入门到实战】--理论入门与环境搭建--引言

在当前软件开发领域&#xff0c;云原生和微服务架构已经成为主流趋势&#xff0c;传统的单体应用正逐步向分布式系统转型。随着业务需求的不断变化与用户规模的迅速扩大&#xff0c;如何在保证高可用、高扩展性的同时&#xff0c;还能提高开发效率与降低维护成本&#xff0c;成…

Ubuntu 22.04系统安装部署Kubernetes v1.29.13集群

Ubuntu 22.04系统安装部署Kubernetes v1.29.13集群 简介Kubernetes 的工作流程概述Kubernetes v1.29.13 版本Ubuntu 22.04 系统安装部署 Kubernetes v1.29.13 集群 1 环境准备1.1 集群IP规划1.2 初始化步骤&#xff08;各个节点都需执行&#xff09;1.2.1 主机名与IP地址解析1.…

基于SpringBoot的新闻资讯系统的设计与实现(源码+SQL脚本+LW+部署讲解等)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…

每日一题——包含min函数的栈

包含min函数的栈 题目数据范围&#xff1a;示例C语言代码实现解释1. push(value)2. pop()3. top()4. min() 总结大小堆 题目 定义栈的数据结构&#xff0c;请在该类型中实现一个能够得到栈中所含最小元素的 min 函数&#xff0c;输入操作时保证 pop、top 和 min 函数操作时&am…

RDP协议详解

以下内容包含对 RDP&#xff08;Remote Desktop Protocol&#xff0c;远程桌面协议&#xff09;及其开源实现 FreeRDP 的较为系统、深入的讲解&#xff0c;涵盖协议概要、历史沿革、核心原理、安全机制、安装与使用方法、扩展与未来发展趋势等方面&#xff0c; --- ## 一、引…

【Linux系统】计算机世界的基石:冯诺依曼架构与操作系统设计

文章目录 一.冯诺依曼体系结构1.1 为什么体系结构中要存在内存&#xff1f;1.2 冯诺依曼瓶颈 二.操作系统2.1 设计目的2.2 系统调用与库函数 一.冯诺依曼体系结构 冯诺依曼体系结构&#xff08;Von Neumann Architecture&#xff09;是计算机的基本设计理念之一&#xff0c;由…