Day08 【基于jieba分词实现词嵌入的文本多分类】

基于jieba分词的文本多分类

      • 目标
      • 数据准备
      • 参数配置
      • 数据处理
      • 模型构建
      • 主程序
      • 测试与评估
      • 测试结果

目标

本文基于给定的词表,将输入的文本基于jieba分词分割为若干个词,然后将词基于词表进行初步编码,之后经过网络层,输出在已知类别标签上的概率分布,从而实现一个简单文本的多分类。

数据准备

词表文件chars.txt

类别标签文件schema.json

{"停机保号": 0,"密码重置": 1,"宽泛业务问题": 2,"亲情号码设置与修改": 3,"固话密码修改": 4,"来电显示开通": 5,"亲情号码查询": 6,"密码修改": 7,"无线套餐变更": 8,"月返费查询": 9,"移动密码修改": 10,"固定宽带服务密码修改": 11,"UIM反查手机号": 12,"有限宽带障碍报修": 13,"畅聊套餐变更": 14,"呼叫转移设置": 15,"短信套餐取消": 16,"套餐余量查询": 17,"紧急停机": 18,"VIP密码修改": 19,"移动密码重置": 20,"彩信套餐变更": 21,"积分查询": 22,"话费查询": 23,"短信套餐开通立即生效": 24,"固话密码重置": 25,"解挂失": 26,"挂失": 27,"无线宽带密码修改": 28
}

训练集数据train.json训练集数据

验证集数据valid.json验证集数据

参数配置

config.py

# -*- coding: utf-8 -*-"""
配置参数信息
"""Config = {"model_path": "model_output","schema_path": "../data/schema.json","train_data_path": "../data/train.json","valid_data_path": "../data/valid.json","vocab_path":"../chars.txt","max_length": 20,"hidden_size": 128,"epoch": 10,"batch_size": 32,"optimizer": "adam","learning_rate": 1e-3,
}

数据处理

loader.py

# -*- coding: utf-8 -*-import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader"""
数据加载
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.schema = load_schema(config["schema_path"])self.config["class_num"] = len(self.schema)self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for line in f:line = json.loads(line)#加载训练集if isinstance(line, dict):questions = line["questions"]label = line["target"]label_index = torch.LongTensor([self.schema[label]])for question in questions:input_id = self.encode_sentence(question)input_id = torch.LongTensor(input_id)self.data.append([input_id, label_index])else:assert isinstance(line, list)question, label = lineinput_id = self.encode_sentence(question)input_id = torch.LongTensor(input_id)label_index = torch.LongTensor([self.schema[label]])self.data.append([input_id, label_index])returndef encode_sentence(self, text):input_id = []if self.config["vocab_path"] == "words.txt":for word in jieba.cut(text):input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))else:for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))input_id = self.padding(input_id)return input_id#补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id):input_id = input_id[:self.config["max_length"]]input_id += [0] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]#加载字表或词表
def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = index + 1  #0留给padding位置,所以从1开始return token_dict#加载schema
def load_schema(schema_path):with open(schema_path, encoding="utf8") as f:return json.loads(f.read())#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("valid_tag_news.json", Config)print(dg[1])

主要实现一个自定义数据加载器 DataGenerator,用于加载和处理文本数据。它通过词汇表和标签映射将输入文本转化为索引序列,并进行补齐或截断。

模型构建

model.py

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.optim import Adam, SGD
"""
建立网络模型结构
"""class TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)self.layer = nn.Linear(hidden_size, hidden_size)self.classify = nn.Linear(hidden_size, class_num)self.pool = nn.AvgPool1d(max_length)self.activation = torch.relu     #relu做激活函数self.dropout = nn.Dropout(0.1)self.loss = nn.functional.cross_entropy  #loss采用交叉熵损失#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x)                #input shape:(batch_size, input_dim)if target is not None:return self.loss(predict, target.squeeze())else:return predictdef choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

