机器学习 - 不同分类模型的比较

一、模型训练

本案例中,我们将通过四种不同的模型来预测泰坦尼克号乘客的生存情况。
一下是训练的具体步骤。

加载数据

从seaborn库中加载目标数据。该数据集包括多个特征,如 PassengerId, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, 和 Embarked。我们训练使用特征 Pclass, Age, Fare, 和 Sex,标签列为 Survived

import pandas as pd
import seaborn as sns# Load the Titanic dataset
data = sns.load_dataset('titanic')
print(data.head())
ResultPassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   
3            4         1       1   
4            5         0       3   Name     Sex   Age  SibSp  \
0                            Braund, Mr. Owen Harris    male  22.0      1   
1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
2                             Heikkinen, Miss. Laina  female  26.0      0   
3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   
4                           Allen, Mr. William Henry    male  35.0      0   Parch            Ticket     Fare Cabin Embarked  
0      0         A/5 21171   7.2500   NaN        S  
1      0          PC 17599  71.2833   C85        C  
2      0  STON/O2. 3101282   7.9250   NaN        S  
3      0            113803  53.1000  C123        S  
4      0            373450   8.0500   NaN        S  

数据预处理

在本案例中,我们的目标是预测泰坦尼克号乘客的生存情况。首先,我将详细介绍使用的数据预处理方法,这是确保模型表现良好的重要步骤。

1. 缺失值处理

在泰坦尼克号数据集中,Age 是存在缺失值的重要特征。处理缺失值是确保模型准确性的关键步骤之一。

# Handle missing values for 'Age'
imputer = SimpleImputer(strategy='mean')
features['Age'] = imputer.fit_transform(features[['Age']])
  • SimpleImputer(strategy='mean'): 这行代码创建了一个填充器对象,指定使用均值(mean)来填充缺失值。
  • imputer.fit_transform(features[['Age']]): 这里应用填充器,计算所有已知年龄的均值,并填充到缺失的位置。这种方法假设年龄数据的缺失是随机的,使用均值是合理的首选。
2. 类别特征编码

Sex 列是分类数据,包含文本值(male/female),需要转换为模型可处理的数值形式。

# Convert 'Sex' from categorical to numerical
encoder = LabelEncoder()
features['Sex'] = encoder.fit_transform(features['Sex'])
  • LabelEncoder(): 这行代码创建了一个标签编码器,用于将文本标签转换为唯一的整数。
  • encoder.fit_transform(features['Sex']): 应用编码器,将 malefemale 分别转换为数值(例如 0 和 1)。这是必须的步骤,因为大多数机器学习算法在训练过程中不能直接处理文本数据。
3. 特征缩放

由于KNN和许多其他机器学习算法对数据的尺度敏感,所以对特征进行标准化是很重要的。

# Standardize the features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
  • StandardScaler(): 创建一个标准化器,用于将特征缩放到具有零均值和单位方差的范围内。
  • scaler.fit_transform(features): 应用标准化处理,确保所有特征都处于相同的尺度,有助于改善模型的性能和收敛速度。
4. 数据集划分

最后,数据被划分为训练集和测试集,用于训练模型和评估其性能。

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features_scaled, target, test_size=0.2, random_state=42)
  • train_test_split(): 这个函数将数据随机分为训练集和测试集,test_size=0.2 表示 20% 的数据用于测试,剩下的 80% 用于训练。random_state=42 确保每次数据分割的方式相同,这对于可复现性是很重要的。

这些预处理步骤确保了数据的一致性和适用性,是后续模型训练和验证的基础。

5.数据预处理代码
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import classification_report# Select the required features and the target
features = data[['Pclass', 'Age', 'Fare', 'Sex']]
target = data['Survived']# Handle missing values for 'Age'
imputer = SimpleImputer(strategy='mean')
features['Age'] = imputer.fit_transform(features[['Age']])# Convert 'Sex' from categorical to numerical
encoder = LabelEncoder()
features['Sex'] = encoder.fit_transform(features['Sex'])# Standardize the features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features_scaled, target, test_size=0.2, random_state=42)

KNN模型训练

from sklearn.neighbors import KNeighborsClassifier# Train the KNN model
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)# Predictions and evaluation
knn_predictions = knn.predict(X_test)
knn_report = classification_report(y_test, knn_predictions)knn_report

KNN模型评估报告:

              precision    recall  f1-score   support0       0.82      0.89      0.85       1051       0.82      0.72      0.76        74accuracy                           0.82       179macro avg       0.82      0.80      0.81       179
weighted avg       0.82      0.82      0.81       179

这个报告显示了模型在测试集上的表现,包括精确度(precision)、召回率(recall)、F1分数和总体准确度(accuracy)。

