实现五折交叉验证进行模型训练 -

news/2025/11/22 20:45:31/文章来源:https://www.cnblogs.com/hx-top/p/19258571

1、实验目的
熟悉Python 的基本操作,掌握对数据集的读写实现、对模型性能的评估实现的能力;
加深对训练集、测试集、N折交叉验证、模型评估标准的理解。
2、实验内容
(1)利用pandas库从本地读取iris数据集;
(2)从scikit-learn 库中直接加载iris 数据集;
(3)实现五折交叉验证进行模型训练;
(4)计算并输出模型的准确度、精度、召回率和F1值。
3、操作要点
(1)安装Python及pycharm(一种Python开发IDE),并熟悉Python基本操作;
(2)学习pandas库里存取文件的相关函数,以及scikit-learn库里数据集下载、交叉验
证、模型评估等相关操作;
(3)可能用的库有pandas,scikit-learn,numpy 等,需要提前下载pip;
(4)测试模型可使用随机森林rf_classifier=RandomForestClassifier(n_estimators=100),
或其它分类器

实验代码:

"""
Iris数据集分类任务
实现五折交叉验证并计算模型评估指标
"""import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import warnings
import sys
import io# 设置输出编码为UTF-8,解决Windows PowerShell中文乱码问题
if sys.platform == 'win32':sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')warnings.filterwarnings('ignore')# # ==================== 步骤1:使用pandas从本地读取iris数据集 ====================
# print("=" * 70)
# print("步骤1:使用pandas从本地读取iris数据集")
# print("=" * 70)
#
# # 读取本地iris数据文件
# iris_local_path = 'iris/iris.data'
# iris_local = pd.read_csv(iris_local_path, header=None,
#                          names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'])
#
# print(f"本地数据集形状: {iris_local.shape}")
# print(f"前5行数据:")
# print(iris_local.head())
# print(f"\n类别分布:")
# print(iris_local['species'].value_counts())
# print()
#
# # 分离特征和标签
# X_local = iris_local.iloc[:, :-1].values
# y_local = iris_local.iloc[:, -1].values# ==================== 步骤2:从scikit-learn库中直接加载iris数据集 ====================
print("=" * 70)
print("步骤2:从scikit-learn库中直接加载iris数据集")
print("=" * 70)# 从sklearn加载iris数据集
iris_sklearn = load_iris()
X_sklearn = iris_sklearn.data
y_sklearn = iris_sklearn.target
feature_names = iris_sklearn.feature_names
target_names = iris_sklearn.target_namesprint(f"Sklearn数据集形状: {X_sklearn.shape}")
print(f"特征名称: {feature_names}")
print(f"类别名称: {target_names}")
print(f"类别分布: {np.bincount(y_sklearn)}")
print()# 将sklearn数据转换为DataFrame以便查看
iris_sklearn_df = pd.DataFrame(X_sklearn, columns=feature_names)
iris_sklearn_df['species'] = [target_names[i] for i in y_sklearn]
print("前5行数据:")
print(iris_sklearn_df.head())
print()# ==================== 步骤3:实现五折交叉验证进行模型训练 ====================
print("=" * 70)
print("步骤3:实现五折交叉验证进行模型训练")
print("=" * 70)# 选择使用sklearn加载的数据集(更标准)
X = X_sklearn
y = y_sklearn# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 创建模型(使用逻辑回归作为示例)
model = LogisticRegression(random_state=42, max_iter=1000)# 创建五折交叉验证
kfold = KFold(n_splits=5, shuffle=True, random_state=42)# 存储每折的预测结果
all_y_true = []
all_y_pred = []print("开始五折交叉验证...")
fold_num = 1
for train_idx, test_idx in kfold.split(X_scaled):print(f"\n第 {fold_num} 折:")print(f"  训练集大小: {len(train_idx)}, 测试集大小: {len(test_idx)}")# 划分训练集和测试集X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y[train_idx], y[test_idx]# 训练模型model.fit(X_train, y_train)# 预测y_pred = model.predict(X_test)# 保存预测结果all_y_true.extend(y_test)all_y_pred.extend(y_pred)# 计算当前折的指标acc = accuracy_score(y_test, y_pred)prec = precision_score(y_test, y_pred, average='weighted', zero_division=0)rec = recall_score(y_test, y_pred, average='weighted', zero_division=0)f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)print(f"  准确度: {acc:.4f}")print(f"  精度: {prec:.4f}")print(f"  召回率: {rec:.4f}")print(f"  F1值: {f1:.4f}")fold_num += 1print("\n五折交叉验证完成!")
print()# ==================== 步骤4:计算并输出模型的准确度、精度、召回率和F1值 ====================
print("=" * 70)
print("步骤4:计算并输出模型的评估指标(整体结果)")
print("=" * 70)# 计算整体指标
accuracy = accuracy_score(all_y_true, all_y_pred)
precision = precision_score(all_y_true, all_y_pred, average='weighted', zero_division=0)
recall = recall_score(all_y_true, all_y_pred, average='weighted', zero_division=0)
f1 = f1_score(all_y_true, all_y_pred, average='weighted', zero_division=0)print(f"\n整体评估指标(基于所有5折的预测结果):")
print(f"  准确度 (Accuracy): {accuracy:.4f}")
print(f"  精度 (Precision): {precision:.4f}")
print(f"  召回率 (Recall): {recall:.4f}")
print(f"  F1值 (F1-Score): {f1:.4f}")
print()# 使用cross_val_predict方法(另一种方式)
print("=" * 70)
print("使用cross_val_predict方法进行交叉验证(对比)")
print("=" * 70)y_pred_cv = cross_val_predict(model, X_scaled, y, cv=5)accuracy_cv = accuracy_score(y, y_pred_cv)
precision_cv = precision_score(y, y_pred_cv, average='weighted', zero_division=0)
recall_cv = recall_score(y, y_pred_cv, average='weighted', zero_division=0)
f1_cv = f1_score(y, y_pred_cv, average='weighted', zero_division=0)print(f"\n评估指标(cross_val_predict方法):")
print(f"  准确度 (Accuracy): {accuracy_cv:.4f}")
print(f"  精度 (Precision): {precision_cv:.4f}")
print(f"  召回率 (Recall): {recall_cv:.4f}")
print(f"  F1值 (F1-Score): {f1_cv:.4f}")
print()# 按类别显示详细指标
print("=" * 70)
print("按类别显示详细指标")
print("=" * 70)precision_per_class = precision_score(all_y_true, all_y_pred, average=None, zero_division=0)
recall_per_class = recall_score(all_y_true, all_y_pred, average=None, zero_division=0)
f1_per_class = f1_score(all_y_true, all_y_pred, average=None, zero_division=0)print("\n各类别指标:")
print(f"{'类别':<20} {'精度':<12} {'召回率':<12} {'F1值':<12}")
print("-" * 60)
for i, class_name in enumerate(target_names):print(f"{class_name:<20} {precision_per_class[i]:<12.4f} {recall_per_class[i]:<12.4f} {f1_per_class[i]:<12.4f}")print("\n" + "=" * 70)
print("任务完成!")
print("=" * 70)