定义了一个神经网络模型 TorchModel,继承自 nn.Module,用于文本分类任务。模型包括嵌入层、线性层、平均池化层和分类层,使用 ReLU 激活函数和 Dropout 防止过拟合。前向传播根据输入返回预测值或损失值(若提供标签)。choose_optimizer 函数根据配置选择 Adam 或 SGD 优化器,并设置学习率。模型通过交叉熵损失进行训练。

主程序

main.py

# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import loggingfrom config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data, load_schemalogging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""def main(config):#创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加载训练数据train_data = load_data(config["train_data_path"], config)#加载模型model = TorchModel(config)# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()#加载优化器optimizer = choose_optimizer(config, model)#加载效果测试类evaluator = Evaluator(config, model, logger)#训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况loss = model(input_id, labels)train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)loss.backward()# print(loss.item())# print(model.classify.weight.grad)optimizer.step()logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)return model, train_datadef ask(model, question):input_id = train_data.dataset.encode_sentence(question)model.eval()model = model.cpu()cls = torch.argmax(model(torch.LongTensor([input_id])))schemes = load_schema(Config["schema_path"])ans = ""for name, val in schemes.items():if val == cls:ans = namereturn ansif __name__ == "__main__":model, train_data = main(Config)print(ask(model, "积分是怎么积的"))while True:question = input("请输入问题:")res = ask(model, question)print("命中问题:", res)print("-----------")

实现一个基于 PyTorch 的文本分类模型的训练和推理过程。首先,通过 main 函数创建模型训练的主流程。代码首先检查是否有 GPU 可用,并将模型迁移至 GPU(如果可用)。然后加载训练数据、模型、优化器以及效果评估类。训练过程中,模型使用交叉熵损失函数计算训练误差并进行反向传播更新参数,每个 epoch 后记录并输出平均损失。同时,训练结束后,将模型保存至指定路径。

在训练完成后,ask 函数用于推理,输入问题并通过模型进行预测。它首先将输入问题转化为模型所需的格式,然后利用训练好的模型进行分类,最后返回匹配的答案。整个程序支持通过命令行输入问题,模型根据训练结果给出对应的答案。

在主程序中,首先进行一次初始化训练,之后进入循环,可以持续输入问题并得到模型的预测答案。

测试与评估

evaluate.py

