基于深度学习的毒蘑菇检测

文章目录

    • 任务介绍
    • 数据概览
    • 数据处理
        • 数据读取与拼接
        • 字符数据转化
        • 标签数据映射
        • 数据集划分
        • 数据标准化
    • 模型构建与训练
        • 模型构建
        • 数据批处理
        • 模型训练
    • 文件提交
    • 结果
    • 附录

任务介绍

本次任务为毒蘑菇的二元分类,任务本身并不复杂,适合初学者,主要亮点在于对字符数据的处理,还有尝试了加深神经网络深度的效果,之后读者也可自行改变观察效果,比赛路径将于附录中给出。

数据概览

本次任务的数据集比较简单

  • train.csv 训练文件
  • test.csv 测试文件
  • sample_submission.csv 提交示例文件

具体内容就是关于毒蘑菇的各种特征,可在附录中获取数据集。

数据处理

数据读取与拼接

这段代码提取了数据文件,并且对两个不同来源的数据集进行了拼接,当我们的数据集较小时,就可采用这种方法,获取其他的数据集并将两个数据集合并起来。

import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
file = pd.read_csv("/kaggle/input/playground-series-s4e8/train.csv", index_col="id")
file2 = pd.read_csv("/kaggle/input/mushroom-classification-edible-or-poisonous/mushroom.csv")
file_all = pd.concat([file, file2])
字符数据转化

这段代码主要就是提取出字符数据,因为字符是无法直接被计算机处理,所以我们提取出来后,再将字符数据映射为数字数据。