实验结果得到的数据:
image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/973405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

KingbaseES:为银行核心系统迁移开启新航道 - 详解

KingbaseES:为银行核心系统迁移开启新航道 - 详解2025-11-22 20:38 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; displ…

用 ffmpeg 命令去除视频的重复帧、剪视频、修改视频尺寸 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

20232422 2025-2026-1 《网络与系统攻防技术》实验六实验报告

20232422 2025-2026-1 《网络与系统攻防技术》实验六实验报告 1.实验内容 本次实验是Metasploit工具的实战应用,先通过主机发现、端口扫描完成前期信息搜集,再针对Metasploitable2靶机的4个已知漏洞(Vsftpd后门漏洞…

毕业论文写作全流程:从选题到答辩的完整指南

毕业论文写作挑战重重,本文提供从选题到答辩的完整指南。选题与开题准备部分介绍选题原则方法、开题报告撰写及文献检索整理技巧;论文写作核心流程涵盖大纲搭建、摘要引言结论撰写、正文论证及参考文献规范;修改、降…

html空间如何添加滚动条

在HTML空间(通常指的是一个div元素)中添加滚动条,可以通过CSS样式来实现。以下是一个简单的示例,展示了如何为一个div元素添加垂直滚动条:HTML结构: <!DOCTYPE html> <html lang="en"> &l…

实用指南:Jenkins 持续集成与部署指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年11月DR耐油橡胶热缩管,氟橡胶热缩管,防滑花纹热缩管厂家最新推荐:耐老化性能实测榜单

2025年11月DR耐油橡胶热缩管,氟橡胶热缩管,防滑花纹热缩管厂家最新推荐:耐老化性能实测榜单在热缩管市场中,DR耐油橡胶热缩管、氟橡胶热缩管以及防滑花纹热缩管等产品凭借各自独特的性能,在众多领域发挥着重要作用…

2025年11月DR耐油橡胶热缩管,线缆标识热缩管,防滑花纹热缩管厂商推荐:耐油等级与使用寿命解析

2025年11月DR耐油橡胶热缩管,线缆标识热缩管,防滑花纹热缩管厂商推荐:耐油等级与使用寿命解析在众多热缩管厂商中,广州容信塑胶制品有限公司是一家值得关注的企业。该公司成立于2009年1月,是一家专业的热缩套管、…

[游记]CSP 2025

和一位,能不能比去年考得好啊 /ll Day -114514 以领先分数线 \(\Theta(1)\) 分的优势苟进了复赛。 Day 10.28 csp 前最后一场模拟赛,获得了极低的分数,太有信心了! 深度思考一整场 1log 怎么做,结果是不知道经典 …

11.22题解

A.栞 考虑面积公式 \(S = \frac{1}{2} ab \sin C\),则 \(4S^2 = ab(1 - \cos^2C)\),则我固定 ab 的情况下,我要 \(\sinC\) 最大,也就是 cos 绝对值最小。 考虑定序,若令 \(a > b > c\) 那么 C 一定是锐角,…

电梯调度问题的三次迭代

电梯调度问题的三次迭代 目录第一章 引言第二章 设计与分析第三章 踩坑心得第四章 改进建议第五章 总结第一章 引言 在现代城市生活中,电梯作为垂直交通的核心工具,其运行效率直接影响着人们的出行体验与楼宇的整体运…

【minimap2】一定要注意组合参数

当我需要minimap2在输出sam文件中包含secondary alignment时,我认为默认的输出开关就应该是开着的,因此没有设置--secondary=yes,使用以下参数:minimap2 -ax sr -t $threads ${INDEX} ${fastq_dictory}/${prefix}_…

3-数据库

3.数据库 2025.11.13 Day14 3.1 一条SQL查询语句是如何执行的? 连接器: 连接器负责跟客户端建立连接、获取权限、维持和管理连接。 查询缓存: MySQL 拿到一个查询请求后,会先到查询缓存看看,之前是不是执行过这条语…

4-java

4.java 2025.11.20 DAY23 4.1 String、StringBuffer、StringBuilder的区别 在 Java 中,String、StringBuilder 和 StringBuffer 都是用于处理字符序列的类。它们最核心的区别在于可变性、线程安全和性能。 1. 核心区别…

重构高阶智驾:天瞳威视以国产芯片,解锁Robotaxi平民化路径 - 实践

重构高阶智驾:天瞳威视以国产芯片,解锁Robotaxi平民化路径 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: &quo…

1-计算机网络

1.计算机网络 2025.10.29 DAY01 1.1 介绍一下TCP/IP模型和OSI模型的区别 OSI:物联网叔会使用 TCP/IP:接网叔用 OSI模型是国际标准化组织(ISO)制定的一个用于计算机或通信系统间互联的标准体系,它将网络通信精细地…

实用指南:MCU定点计算深度解析:原理、技巧与实现

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2-操作系统

2.计算机组成原理 2025.11.07 DAY10 2.1 进程和线程之间有什么区别 线程是程序执行的最小单位,线程是进程的子任务,是进程内的执行单元。 一个进程至少有一个线程,一个进程可以运行多个线程,这些线程共享同一块内存…

html空间如何添加图片

在HTML空间中添加图片,可以通过以下两种方法: 内联方式在HTML代码中,使用<img>标签插入图片。例如,要插入一张名为“example.jpg”的图片,且该图片位于与HTML文件相同的目录下,可以使用以下代码:<img …

html空间可以设置边框吗

HTML空间可以设置边框。在HTML中,可以使用CSS样式来设置边框。例如,<div style="border: 1px solid black;">这段代码就会给<div>元素设置一个黑色实线边框。同样地,也可以使用其他边框样式、…