线性回归模型训练

from sklearn.linear_model import LinearRegression# Train the Linear Regression model
# Note: Linear regression is not typically used for classification tasks, but we'll demonstrate it here for learning purposes.
linear_reg = LinearRegression()
linear_reg.fit(X_train, y_train)# Predictions
linear_reg_predictions = linear_reg.predict(X_test)# Convert predictions to binary to evaluate (0 if < 0.5 else 1)
linear_reg_predictions_binary = [1 if x >= 0.5 else 0 for x in linear_reg_predictions]# Evaluation
linear_reg_report = classification_report(y_test, linear_reg_predictions_binary)linear_reg_report

线性回归评估报告:

              precision    recall  f1-score   support0       0.81      0.84      0.82       1051       0.76      0.72      0.74        74accuracy                           0.79       179macro avg       0.78      0.78      0.78       179
weighted avg       0.79      0.79      0.79       179

尽管线性回归通常不用于分类任务,这里我们通过将输出阈值设为0.5来将其用于二分类问题。

逻辑回归模型训练

from sklearn.linear_model import LogisticRegression# Train the Logistic Regression model
logistic_reg = LogisticRegression(random_state=42)
logistic_reg.fit(X_train, y_train)# Predictions
logistic_reg_predictions = logistic_reg.predict(X_test)# Evaluation
logistic_reg_report = classification_report(y_test, logistic_reg_predictions)logistic_reg_report

逻辑回归评估报告:

              precision    recall  f1-score   support0       0.82      0.86      0.84       1051       0.78      0.73      0.76        74accuracy                           0.80       179macro avg       0.80      0.79      0.80       179
weighted avg       0.80      0.80      0.80       179

决策树模型训练

from sklearn.tree import DecisionTreeClassifier# Train the Decision Tree model
decision_tree = DecisionTreeClassifier(random_state=42)
decision_tree.fit(X_train, y_train)# Predictions
decision_tree_predictions = decision_tree.predict(X_test)# Evaluation
decision_tree_report = classification_report(y_test, decision_tree_predictions)decision_tree_report

决策树评估报告:

              precision    recall  f1-score   support0       0.79      0.77      0.78       1051       0.69      0.72      0.70        74accuracy                           0.75       179macro avg       0.74      0.74      0.74       179
weighted avg       0.75      0.75      0.75       179

模型比较

在这项分析中,我们使用了四种不同的机器学习模型来处理同一数据集,下面是每个模型的性能总结和对比:

KNN (K-Nearest Neighbors)

  • 精确度 (Precision): 0.82 (平均)
  • 召回率 (Recall): 0.80 (平均)
  • F1 分数: 0.81 (平均)
  • 总体准确率 (Accuracy): 82%
  • 优点: 相对直观易懂,不需要假设数据分布。
  • 缺点: 对异常值敏感,计算量较大,需要调整超参数(如K值)。

线性回归 (Linear Regression)

  • 精确度: 0.78 (平均)
  • 召回率: 0.78 (平均)
  • F1 分数: 0.78 (平均)
  • 总体准确率: 79%
  • 优点: 实现简单,解释性强。
  • 缺点: 不适合用于分类任务,需要转换为分类输出,容易受到异常值的影响。

逻辑回归 (Logistic Regression)

  • 精确度: 0.80 (平均)
  • 召回率: 0.79 (平均)
  • F1 分数: 0.80 (平均)
  • 总体准确率: 80%
  • 优点: 输出可解释性强,输出值具有概率意义。
  • 缺点: 非线性问题表现一般。

决策树 (Decision Tree)

  • 精确度: 0.74 (平均)
  • 召回率: 0.74 (平均)
  • F1 分数: 0.74 (平均)
  • 总体准确率: 75%
  • 优点: 不需要数据预处理,对非线性关系处理得好,易于理解和解释。
  • 缺点: 容易过拟合,对于数据变化较敏感。

总结

  • 性能最佳: KNN 和逻辑回归在本次任务中表现最佳,具有较高的准确率和平衡的精确度与召回率。
  • 适用性: 逻辑回归提供了概率输出,更适合需要概率解释的场景。决策树在解释性和处理非线性数据方面有优势。
  • 资源消耗: KNN在大数据集上运行较慢,因为需要计算每个实例之间的距离。决策树和逻辑回归相对资源消耗较低。

很好,我们可以进一步探讨一些可能的改进方法或者根据模型性能的具体分析来提供额外的见解。

模型优化和选择