char_features = ['cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
for i in char_features:file_all[i] = LabelEncoder().fit_transform(file_all[i])
file_all = file_all.fillna(0)
train_col = ['cap-diameter', 'stem-height', 'stem-width', 'cap-shape', 'cap-surface', 'cap-color', 'does-bruise-or-bleed', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-type', 'veil-color', 'has-ring', 'ring-type', 'spore-print-color', 'habitat', 'season']
X = file_all[train_col]
y = file_all['class']
标签数据映射

除了用上述方法进行字符转化外,还可以使用map函数,以下是具体操作。

y.unique()
# 构建映射字典
applying = {'e': 0, 'p': 1}
y = y.map(applying)
数据集划分

这段代码使用sklearn库将数据集划分为训练集和测试集。

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
x_train.shape, y_train.shape
数据标准化

这段代码将我们的数据进行归一化,减小数字大小方便计算,但是仍然保持他们之间的线性关系,不会对结果产生影响。

scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.fit_transform(x_test)

模型构建与训练

这段代码使用torch库构建了深度学习模型,主要运用了线性层,还进行了正则化操作,防止模型过拟合。

模型构建
import torch
import torch.nn as nn
class Model(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(20, 256)self.relu = nn.ReLU()self.dropout = nn.Dropout(p=0.2)self.linear1 = nn.Linear(256, 128)self.linear2 = nn.Linear(128, 64)self.linear3 = nn.Linear(64, 48)self.linear4 = nn.Linear(48, 32)self.linear5 = nn.Linear(32, 2)def forward(self, x):out = self.linear(x)out = self.relu(out)out = self.linear1(out)out = self.relu(out)out = self.dropout(out)out = self.linear2(out)out = self.relu(out)out = self.linear3(out)out = self.dropout(out)out = self.relu(out)out = self.linear4(out)out = self.relu(out)out = self.linear5(out)return out

对模型类进行实例化。

model = Model()
数据批处理

由于数据一条一条的处理起来很慢,因此我们可以将数据打包,一次给模型输入多条数据,能有效节省时间。

import torch.nn.functional as F
class Dataset(torch.utils.data.Dataset):def __init__(self, x, y):self.x = xself.y = ydef __len__(self):return len(self.x)def __getitem__(self, i):x = torch.Tensor(self.x[i])y = torch.tensor(self.y.iloc[i])return x, y
train_data = Dataset(x_train, y_train)
test_data = Dataset(x_test, y_test)
loader = torch.utils.data.DataLoader(train_data, batch_size=64, drop_last=True, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, drop_last=True, shuffle=True)
模型训练

这段代码就是模型的训练过程,包括创建优化器,定义损失函数等,还在训练过程中测试准确率与损失函数值,动态的观察训练过程。

from tqdm import tqdm
import matplotlib.pyplot as plt
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-5)
from sklearn.metrics import matthews_corrcoef
flag = 0
for i in range(10):for x, label in tqdm(loader):out = model(x)loss = criterion(out, label)loss.backward()optimizer.step()optimizer.zero_grad()flag+=1if flag%500 == 0:test = next(iter(test_loader))t_out = model(test[0]).argmax(dim=1)print("loss=", loss.item())acc = (t_out == test[1]).sum().item()/len(test[1])mcc = matthews_corrcoef(t_out, test[1])print("acc=", acc)print("mcc=", mcc)

文件提交

这段代码主要就是使用训练好的模型在测试集上预测,并且将其整合成提交文件。

test_file = pd.read_csv("/kaggle/input/playground-series-s4e8/test.csv")
for i in char_features:test_file[i] = LabelEncoder().fit_transform(test_file[i])
test_file.fillna(0)
test_x = torch.Tensor(test_file[train_col].values)
test_x = torch.Tensor(scaler.fit_transform(test_x))
out = model(test_x)
out = pd.Series(out.argmax(dim=1))
map2 = {0: 'e', 1: 'p'}
result = out.map(map2)
answer = pd.DataFrame({'id': test_file['id'], "class": result})
answer.to_csv('submission.csv', index=False)

结果

将文件提交后,得到了0.97的成绩,已经非常接近1了,证明模型的效果非常不错。
在这里插入图片描述

附录

比赛链接:https://www.kaggle.com/competitions/playground-series-s4e8
额外数据集地址:https://www.kaggle.com/datasets/vishalpnaik/mushroom-classification-edible-or-poisonous

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903931.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

时间给了我们什么?

时间给了我们什么? ​春秋易逝,青春难留,转瞬之间已过半百。 ​过往中,有得有失,这就是人生。 ​一日三餐四季,日起日落里,成就了昨天、今天和明天,在历史长河中,皆是…

软件工程国考

软件工程-同等学力计算机综合真题及答案 (2004-2014、2017-2024) 2004 年软工 第三部分 软件工程 (共 30 分) 一、单项选择题(每小题 1 分,共 5 分) 软件可用性是指( &#xff09…

数据结构*栈

栈 什么是栈 这里的栈与我们之前常说的栈是不同的。之前我们说的栈是内存栈,它是JVM内存的一部分,用于存储局部变量、方法调用信息等。每个线程都有自己独立的栈空间,当线程启动时,栈就会被创建;线程结束&#xff0c…

IntelliJ IDEA 保姆级使用教程

文章目录 一、创建项目二、创建模块三、创建包四、创建类五、编写代码六、运行代码注意 七、IDEA 常见设置1、主题2、字体3、背景色 八、IDEA 常用快捷键九、IDEA 常见操作9.1、类操作9.1.1、删除类文件9.1.2、修改类名称注意 9.2、模块操作9.2.1、修改模块名快速查看 9.2.2、导…

HTTP 快速解析

一、HTTP请求结构 HTTP请求和响应报文由以下部分组成(以请求报文为例): 请求报文结构: 请求行:包含HTTP方法(如GET/POST)、请求URL和协议版本(如HTTP/1.1,HTTP/2.0&…

【AI学习】李宏毅新课《DeepSeek-R1 这类大语言模型是如何进行「深度思考」(Reasoning)的?》的部分纪要

针对推理模型,主要讲了四种方法,两种不需要训练模型,两种需要。 对于reason和inference,这两个词有不同的含义! 推理时计算不是新鲜事,AlphaGo就是如此。 这张图片说明了将训练和推理时计算综合考虑的关系&…

Kotlin Flow流

一 Kotlin Flow 中的 stateIn 和 shareIn 一、简单比喻理解 想象一个水龙头(数据源)和几个水杯(数据接收者): 普通 Flow(冷流):每个水杯来接水时,都要重新打开水龙头从…

【嵌入式Linux】基于ARM-Linux的zero2平台的智慧楼宇管理系统项目

目录 1. 需求及项目准备(此项目对于虚拟机和香橙派的配置基于上一个垃圾分类项目,如初次开发,两个平台的环境变量,阿里云接入,摄像头配置可参考垃圾分类项目)1.1 系统框图1.2 硬件接线1.3 语音模块配置1.4 …

Linux运维中常用的磁盘监控方式

在Linux运维中,磁盘监控是一项关键任务,因为它能帮助我们预防磁盘空间不足或性能问题导致的服务中断或数据丢失。让我们来看看有哪些常用的磁盘监控方法吧! 1. 查看磁盘使用情况(df命令) df命令用于显示文件系统的…

OpenCV第6课 图像处理之几何变换(缩放)

1.简述 图像几何变换又称为图像空间变换,它将一幅图像中的坐标位置映射到另一幅图像中的新坐标位置。几何变换并不改变图像的像素值,只是在图像平面上进行像素的重新安排。 根据OpenCV函数的不同,本节课将映射关系划分为缩放、翻转、仿射变换、透视等。 2.缩放 2.1 函数…

(35)VTK C++开发示例 ---将图片映射到平面2

文章目录 1. 概述2. CMake链接VTK3. main.cpp文件4. 演示效果 更多精彩内容👉内容导航 👈👉VTK开发 👈 1. 概述 与上一个示例不同的是,使用vtkImageReader2Factory根据文件扩展名或内容自动创建对应的图像文件读取器&a…

【模型量化】量化基础

目录 一、认识量化 二、量化基础原理 2.1 对称量化和非对称量化 2.1.1 对称量化 2.1.2 非对称量化 2.1.3 量化后的矩阵乘 2.2 神经网络量化 2.2.1 动态量化 2.2.2 静态量化 2.3 量化感知训练 一、认识量化 量化的主要目的是节约显存、提高计算效率以及加快通信 dee…

【零基础入门】一篇掌握Python中的字典(创建、访问、修改、字典方法)【详细版】

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀《PyTorch科研加速指南:即插即用式模块开发》-CSDN博客 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 1. 前言 2. 字典 2.1 字典的创建 2.1.1 大括号+直接赋值 2.1.2 大括号…

PHP-session

PHP中,session(会话)是一种在服务器上存储用户数据的方法,这些数据可以在多个页面请求或访问之间保持。Session提供了一种方式来跟踪用户状态,比如登录信息、购物车内容等。当用户首次访问网站时,服务器会创…

第 5 篇:红黑树:工程实践中的平衡大师

上一篇我们探讨了为何有序表需要“平衡”机制来保证 O(log N) 的稳定性能。现在,我们要认识一位在实际工程中应用最广泛、久经考验的“平衡大师”——红黑树 (Red-Black Tree)。 如果你用过 Java 的 TreeMap​ 或 TreeSet​,或者 C STL 中的 map​ 或 s…

第十六届蓝桥杯 2025 C/C++组 客流量上限

目录 题目: 题目描述: 题目链接: 思路: 打表找规律: 核心思路: 思路详解: 得到答案的方式: 按计算器: 暴力求解代码: 快速幂代码: 位运…

一天学完JDBC!!(万字总结)

文章目录 JDBC是什么 1、环境搭建 && 入门案例2、核心API理解①、注册驱动(Driver类)②、Connection③、statement(sql注入)④、PreparedStatement⑤、ResultSet 3、jdbc扩展(ORM、批量操作)①、实体类和ORM②、批量操作 4. 连接池①、常用连接池②、Durid连接池③、Hi…

从原理到实战讲解回归算法!!!

哈喽,大家好,我是我不是小upper, 今天系统梳理了线性回归的核心知识,从模型的基本原理、参数估计方法,到模型评估指标与实际应用场景,帮助大家深入理解这一经典的机器学习算法,助力数据分析与预测工作。 …

【dify—10】工作流实战——文生图工具

目录 一、创建工作流 应用 二、安装硅基流动 三、配置硅基流动 四、API测试 (1)进入API文档 (2)复制curl代码 (3)Postman测试API 五、 建立文生图工作流 (1)建立http请求 &…

Rust将结构导出到json如何处理小数点问题

简述 标准的 serde_json 序列化器不支持直接对浮点数进行格式化限制。如果将浮点数转换成字符串,又太low逼。这里重点推荐rust_decimal。 #[derive(Serialize)] pub struct StockTickRow {datetime: NaiveDateTime,code: String,name: String,#[serde(serialize_w…