# -*- coding: utf-8 -*-
import torch
from loader import load_data"""
模型效果测试
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)self.stats_dict = {"correct":0, "wrong":0}  #用于存储测试结果def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.stats_dict = {"correct":0, "wrong":0}  #清空前一轮的测试结果self.model.eval()for index, batch_data in enumerate(self.valid_data):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测self.write_stats(labels, pred_results)self.show_stats()returndef write_stats(self, labels, pred_results):assert len(labels) == len(pred_results)for true_label, pred_label in zip(labels, pred_results):pred_label = torch.argmax(pred_label)if int(true_label) == int(pred_label):self.stats_dict["correct"] += 1else:self.stats_dict["wrong"] += 1returndef show_stats(self):correct = self.stats_dict["correct"]wrong = self.stats_dict["wrong"]self.logger.info("预测集合条目总量:%d" % (correct +wrong))self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong))self.logger.info("预测准确率:%f" % (correct / (correct + wrong)))self.logger.info("--------------------")return

定义一个 Evaluator 类,用于评估深度学习模型在验证集上的表现。Evaluator 初始化时接受配置文件、模型和日志记录器,并加载验证数据。eval 方法用于进行模型评估,在每轮评估开始时清空统计信息,设置模型为评估模式,然后通过遍历验证数据集进行预测。预测结果通过 write_stats 方法与真实标签进行比对,统计正确和错误的预测条目。最后,show_stats 方法输出总预测条目数、正确条目数、错误条目数以及准确率。该类的作用是帮助监控模型在验证集上的性能,便于调整和优化模型。

测试结果

请输入问题:在官网上如何修改移动密码
命中问题: 移动密码修改
-----------
请输入问题:我想多加一个号码作为亲情号
命中问题: 亲情号码设置与修改
-----------
请输入问题:我已经交足了话费请立即帮我开机
命中问题: 话费查询
-----------
请输入问题:密码想换一下
命中问题: 密码修改

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/901640.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

入门-C编程基础部分:6、常量

飞书文档https://x509p6c8to.feishu.cn/wiki/MnkLwEozRidtw6kyeW9cwClbnAg C 常量 常量是固定值,在程序执行期间不会改变,可以让我们编程更加规范。 常量可以是任何的基本数据类型,比如整数常量、浮点常量、字符常量,或字符串字…

第二阶段:数据结构与函数

模块4:常用数据结构 (Organizing Lots of Data) 在前面的模块中,我们学习了如何使用变量来存储单个数据,比如一个数字、一个名字或一个布尔值。但很多时候,我们需要处理一组相关的数据,比如班级里所有学生的名字、一本…

【C++算法】61.字符串_最长公共前缀

文章目录 题目链接:题目描述:解法C 算法代码:解释 题目链接: 14. 最长公共前缀 题目描述: 解法 解法一:两两比较 先算前两个字符串的最长公共前缀,然后拿这个最长公共前缀和后面一个来比较&…

JVM 调优不再难:AI 工具自动生成内存优化方案

在 Java 应用程序的开发与运行过程中,Java 虚拟机(JVM)的性能调优一直是一项极具挑战性的任务,尤其是内存优化方面。不合适的 JVM 内存配置可能会导致应用程序出现性能瓶颈,甚至频繁抛出内存溢出异常,影响业…

纷析云开源财务软件:企业财务数字化转型的灵活解决方案

纷析云是一家专注于开源财务软件研发的公司,自2018年成立以来,始终以“开源开放”为核心理念,致力于通过技术创新助力企业实现财务管理的数字化与智能化转型。其开源财务软件凭借高扩展性、灵活部署和全面的功能模块,成为众多企业…

【数字图像处理】数字图像空间域增强(3)

图像锐化 图像细节增强 图像轮廓:灰度值陡然变化的部分 空间变化:计算灰度变化程度 图像微分法:微分计算灰度梯度突变的速率 一阶微分:单向差值 二阶微分:双向插值 一阶微分滤波 1:梯度法 梯度&#xff1…

基于Linux的ffmpeg python的关键帧抽取

1.FFmpeg的环境配置 首先强调,ffmpeg-python包与ffmpeg包不一样。 1) 创建一个虚拟环境env conda create -n yourenv python3.x conda activate yourenv2) ffmpeg-python包的安装 pip install ffmpeg-python3) 安装系统级别的 FFmpeg 工具 虽然安装了 ffmpeg-p…

C#进阶学习(四)单向链表和双向链表,循环链表(上)单向链表

目录 前置知识: 一、链表中的结点类LinkedNode 1、申明字段节点类: 2、申明属性节点类: 二、两种方式实现单向链表 ①定框架: ②增加元素的方法:因为是单链表,所以增加元素一定是只能在末尾添加元素,…

RK3588 Buildroot 串口测试工具

RK3588 Buildroot串口测试工具(含代码) 一、引言 1.1 目的 本文档旨在指导开发人员能快速测试串口功能 1.2 适用范围 本文档适用于linux 系统串口测试。 二、开发环境准备 2.1 硬件环境 开发板:RK3588开发板,确保其串口硬件连接正常,具备电源供应、调试串口等基本硬…

HOJ PZ

https://docs.hdoi.cn/deploy 单体部署 请到~/hoj-deploy/standAlone的目录下,即是与docker-compose.yml的文件同个目录下,该目录下有个叫hoj的文件夹,里面的文件夹介绍如下: hoj ├── file # 存储了上传的图片、上传的临…

EtherCAT 的优点与缺点

EtherCAT(以太网控制自动化技术)是一种高性能的工业以太网协议,广泛应用于实时自动化控制。以下是其核心优缺点分析: ​一、EtherCAT 的核心优点​ 1. ​超低延迟 & 高实时性​ ​原理​:采用"​Processing…

高并发多级缓存架构实现思路

目录 1.整体架构 3.安装环境 1.1 使用docket安装redis 1.2 配置redis缓存链接: 1.3 使用redisTemplate实现 1.4 缓存注解优化 1.4.1 常用缓存注解简绍 1.4.2 EnableCaching注解的使用 1.4.3使用Cacheable 1.4.4CachePut注解的使用 1.4.5 优化 2.安装Ngin…

Qt QML实现Windows桌面颜色提取器

前言 实现一个简单的小工具,使用Qt QML实现Windows桌面颜色提取器,实时显示鼠标移动位置的颜色值,包括十六进制值和RGB值。该功能在实际应用中比较常见,比如截图的时候,鼠标移动就会在鼠标位置实时显示坐标和颜色值&a…

vue3+vite 多个环境配置

同一套代码 再也不用在不同的环境里来回切换请求地址了 然后踩了一个坑 就是env的文件路径是在当前项目下 不是在views内 因为公司项目需求只有dev和pro两个环境 虽然我新增了3个 但是只在这两个里面配置了 .env是可以配置一些公共配置的 目前需求来说不需要 所以我也懒得配了。…

AI赋能PLC(一):三菱FX-3U编程实战初级篇

前言 在工业自动化领域,三菱PLC以其高可靠性、灵活性和广泛的应用场景,成为众多工程师的首选控制设备。然而,传统的PLC编程往往需要深厚的专业知识和经验积累,开发周期长且调试复杂。随着人工智能技术的快速发展,利用…

XSS 跨站Cookie 盗取表单劫持网络钓鱼溯源分析项目平台框架

漏洞原理:接受输入数据,输出显示数据后解析执行 基础类型:反射 ( 非持续 ) ,存储 ( 持续 ) , DOM-BASE 拓展类型: jquery , mxss , uxss , pdfxss , flashx…

鸿蒙应用(医院诊疗系统)开发篇2·Axios网络请求封装全流程解析

一、项目初始化与环境准备 1. 创建鸿蒙工程 src/main/ets/ ├── api/ │ ├── api.ets # 接口聚合入口 │ ├── login.ets # 登录模块接口 │ └── request.ets # 网络请求核心封装 └── pages/ └── login.ets # 登录页面逻辑…

ADAS高级驾驶辅助系统详细介绍

ADAS(高级驾驶辅助系统)核心模块,通过 “监测→预警→干预” 三层逻辑提升行车安全。用户选择车辆时,可关注传感器配置(如是否标配毫米波雷达)、功能覆盖场景(如 AEB 是否支持夜间行人&#xff…

Prometheus+Grafana+K8s构建监控告警系统

一、技术介绍 Prometheus、Grafana及K8S服务发现详解 Prometheus简介 Prometheus是一个开源的监控系统和时间序列数据库,最初由SoundCloud开发,现已成为CNCF(云原生计算基金会)的毕业项目‌。它专注于实时监控和告警,特别适合云原生和分布式…

MATLAB脚本实现了一个三自由度的通用航空运载器(CAV-H)的轨迹仿真,主要用于模拟升力体在不同飞行阶段(初始滑翔段、滑翔段、下压段)的运动轨迹

%升力体:通用航空运载器CAV-H %读取数据1 升力系数 alpha = [10 15 20]; Ma = [3.5 5 8 10 15 20 23]; alpha1 = 10:0.1:20; Ma1 = 3.5:0.1:23; [Ma1, alpha1] = meshgrid(Ma1, alpha1); CL = readmatrix(simulation.xlsx, Sheet, Sheet1, Range, B2:H4); CL1 = interp2(…