改进模型性能的策略

  1. KNN:

    • 参数调整: KNN的K值对模型的性能有重大影响。通过交叉验证来找到最佳的K值可以改进模型的精确度和召回率。
    • 距离度量: 选择不同的距离度量(如欧氏距离、曼哈顿距离)可能会对结果产生影响,特别是在特征差异性较大的数据集中。
  2. 线性回归和逻辑回归:

    • 特征工程: 引入多项式特征或交互特征可以帮助模型捕捉更复杂的关系,尤其是在逻辑回归中处理非线性边界时。
    • 正则化: 对逻辑回归使用L1或L2正则化可以帮助避免过拟合,同时选择合适的正则化强度是关键。
  3. 决策树:

    • 剪枝: 对决策树进行剪枝(限制树的深度、叶节点的最小样本数等)可以减少过拟合,提高模型的泛化能力。
    • 集成方法: 使用随机森林或梯度提升决策树(Gradient Boosting Decision Trees, GBDT)等集成方法可以显著提升决策树的性能和稳定性。

模型选择的考虑因素

  • 数据大小和特征数: 对于大规模数据集,计算密集型的模型(如KNN)可能不是最佳选择。相反,决策树和逻辑回归在大数据集上的表现通常更优。
  • 预测时间要求: 如果应用场景对预测速度有严格要求,需要考虑模型的预测效率。例如,决策树的预测速度通常非常快。
  • 模型的解释性: 在需要解释模型决策的应用中(如医疗、金融领域),决策树和逻辑回归的可解释性优势可能更为重要。

更多问题咨询

Cos机器人

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/11696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科技查新中的工法查新点如何确立与提炼?案例讲解!

