深度学习笔记40_中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、我的环境

1.语言环境:Python 3.8

2.编译器:Pycharm

3.深度学习环境:

  • torch==1.12.1+cu113
  • torchvision==0.13.1+cu113

、导入数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")import pandas as pd# 加载自定义中文数据
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
print(train_data.head())

结果:

                       0              1
0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query
1                从这里怎么回家   Travel-Query
2       随便播放一首专辑阁楼里的佛里的歌     Music-Play
3              给看一下墓王之王嘛  FilmTele-Play
4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play

、构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引print(vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频']))

结果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))
结果:[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
4

生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text, _label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和return text_list.to(device), label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)

定义模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,  # 嵌入的维度sparse=False)  #self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)  # 初始化权重self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置值归零def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

定义实例

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义训练函数与评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time = time.time()for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()  # grad属性归零loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真实值loss.backward()  # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text, label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc / total_count, train_loss / total_count

训练模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset# 超参数
EPOCHS = 10  # epoch
LR = 5  # 学习率
BATCH_SIZE = 64  # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc, val_loss, lr))print('-' * 69)

 结果:

Batch [50/152], Loss: 0.0340, Accuracy: 0.4203
Batch [100/152], Loss: 0.0235, Accuracy: 0.5851
Batch [150/152], Loss: 0.0309, Accuracy: 0.6572
---------------------------------------------------------------------
| epoch 1 | time: 0.55s | valid_acc 0.814 valid_loss 0.012 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0104, Accuracy: 0.8165
Batch [100/152], Loss: 0.0099, Accuracy: 0.8215
Batch [150/152], Loss: 0.0092, Accuracy: 0.8329
---------------------------------------------------------------------
| epoch 2 | time: 0.44s | valid_acc 0.855 valid_loss 0.008 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0068, Accuracy: 0.8790
Batch [100/152], Loss: 0.0065, Accuracy: 0.8778
Batch [150/152], Loss: 0.0064, Accuracy: 0.8809
---------------------------------------------------------------------
| epoch 3 | time: 0.44s | valid_acc 0.874 valid_loss 0.007 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0050, Accuracy: 0.9105
Batch [100/152], Loss: 0.0051, Accuracy: 0.9101
Batch [150/152], Loss: 0.0048, Accuracy: 0.9130
---------------------------------------------------------------------
| epoch 4 | time: 0.44s | valid_acc 0.882 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0039, Accuracy: 0.9366
Batch [100/152], Loss: 0.0039, Accuracy: 0.9339
Batch [150/152], Loss: 0.0038, Accuracy: 0.9350
---------------------------------------------------------------------
| epoch 5 | time: 0.44s | valid_acc 0.896 valid_loss 0.006 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0028, Accuracy: 0.9519
Batch [100/152], Loss: 0.0030, Accuracy: 0.9517
Batch [150/152], Loss: 0.0030, Accuracy: 0.9494
---------------------------------------------------------------------
| epoch 6 | time: 0.44s | valid_acc 0.898 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0025, Accuracy: 0.9580
Batch [100/152], Loss: 0.0024, Accuracy: 0.9616
Batch [150/152], Loss: 0.0024, Accuracy: 0.9609
---------------------------------------------------------------------
| epoch 7 | time: 0.44s | valid_acc 0.902 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0018, Accuracy: 0.9764
Batch [100/152], Loss: 0.0019, Accuracy: 0.9739
Batch [150/152], Loss: 0.0019, Accuracy: 0.9724
---------------------------------------------------------------------
| epoch 8 | time: 0.44s | valid_acc 0.900 valid_loss 0.005 | lr 5.000000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0015, Accuracy: 0.9810
Batch [100/152], Loss: 0.0014, Accuracy: 0.9817
Batch [150/152], Loss: 0.0014, Accuracy: 0.9818
---------------------------------------------------------------------
| epoch 9 | time: 0.49s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------
Batch [50/152], Loss: 0.0013, Accuracy: 0.9831
Batch [100/152], Loss: 0.0013, Accuracy: 0.9831
Batch [150/152], Loss: 0.0014, Accuracy: 0.9825
---------------------------------------------------------------------
| epoch 10 | time: 0.54s | valid_acc 0.906 valid_loss 0.005 | lr 0.500000
---------------------------------------------------------------------

、预测

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item()# ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"model = model.to("cpu")print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])
该文本的类别是:Travel-Query

