pytorch逻辑回归实现垃圾邮件检测

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np# 增强的数据集:更多的垃圾邮件与正常邮件样本
X = ["Congratulations! You've won a $1000 gift card. Claim it now!","Dear friend, I hope you are doing well. Let's catch up soon.","Urgent: Your bank account has been compromised. Please contact support immediately.","Hello, just wanted to confirm our meeting at 2 PM today.","You have a new message from your friend. Click here to read.","Get a free iPhone now! Limited offer, click here.","Last chance to claim your prize, you won $500!","Meeting scheduled for tomorrow. Please confirm.","Hello! You are invited to an exclusive event!","Click here to get free lottery tickets. Hurry up!","Reminder: Your subscription will expire soon, renew now.","Don't forget to submit your report by end of day today."
]
y = [1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0]  # 1 为垃圾邮件,0 为正常邮件# 使用 TfidfVectorizer 进行文本向量化
vectorizer = TfidfVectorizer(stop_words='english')  # 去除停用词
X_vec = vectorizer.fit_transform(X).toarray()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_vec, y, test_size=0.33, random_state=42)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.fc = nn.Linear(input_dim, 1)  # 线性层,输入维度是特征的数量,输出是1def forward(self, x):return torch.sigmoid(self.fc(x))  # 使用sigmoid激活函数输出0到1之间的概率# 定义训练过程
def train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001):criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器X_train_tensor = torch.tensor(X_train, dtype=torch.float32)y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
def evaluate_model(model, X_test, y_test):model.eval()X_test_tensor = torch.tensor(X_test, dtype=torch.float32)y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)with torch.no_grad():outputs = model(X_test_tensor)predictions = (outputs >= 0.5).float()  # 阈值设为0.5accuracy = accuracy_score(y_test, predictions.numpy())print(f'Accuracy: {accuracy * 100:.2f}%')# 训练并评估模型
input_dim = X_train.shape[1]  # 输入特征的数量
model = LogisticRegressionModel(input_dim)
train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001)
evaluate_model(model, X_test, y_test)# 预测新邮件
def predict(model, new_email):model.eval()new_email_vec = vectorizer.transform([new_email]).toarray()new_email_tensor = torch.tensor(new_email_vec, dtype=torch.float32)with torch.no_grad():prediction = model(new_email_tensor)return "Spam" if prediction >= 0.5 else "Not Spam"# 检测新邮件
email_1 = "Congratulations! You have a limited time offer for a free cruise."
email_2 = "Hi, let's discuss the project updates tomorrow."print(f"Email 1: {predict(model, email_1)}")  # 可能输出:Spam
print(f"Email 2: {predict(model, email_2)}")  # 可能输出:Not Spam
1. 数据预处理
  • 准备数据集:包含垃圾邮件(Spam)和正常邮件(Not Spam)。
  • 文本向量化:使用 TfidfVectorizer 将文本转换为数值特征,使模型能够处理。
  • 去除停用词:排除无意义的常见词(如 "the", "is", "and"),提高模型性能。
2. 训练集与测试集划分
  • 将数据集拆分为训练集和测试集,以 67% 训练,33% 测试,保证模型有足够数据训练,同时可以评估其泛化能力。
3. 逻辑回归模型
  • 搭建 PyTorch 逻辑回归模型
    • 采用 nn.Linear() 构建一个单层神经网络(输入为文本特征,输出为 1 个数值)。
    • 使用 sigmoid 作为激活函数,将输出转换为 0-1 之间的概率值。
4. 训练模型
  • 定义损失函数:使用二元交叉熵损失 (BCELoss),适用于二分类问题。
  • 优化器:采用 Adam 优化器,以 0.001 学习率进行参数优化。
  • 训练流程
    1. 计算前向传播的输出。
    2. 计算损失值,衡量预测结果与真实标签的差距。
    3. 进行反向传播,更新权重参数。
    4. 迭代多轮(如 200 轮),不断优化模型。
5. 评估模型
  • 将测试数据输入模型,预测结果并与真实标签进行对比。
  • 计算准确率,评估模型在未见过的数据上的表现。
6. 预测新邮件
  • 将新邮件转换为数值特征(与训练时相同的方法)。
  • 使用训练好的模型进行预测
  • 阈值判断:如果输出概率 ≥ 0.5,则判断为垃圾邮件,否则为正常邮件。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【 CVE-2025-21298】 通过ghidriff查看完整补丁差异

ole32_dec24.dll-ole32.dll 差异 目录 视觉图表差异元数据 Ghidra 差异引擎 命令行二进制元数据差异程序选项

洛谷P3383 【模板】线性筛素数