按《工程建设工法管理办法》( 建 质&#xff3b;2014&#xff3d;103 号) &#xff0c;工法&#xff0c;是指以工程为对象&#xff0c;以工艺为核心&#xff0c;运用系 统工程原理&#xff0c;把先进技术和科学管理结合起来&#xff0c;经过一定工程实践形成的综合配套的施工方…

探索美国动态IP池:技术赋能下的网络安全新篇章

在数字化飞速发展的今天&#xff0c;网络安全成为了各行各业关注的焦点。特别是在跨国业务中&#xff0c;如何保障数据的安全传输和合规性成为了企业面临的重要挑战。美国动态IP池作为一种新兴的网络技术&#xff0c;正逐渐走进人们的视野&#xff0c;为网络安全提供新的解决方…

黑马甄选离线数仓项目day02(数据采集)

datax介绍 官网&#xff1a; https://github.com/alibaba/DataX/blob/master/introduction.md DataX 是阿里云 DataWorks数据集成 的开源版本&#xff0c;在阿里巴巴集团内被广泛使用的离线数据同步工具/平台。 DataX 实现了包括 MySQL、Oracle、OceanBase、SqlServer、Postgre…

Java中List接口中方法的使用(初学者指南)

Java中List接口中方法的使用&#xff08;初学者指南&#xff09; 在Java中&#xff0c;List接口是Collection接口的子接口&#xff0c;它表示一个有序的集合&#xff0c;其中的元素都可以重复。List接口提供了许多额外的方法&#xff0c;用于对元素进行插入、删除、查询等操作…

计算机Java项目|Springboot学生读书笔记共享

作者主页&#xff1a;编程指南针 作者简介&#xff1a;Java领域优质创作者、CSDN博客专家 、CSDN内容合伙人、掘金特邀作者、阿里云博客专家、51CTO特邀作者、多年架构师设计经验、腾讯课堂常驻讲师 主要内容&#xff1a;Java项目、Python项目、前端项目、人工智能与大数据、简…

C++通过json文件配置参数

一、安装nlohmann json nlohmann json&#xff1a;安装_nlohmann安装-CSDN博客 依次执行下面指令&#xff1a; git clone https://gitee.com/cuihongxi/mov_from_github.gitcd json-developmkdir buildcd buildcmake ..makesudo make install 二、安装完成后使用 #include…

华为设备display查看命令

display version //查看版本信息 display current-configuration //查看配置详情 display this //查看当前视图有效配置 display ip routing-table //查看路由表 display ip routing-table 192.168.3.1 //查看去往3.1的路由 display ip interface brief //查看接口下ip信息 dis…

想跨境出海?云手机提供了一种可能性

全球化时代&#xff0c;越来越多的中国电商开始将目光投向了海外市场。这并不是偶然&#xff0c;而是他们在长期的市场运营中&#xff0c;看到了出海的必要性和潜在的机会。 中国的电商市场无疑是全球最大也最发达的之一。然而&#xff0c;随着市场的不断发展和竞争的日益加剧…

visual studio2022 JNI极简开发流程

文章目录 1 创建java类2 生成JNI头文件3 使用visual studio2022创建DLL项目3.1 选择模板中&#xff08;Windows桌面向导&#xff09;3.2 为项目命名3.3 选择应用程序类型为动态链接库3.4 项目概览 4 导入需要的头文件4.1 导入需要的头文件4.2 修改头文件 5 编写C实现6 生成dll文…

服务器3389端口,服务器3389端口风险提示的应对措施

3389端口是Windows操作系统中远程桌面协议&#xff08;RDP&#xff09;的默认端口。一旦该端口被恶意攻击者利用&#xff0c;可能会导致未经授权的远程访问和数据泄露等严重安全问题。 针对此风险&#xff0c;强烈建议您采取以下措施&#xff1a; 1. 修改默认端口&#xff1a;…

Java面试之抽象类和接口

Java的一个重要特性就是抽象&#xff0c;抽象是指将具体的事物抽象成更一般化、更抽象化的概念或模型。在Java中&#xff0c;抽象可以通过抽象类和接口来实现&#xff0c;它们让你能够定义一些方法但不提供具体实现&#xff0c;从而让子类去实现具体细节。 一、抽象类&#xf…

springboot3 集成spring-authorization-server (一 基础篇)

官方文档 Spring Authorization Server 环境介绍 java&#xff1a;17 SpringBoot&#xff1a;3.2.0 SpringCloud&#xff1a;2023.0.0 引入maven配置 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter…

识别AI论文生成内容,降低论文高AI率

AI写作工具能帮我们在短时间内高效生成一篇毕业论文、开通报告、文献综述、任务书、调研报告、期刊论文、课程论文等等&#xff0c;导致许多人开始使用AI写作工具作为撰写学术论文的辅助手段。而学术界为了杜绝此行为&#xff0c;开始使用AIGC检测系统来判断文章是由AI生成还是…

解锁商业AI,赋能新质生产力发展——思爱普中国峰会探展全纪录

ITValue 钛媒体独家探秘思爱普中国峰会&#xff0c;带你深刻感受SAP助力企业利用以商业AI为代表的数字化技术&#xff0c;实现质的飞跃&#xff0c;通过全数据、全球化、全绿色赋能新型中国企业发展新质生产力。 首发&#xff5c;钛媒体APP ITValue 5月10日&#xff0c;一年一度…

基于NTP服务器获取网络时间的实现

文章目录 1 NTP1.1 简介1.2 包结构1.3 UNIX 时间戳和NTP时间戳 2 代码实现2.1 实现步骤2.2 完整代码 3 结果 在某些场景下&#xff0c;单片机需要通过网络获取准确的时间进行数据同步&#xff0c;例如日志记录、定时任务等。然而&#xff0c;单片机本身无法直接获得准确的标准时…

Vue的学习 —— <vue指令>

目录 前言 正文 内容渲染指令 内容渲染指令的使用方法 v-text v-html 属性绑定指令 双向数据绑定指令 事件绑定指令 条件渲染指令 循环列表渲染指令 侦听器 前言 在完成Vue开发环境的搭建后&#xff0c;若想将Vue应用于实际项目&#xff0c;首要任务是学习Vue的基…

ORA-00932: inconsistent datatypes: expected - got CLOB的分析解决方案

最近在项目中遇到查询数据时报ORA-00932: inconsistent datatypes: expected - got CLOB错误&#xff0c;这个错误很明显是由于查询时类型的不匹配造成的。 问题分析&#xff1a; 一、检查你的查询的实体的类型是否于数据库的保持一致&#xff0c;如果不一致&#xff0c;那么需…

333_C++_编写一个go函数每次从文件中读取固定大小数据,且go作为回调,传递给其他函数中,多次调用,完成逐块传输数据

(core工程文件) tick_transfer_all_t类是一个用于异步传输数据的辅助类,它在某个异步操作完成后将_tick的值设置为0,并返回传输的结果 namespace hl {namespace http{namespace __detail{class tick_transfer_all_t{boost::shared_ptr<unsigned long long> _tick

MySQL 查询库 和 表 占用空间大小的 语句

查看mysql 数据库的大小 SELECT table_schema AS 数据库名称, ROUND(SUM(data_length index_length) / 1024 / 1024, 2) AS 数据库大小(MB) FROM information_schema.tables GROUP BY table_schema;查询数据库中表的 数据量&#xff08;这个方法 有缓存延迟&#xff0c;只能用…

[力扣题解] 96. 不同的二叉搜索树

题目&#xff1a;96. 不同的二叉搜索树 思路 动态规划 f[i]&#xff1a;有i个结点有多少种二叉搜索树 状态转移方程&#xff1a; 以n3为例&#xff1a; 以1为头节点&#xff0c;左子树有0个结点&#xff0c;右子树有2个结点&#xff1b; 以2为头节点&#xff0c;左子树有1个…