总结: 

  1. 语料库(原始文本)‌:

    来源包括维基百科、网页文本、新闻资讯及内部文本。
  2. 文本清洗‌:

    清洗原始文本,包括去除标点符号和特殊字符。该流程主要用于将原始文本数据转化为可用于模型训练的数值化向量,再通过深度学习模型进行文本分类。
    • 分词‌:

      使用jieba分词工具对清洗后的文本进行分词处理。
    • 建模‌:

      采用不同的模型进行文本建模,包括循环神经网络(RNN)、卷积神经网络(CNN)、门控循环单元(GRU)和长短期记忆网络(LSTM)。
    • 文本向量化‌:

      将分词后的文本转换为向量表示,方法包括TF-IDF和Word2vec。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/79197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

010302-oss_反向代理_负载均衡-web扩展2-基础入门-网络安全

文章目录 1 OSS1.1 什么是 OSS 存储&#xff1f;1.2 OSS 核心功能1.3 OSS 的优势1.4 典型使用场景1.5 如何接入 OSS&#xff1f;1.6 注意事项1.7 cloudreve实战演示1.7.1 配置cloudreve连接阿里云oss1.7.2 常见错误1.7.3 安全测试影响 2 反向代理2.1 正向代理和反向代理2.2 演示…

【 Node.js】 Node.js安装

下载 下载 | Node.js 中文网https://nodejs.cn/download/ 安装 双击安装包 点击Next 勾选使用许可协议&#xff0c;点击Next 选择安装位置 点击Next 点击Next 点击Install 点击Finish 完成安装 添加环境变量 编辑【系统变量】下的变量【Path】添加Node.js的安装路径--如果…

Python基本语法(自定义函数)

自定义函数 Python语言没有子程序&#xff0c;只有自定义函数&#xff0c;目的是方便我们重复使用相同的一 段程序。将常用的代码块定义为一个函数&#xff0c;以后想实现相同的操作时&#xff0c;只要调用函数名就可以了&#xff0c;而不需要重复输入所有的语句。 函数的定义…

OpenGL-ES 学习(11) ---- EGL

目录 EGL 介绍EGL 类型和初始化EGL初始化方法获取 eglDisplay初始化 EGL选择 Config构造 Surface构造 Context开始绘制 EGL Demo EGL 介绍 OpenGL-ES 是一个操作GPU的图像API标准&#xff0c;它通过驱动向 GPU 发送相关图形指令&#xff0c;控制图形渲染管线状态机的运行状态&…

极简5G专网解决方案

极简5G专网解决方案 利用便携式即插即用私有 5G 网络提升您的智能创新。为您的企业提供无缝、安全且可扩展的 5G 解决方案。 提供极简5G专网解决方案 Mantiswave Network Private Limited 提供全面的 5G 专用网络解决方案&#xff0c;以满足您企业的独特需求。我们创新的“…

html:table表格

表格代码示例&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body><!-- 标准表格。 --><table border"5"cellspacing&qu…

tkinter 电子时钟 实现时间日期 可实现透明

以下是一个使用Tkinter模块创建一个简单的电子时钟并显示时间和日期的示例代码&#xff1a; import tkinter as tk import time# 创建主窗口 root tk.Tk() root.overrideredirect(True) # 隐藏标题栏 root.attributes(-alpha, 0.7) # 设置透明度# 显示时间的标签 time_labe…

【报错问题】 macOS 的安全策略(Gatekeeper)阻止了未签名的原生模块(bcrypt_lib.node)加载

这个错误是由于 macOS 的安全策略&#xff08;Gatekeeper&#xff09;阻止了未签名的原生模块&#xff08;bcrypt_lib.node&#xff09;加载 导致的。以下是具体解决方案&#xff1a; 1. 临时允许加载未签名模块&#xff08;推荐先尝试&#xff09; 在终端运行以下命令&#x…

AI实现制作logo的网站添加可选颜色模板