题目链接:P3383 【模板】线性筛素数 - 洛谷 | 计算机科学教育新生态 题目难度:普及一 题目分析:本题是模板题,用到了线性筛法,其中原理是保证范围内的每个合数都被删掉(在 bool 数组里面标记为非素数…

STM32标准库移植RT-Thread nano

STM32标准库移植RT-Thread Nano 哔哩哔哩教程链接:STM32F1标准库移植RT_Thread Nano 移植前的准备 stm32标准库的裸机代码(最好带有点灯和串口)RT-Thread Nano Pack自己的开发板 移植前的说明 本人是在读学生,正在学习阶段&a…

JVM--类加载器

概念 类加载器:只参与加载过程中的字节码获取并加载到内存中的部分;java虚拟机提供给应用程序去实现获取类和接口字节码数据的一种技术,也就是说java虚拟机是允许程序员写代码去获取字节码信息 类加载是加载的第一步,主要有以下三…

ECMAScript 6语法

1.ES6简介 ECMAScript 6(简称ES6)是于2015年6月正式发布的JavaScript语言的标准,正式名为ECMAScript 2015(ES2015)。它的目标是使得JavaScript语言可以用来编写复杂的大型应用程序,成为企业级开发语言 。 …

联想Y7000+RTX4060+i7+Ubuntu22.04运行DeepSeek开源多模态大模型Janus-Pro-1B+本地部署

直接上手搓了: conda create -n myenv python3.10 -ygit clone https://github.com/deepseek-ai/Janus.gitcd Januspip install -e .pip install webencodings beautifulsoup4 tinycss2pip install -e .[gradio]pip install pexpect>4.3python demo/app_januspr…

Tez 0.10.1安装

个人博客地址:Tez 0.10.1安装 | 一张假钞的真实世界 具体安装步骤参照官网安装手册即可。此处只对官网手册进行补充。 从官网下载apache-tez-0.10.1-bin.tar.gz进行安装未成功,出现下面的异常。最终按照官网源代码编译的方式安装测试成功。 环境 Had…

FastAPI + GraphQL + SQLAlchemy 实现博客系统

本文将详细介绍如何使用 FastAPI、GraphQL(Strawberry)和 SQLAlchemy 实现一个带有认证功能的博客系统。 技术栈 FastAPI:高性能的 Python Web 框架Strawberry:Python GraphQL 库SQLAlchemy:Python ORM 框架JWT&…

微服务入门(go)

微服务入门(go) 和单体服务对比:里面的服务仅仅用于某个特定的业务 一、领域驱动设计(DDD) 基本概念 领域和子域 领域:有范围的界限(边界) 子域:划分的小范围 核心域…

深入解析 Linux 内核内存管理核心:mm/memory.c

在 Linux 内核的众多组件中,内存管理模块是系统性能和稳定性的关键。mm/memory.c 文件作为内存管理的核心实现,承载着页面故障处理、页面表管理、内存区域映射与取消映射等重要功能。本文将深入探讨 mm/memory.c 的设计思想、关键机制以及其在内核中的作用,帮助读者更好地理…

安卓通过网络获取位置的方法

一 方法介绍 1. 基本权限设置 首先需要在 AndroidManifest.xml 中添加必要权限&#xff1a; xml <uses-permission android:name"android.permission.INTERNET" /> <uses-permission android:name"android.permission.ACCESS_NETWORK_STATE" /&g…

【B站保姆级视频教程:Jetson配置YOLOv11环境(二)SSH连接的三种方式】

B站同步视频教程&#xff1a;https://www.bilibili.com/video/BV1m5wUeyEQD/ 在Jetson设备上配置YOLOv11环境时&#xff0c;SSH连接是实现远程高效开发与管理的关键一环。不同的网络环境和硬件配置可能会影响SSH连接的方式&#xff0c;本文将结合相关视频内容&#xff0c;详细…

视频拼接,拼接时长版本

目录 视频较长&#xff0c;分辨率较大&#xff0c;这个效果很好&#xff0c;不耗用内存 ffmpeg imageio&#xff0c;适合视频较短 视频较长&#xff0c;分辨率较大&#xff0c;这个效果很好&#xff0c;不耗用内存 ffmpeg import subprocess import glob import os from nats…

Vue.js 什么是 Composition API?

Vue.js 什么是 Composition API&#xff1f; 今天我们来聊聊 Vue 3 引入的一个重要特性&#xff1a;组合式 API&#xff08;Composition API&#xff09;。如果你曾在开发复杂的 Vue 组件时感到代码难以维护&#xff0c;那么组合式 API 可能正是你需要的工具。 什么是组合式 …

Selenium配合Cookies实现网页免登录

文章目录 前言1 方案一&#xff1a;使用Chrome用户数据目录2 方案二&#xff1a;手动获取并保存Cookies&#xff0c;后续使用保存的Cookies3 注意事项 前言 在进行使用Selenium进行爬虫、网页自动化操作时&#xff0c;登录往往是一个必须解决的问题&#xff0c;但是Selenium每次…

计算机毕业设计Python+知识图谱大模型AI医疗问答系统 健康膳食推荐系统 食谱推荐系统 医疗大数据 机器学习 深度学习 人工智能 爬虫 大数据毕业设计

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

关于el-table翻页后序号列递增的组件封装

需求说明&#xff1a; 项目中经常会用到的一个场景&#xff0c;表格第一列显示序号&#xff08;1、2、3...&#xff09;&#xff0c;但是在翻页后要递增显示序号&#xff0c;例如10、11、12&#xff08;假设一页显示10条数据&#xff09;&#xff0c;针对这种情况&#xff0c;封…

Elasticsearch的索引生命周期管理

目录 说明零、参考一、ILM的基本概念二、ILM的实践步骤Elasticsearch ILM策略中的“最小年龄”是如何计算的&#xff1f;如何监控和调整Elasticsearch ILM策略的性能&#xff1f; 1. **监控性能**使用/_cat/thread_pool API基本请求格式请求特定线程池的信息响应内容 2. **调整…

AI大模型开发原理篇-3:词向量和词嵌入

简介 词向量是用于表示单词意义的向量&#xff0c; 并且还可以被认为是单词的特征向量或表示。 将单词映射到实向量的技术称为词嵌入。在实际应用中&#xff0c;词向量和词嵌入这两个重要的NLP术语通常可以互换使用。它们都表示将词汇表中的单词映射到固定大小的连续向量空间中…

[内网安全] 内网渗透 - 学习手册

这是一篇专栏的目录文档&#xff0c;方便读者系统性的学习&#xff0c;笔者后续会持续更新文档内容。 如果没有特殊情况的话&#xff0c;大概是一天两篇的速度。&#xff08;实验多或者节假日&#xff0c;可能会放缓&#xff09; 笔者也是一边学习一边记录笔记&#xff0c;如果…