pytorch使用SVM实现文本分类

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1]  # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2)  # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器,调整学习率# 7. 训练模型
num_epochs = 50  # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval()  # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")

1. 数据准备

  • 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
  • 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。

2. 数据预处理

  • 中文分词:使用 jieba 库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。
  • TF-IDF 特征提取:使用 TfidfVectorizer 将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。

3. 数据集分割

  • 使用 train_test_split 将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。

4. PyTorch 数据加载

  • 定义 Dataset 类:创建了一个自定义的 NewsGroupDataset 类,继承自 torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。
  • DataLoader:使用 DataLoader 将训练集和测试集数据进行批处理和加载。

5. 模型定义

  • 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。
  • 使用了 CrossEntropyLoss 作为损失函数,因为这是分类任务中常用的损失函数。
  • 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。

6. 训练过程

  • 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
  • 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。

7. 测试和评估

  • 在测试过程中,将模型设置为评估模式 (model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。
  • 使用 classification_report 来输出精确度、召回率和 F1 分数等更多的评估指标。

8. 预测新样本

  • 定义了一个 predict() 函数,用于预测新的文本样本的分类。
  • 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。

9. 输出

  • 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。

关键技术点:

  • 中文分词:使用 jieba 对中文文本进行分词处理,这对于中文文本的处理至关重要。
  • TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
  • 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
  • PyTorch DataLoader:通过 DataLoader 高效地处理训练集和测试集,进行批处理和自动化管理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69528.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【硬件介绍】三极管工作原理(图文+典型电路设计)

什么是三极管? 三极管,全称为双极型晶体三极管,是一种广泛应用于电子电路中的半导体器件。它是由三个掺杂不同的半导体材料区域组成的,这三个区域分别是发射极(E)、基极(B)和集电极&…

51单片机开发:串口通信

实验目标:电脑通过串口将数据发送给51单片机,单片机原封不动地将数据通过串口返送给电脑。 串口的内部结构如下图所示: 串口配置如下: TMOD | 0X20 ; //设置计数器工作方式 2 SCON 0X50 ; //设置为工作方式 1 PCON 0X80 ; …

DeepSeek-R1本地部署笔记

文章目录 效果概要下载 ollama终端下载模型【可选】浏览器插件 UIQ: 内存占用高,显存占用不高,正常吗 效果 我的配置如下 E5 2666 V3 AMD 590Gme 可以说是慢的一批了,内存和显卡都太垃圾了,回去用我的新设备再试试 概要 安装…

Linux 命令之技巧(Tips for Linux Commands)

Linux 命令之技巧 简介 Linux ‌是一种免费使用和自由传播的类Unix操作系统,其内核由林纳斯本纳第克特托瓦兹(Linus Benedict Torvalds)于1991年10月5日首次发布。Linux继承了Unix以网络为核心的设计思想,是一个性能稳定的多用户…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》029-组件的数据注入

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主&…

deepseek-r1 本地部署

deepseek 最近太火了 1&#xff1a;环境 win10 cpu 6c 内存 16G 2: 部署 1>首先下载ollama 官网&#xff1a;https://ollama.com ollama 安装在c盘 模型可以配置下载到其他盘 OLLAMA_MODELS D:\Ollama 2>下载模型并运行 ollama run deepseek-r1:<标签> 1.5b 7b 8…

租赁系统为企业资产管理提供高效解决方案促进业务增长与创新

内容概要 在现代商业环境中&#xff0c;企业不断寻求高效的管理解决方案&#xff0c;以提高运营效率、降低成本并推动业务增长。而租赁系统正是一款理想的工具&#xff0c;能够帮助企业实现这一目标。 快鲸智慧园区(楼宇)管理系统作为数字化资产管理的领先选择&#xff0c;提供…

如何写美赛(MCM/ICM)论文中的Summary部分

美赛(MCM/ICM)作为一个数学建模竞赛,要求参赛者在有限的时间内解决一个复杂的实际问题,并通过数学建模、数据分析和计算机模拟等手段给出有效的解决方案。在美赛的论文中,Summary部分(通常也称为摘要)是非常关键的,它是整个论文的缩影,能让评审快速了解你解决问题的思…

Nginx 安装配置指南

Nginx 安装配置指南 引言 Nginx 是一款高性能的 HTTP 和反向代理服务器&#xff0c;同时也可以作为 IMAP/POP3/SMTP 代理服务器。由于其稳定性、丰富的功能集以及低资源消耗而被广泛应用于各种场景。本文将为您详细介绍 Nginx 的安装与配置过程。 系统要求 在安装 Nginx 之…

Direct2D 极速教程(2) —— 画淳平

极速导航 创建新项目&#xff1a;002-DrawJunpeiWIC 是什么用 WIC 加载图片画淳平 创建新项目&#xff1a;002-DrawJunpei 右键解决方案 -> 添加 -> 新建项目 选择"空项目"&#xff0c;项目名称为 “002-DrawJunpei”&#xff0c;然后按"创建" 将 “…

自然语言处理——从原理、经典模型到应用

1. 概述 自然语言处理&#xff08;Natural Language Processing&#xff0c;NLP&#xff09;是一门借助计算机技术研究人类语言的科学&#xff0c;是人工智能领域的一个分支&#xff0c;旨在让计算机理解、生成和处理人类语言。其核心任务是将非结构化的自然语言转换为机器可以…

【2025年数学建模美赛F题】(顶刊论文绘图)模型代码+论文

全球网络犯罪与网络安全政策的多维度分析及效能评估 摘要1 Introduction1.1 Problem Background1.2Restatement of the Problem1.3 Literature Review1.4 Our Work 2 Assumptions and Justifications数据完整性与可靠性假设&#xff1a;法律政策独立性假设&#xff1a;人口统计…

06-AD向导自动创建P封装(以STM32-LQFP48格式为例)

自动向导创建封装 自动向导创建封装STM32-LQFP48Pin封装1.选则4排-LCC或者QUAD格式2.计算焊盘相定位长度3.设置默认引脚位置(芯片逆时针)4.特殊情况下:加额外的标记 其他问题测量距离:Ctrl M测量 && Ctrl C清除如何区分一脚和其他脚?芯片引脚是逆时针看的? 自动向导…

MATLAB基础应用精讲-【数模应用】迭代扩展卡尔曼滤波(IEKF)(附MATLAB、python和C语言代码实现)

目录 前言 几个高频面试题目 卡尔曼滤波和扩展卡尔曼滤波的区别? 算法原理 卡尔曼滤波 数据融合 数学模型 KF计算公式 KF使用说明 尔曼滤波案例——多目标跟踪 卡尔曼滤波器——预测阶段 卡尔曼滤波器——更新阶段 扩展卡尔曼滤波 EKF EKF计算公式 EKF迭代过程 …

【Linux探索学习】第二十七弹——信号(一):Linux 信号基础详解

Linux学习笔记&#xff1a; https://blog.csdn.net/2301_80220607/category_12805278.html?spm1001.2014.3001.5482 前言&#xff1a; 前面我们已经将进程通信部分讲完了&#xff0c;现在我们来讲一个进程部分也非常重要的知识点——信号&#xff0c;信号也是进程间通信的一…

微服务网关鉴权之sa-token

目录 前言 项目描述 使用技术 项目结构 要点 实现 前期准备 依赖准备 统一依赖版本 模块依赖 配置文件准备 登录准备 网关配置token解析拦截器 网关集成sa-token 配置sa-token接口鉴权 配置satoken权限、角色获取 通用模块配置用户拦截器 api模块配置feign…

Java基于SSM框架的互助学习平台小程序【附源码、文档】

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

vue中的el是指什么

简介&#xff1a; 在Vue.js中&#xff0c;el指的是Vue实例的挂载元素。 具体来说&#xff0c;el是一个选项&#xff0c;用于指定Vue实例应该挂载到哪个DOM元素上。通过这个选项&#xff0c;Vue可以知道应该从哪个元素开始进行模板编译和渲染。它可以是一个CSS选择器字符串&…

实战纪实 | 真实HW漏洞流量告警分析

视频教程在我主页简介和专栏里 目录&#xff1a; 一、web.xml 文件泄露 二、Fastjson 远程代码执行漏洞 三、hydra工具爆破 四、绕过验证&#xff0c;SQL攻击成功 五、Struts2代码执行 今年七月&#xff0c;我去到了北京某大厂参加HW行动&#xff0c;因为是重点领域—-jr&…

WSL安装CUDA

WSL安装CUDA 参考文档&#xff1a; ​ 总安装文档&#xff1a;https://docs.nvidia.com/cuda/cuda-installation-guide-linux/#wsl-installation 1. 下载cuda ​ 进入下载界面&#xff1a;https://developer.nvidia.com/cuda-downloads?target_osLinux&target_archx86_…