1.效果图 LogoPalette.jsx import React, {useState} from react import HeadingDescription from ./HeadingDescription import Lookup from /app/_data/Lookup import Colors from /app/_data/Colors function LogoPalette({onHandleInputChange}) { const [selectOptio…

云原生后端架构的挑战与应对策略

📝个人主页🌹:慌ZHANG-CSDN博客 🌹🌹期待您的关注 🌹🌹 随着云计算、容器化以及微服务等技术的快速发展,云原生架构已经成为现代软件开发和运维的主流趋势。企业通过构建云原生后端系统,能够实现灵活的资源管理、快速的应用迭代和高效的系统扩展。然而,尽管云原…

【C++】模板为什么要extern?

模板为什么要extern&#xff1f; 在 C 中&#xff0c;多个编译单元使用同一个模板时&#xff0c;是否可以不使用 extern 取决于模板的实例化方式&#xff08;隐式或显式&#xff09;&#xff0c;以及你对编译时间和二进制体积的容忍度。 1. 隐式实例化&#xff1a;可以不用 ex…

中小企业MES系统数据库设计

版本&#xff1a;V1.0 日期&#xff1a;2025年5月2日 一、数据库架构概览 1.1 数据库选型 数据类型数据库类型技术选型用途时序数据&#xff08;传感器读数&#xff09;时序数据库TimescaleDB存储设备实时监控数据结构化业务数据关系型数据库PostgreSQL工单、质量、设备等核心…

VUE篇之树形特殊篇

根节点是level:1, level3及其子节点有关联&#xff0c;但是和level2和他下面的子节点没有关联 思路&#xff1a;采用守护风琴效果&#xff0c;遍历出level1和level2级节点&#xff0c;后面level3的节点&#xff0c;采用树形结构进行关联 <template><div :class"…

洛圣电玩系列部署实录:一次自己从头跑通的搭建过程

写这篇文章不是为了“教大家怎么一步步安装”&#xff0c;而是想把我自己完整跑通洛圣电玩整个平台的经历复盘下来。因为哪怕你找到了所谓的全套源码资源&#xff0c;如果没人告诉你这些资源之间是怎么连起来的&#xff0c;你依旧是一脸懵逼。 我拿到的是什么版本&#xff1f; …

腾讯云web服务器配置步骤是什么?web服务器有什么用途?

腾讯云web服务器配置步骤是什么?web服务器有什么用途&#xff1f; Web服务器配置步骤&#xff08;以常见环境为例&#xff09; 1. 安装Web服务器软件 Linux系统&#xff08;如Ubuntu&#xff09; Apache: sudo apt update sudo apt install apache2 Nginx: sudo apt install…

第37课 绘制原理图——放置离页连接符

什么是离页连接符&#xff1f; 前边我们介绍了网络标签&#xff08;Net Lable&#xff09;&#xff0c;可以让两根导线“隔空相连”&#xff0c;使原理图更加清爽简洁。 但是网络标签的使用也具有一定的局限性&#xff0c;对于两张不同Sheet上的导线&#xff0c;网络标签就不…

Win下的Kafka安装配置

一、准备工作&#xff08;可以不做&#xff0c;毕竟最新版kafka也不需要zk&#xff09; 1、Windows下安装Zookeeper &#xff08;1&#xff09;官网下载Zookeeper 官网下载地址 &#xff08;2&#xff09;解压Zookeeper安装包到指定目录C:\DevelopApp\zookeeper\apache-zoo…

前端Vue3 + 后端Spring Boot,前端取消请求后端处理逻辑分析

在 Vue3 Spring Boot 的技术栈下&#xff0c;前端取消请求后&#xff0c;后端是否继续执行业务逻辑的答案仍然是 取决于请求处理的阶段 和 Spring Boot 的实现方式。以下是结合具体技术的详细分析&#xff1a; 1. 请求未到达 Spring Boot 场景&#xff1a;前端通过 AbortContr…

【蓝桥杯省赛真题58】Scratch画台扇 蓝桥杯scratch图形化编程 中小学生蓝桥杯省赛真题讲解

目录 scratch画台扇 一、题目要求 编程实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 四、程序编写 五、考点分析 六、推荐资料 1、scratch资料 2、python资料 3、C++资料 scratch画台扇 第十五届青少年蓝桥杯scratch编程省赛真题解析 …

GPT-4o 图像生成与八个示例指南

什么是GPT-4o图像生成&#xff1f; 简单来说&#xff0c;GPT-4o图像生成是集成在ChatGPT内部的一项功能。用户可以直接在对话中&#xff0c;通过文本描述&#xff08;Prompt&#xff09;来创建、编辑和调整图像。这与之前的图像生成工具相比&#xff0c;体验更流畅、交互性更强…