机器学习入门之KNN算法和交叉验证与超参数搜索(三)

机器学习入门之KNN算法和交叉验证与超参数搜索(三)

文章目录

  • 机器学习入门之KNN算法和交叉验证与超参数搜索(三)
  • 一、KNN算法-分类
    • 1. 样本距离判断
      • 明可夫斯基距离
    • 2. KNN 算法原理
    • 3. KNN 的缺点
    • 4. KNN 的 API
    • 5. 使用 sklearn 实现 KNN 示例
    • 6. 模型保存与加载
  • 二、模型选择与调优:交叉验证与超参数搜索
    • 1. 交叉验证
      • (1) 保留交叉验证(HoldOut)
      • (2) K-折交叉验证(K-fold)
      • (3) 分层 K-折交叉验证(Stratified K-fold)
      • (4) 其他验证方法
      • (5) API 示例
    • 2. 超参数搜索(网格搜索,Grid Search)
    • 3. sklearn API
    • 4. 示例:鸢尾花分类


一、KNN算法-分类

1. 样本距离判断

KNN 算法中,样本之间的距离是判断相似性的关键。常见的距离度量方式包括:

明可夫斯基距离

  • 欧式距离:明可夫斯基距离的特殊情况,公式为 (\sqrt{\sum_{i=1}^{n}(x_i - y_i)^2})。
  • 曼哈顿距离:明可夫斯基距离的另一种特殊情况,公式为 (\sum_{i=1}^{n}|x_i - y_i|)。

2. KNN 算法原理

K-近邻算法(K-Nearest Neighbors,简称 KNN)是一种基于实例的学习方法。其核心思想是:如果一个样本在特征空间中的 k 个最相似(最邻近)样本中的大多数属于某个类别,则该样本也属于这个类别。例如,假设我们有 10000 个样本,选择距离样本 A 最近的 7 个样本,其中类别 1 有 2 个,类别 2 有 3 个,类别 3 有 2 个,则样本 A 被认为属于类别 2。

3. KNN 的缺点

  • 计算量大:对于大规模数据集,需要计算测试样本与所有训练样本的距离。
  • 维度灾难:在高维数据中,距离度量可能变得不那么有意义。
  • 参数选择:需要选择合适的 k 值和距离度量方式,这可能需要多次实验和调整。

4. KNN 的 API

class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, algorithm='auto')
  • 参数
    • n_neighbors:用于 kneighbors 查询的近邻数,默认为 5。
    • algorithm:找到近邻的方式,可选值为 {'auto', 'ball_tree', 'kd_tree', 'brute'},默认为 'auto'
  • 方法
    • fit(X, y):使用 X 作为训练数据和 y 作为目标数据。
    • predict(X):预测提供的数据,返回预测结果。

5. 使用 sklearn 实现 KNN 示例

以下是一个使用 KNN 算法对鸢尾花进行分类的完整代码示例:

# 用 KNN 算法对鸢尾花进行分类
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier# 1)获取数据
iris = load_iris()
print(iris.data.shape)  # (150, 4)
print(iris.feature_names)  # ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
print(iris.target.shape)  # (150,)
print(iris.target)  # [0 0 0 ... 2 2 2]
print(iris.target_names)  # ['setosa' 'versicolor' 'virginica']# 2)划分数据集
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)# 3)特征工程:标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)# 4)KNN 算法预估器
estimator = KNeighborsClassifier(n_neighbors=7)
estimator.fit(x_train, y_train)# 5)模型评估
y_predict = estimator.predict(x_test)
print("y_predict:\n", y_predict)
print("直接比对真实值和预测值:\n", y_test == y_predict)
score = estimator.score(x_test, y_test)
print("准确率为:\n", score)  # 0.9473684210526315

6. 模型保存与加载

使用 joblib 可以方便地保存和加载模型:

import joblib# 保存模型
joblib.dump(estimator, "my_knn.pkl")# 加载模型
estimator = joblib.load("my_knn.pkl")# 使用模型预测
y_test = estimator.predict([[0.4, 0.2, 0.4, 0.7]])
print(y_test)

以下是整理后的 Markdown 格式内容:


二、模型选择与调优:交叉验证与超参数搜索

1. 交叉验证

交叉验证是评估模型性能的重要方法,常见的交叉验证技术包括:

(1) 保留交叉验证(HoldOut)

  • 原理:将数据集随机划分为训练集和验证集,通常比例为 70% 训练集和 30% 验证集。
  • 优点:简单易行。
  • 缺点
    • 不适用于不平衡数据集。
    • 一部分数据未参与训练,可能导致模型性能不佳。

(2) K-折交叉验证(K-fold)

  • 原理:将数据集划分为 K 个大小相同的部分,每次使用一个部分作为验证集,其余部分作为训练集,重复 K 次。
  • 优点:充分利用数据,模型性能更稳定。
  • 缺点:计算量较大。

(3) 分层 K-折交叉验证(Stratified K-fold)

  • 原理:在每一折中保持原始数据中各个类别的比例关系,确保每个折叠的类别分布与整体数据一致。
  • 优点:适用于不平衡数据集,验证结果更可信。

(4) 其他验证方法

  • 留一交叉验证:每次只留一个样本作为验证集。
  • 蒙特卡罗交叉验证:随机划分训练集和测试集,多次重复。
  • 时间序列交叉验证:适用于时间序列数据。

(5) API 示例

from sklearn.model_selection import StratifiedKFoldstrat_k_fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
indexs = strat_k_fold.split(X, y)for train_index, test_index in indexs:X_train, X_test = X[train_index], X[test_index]y_train, y_test = y[train_index], y[test_index]

2. 超参数搜索(网格搜索,Grid Search)

网格搜索是一种自动寻找最佳超参数的方法,通过遍历所有可能的参数组合来找到最优解。

3. sklearn API

from sklearn.model_selection import GridSearchCVGridSearchCV(estimator, param_grid, cv=5)
  • 参数
    • estimator:模型实例。
    • param_grid:超参数的网格,例如 {"n_neighbors": [1, 3, 5, 7, 9, 11]}
    • cv:交叉验证的折数,默认为 5。
  • 属性
    • best_params_:最佳参数。
    • best_score_:最佳模型的交叉验证分数。
    • best_estimator_:最佳模型实例。
    • cv_results_:交叉验证结果。

4. 示例:鸢尾花分类

以下是一个使用 KNN 算法对鸢尾花进行分类的完整代码示例,结合了分层 K-折交叉验证和网格搜索:

from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler# 加载数据
iris = load_iris()
X = iris.data
y = iris.target# 初始化分层 K-折交叉验证器
strat_k_fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)# 创建 KNN 分类器实例
knn = KNeighborsClassifier()# 网格搜索与交叉验证
param_grid = {"n_neighbors": [1, 3, 5, 7, 9, 11]}
grid_search = GridSearchCV(knn, param_grid, cv=strat_k_fold)
grid_search.fit(X, y)# 输出结果
print("最佳参数:", grid_search.best_params_)  # {'n_neighbors': 3}
print("最佳准确率:", grid_search.best_score_)  # 0.9553030303030303
print("最佳模型:", grid_search.best_estimator_)  # KNeighborsClassifier(n_neighbors=3)

通过分层 K-折交叉验证和网格搜索,我们可以找到最优的超参数,从而提高模型的性能。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/81128.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小刚说C语言刷题—1700请输出所有的2位数中,含有数字2的整数

1.题目描述 请输出所有的 2 位数中,含有数字 2 的整数有哪些,每行 1个,按照由小到大输出。 比如: 12、20、21、22、23… 都是含有数字 2的整数。 输入 无 输出 按题意要求由小到大输出符合条件的整数,每行 1 个。…

在MYSQL中导入cookbook.sql文件

参考资料: GitHub 项目:svetasmirnova/mysqlcookbook CSDN 博客:https://blog.csdn.net/u011868279/category_11645577.html 建库: mysql> use mysql Reading table information for completion of table and column names …

Scrapy框架下地图爬虫的进度监控与优化策略

1. 引言 在互联网数据采集领域,地图数据爬取是一项常见但具有挑战性的任务。由于地图数据通常具有复杂的结构(如POI点、路径信息、动态加载等),使用传统的爬虫技术可能会遇到效率低下、反爬策略限制、任务进度难以监控等问题。 …

【Win32 API】 lstrcmpA()

作用 比较两个字符字符串(比较区分大小写)。 lstrcmp 函数通过从第一个字符开始检查,若相等,则检查下一个,直到找到不相等或到达字符串的末尾。 函数 int lstrcmpA(LPCSTR lpString1, LPCSTR lpString2); 参数 lpStr…

代码随想录60期day38

2维背包 #include<bits/stdc.h> using namespace std;int main(){int n,bagweight;cin>>n>>bagweight;vector<int>weight(n,0);vector<int>value(n,0);for(int i 0 ; i <n;i){cin>>weight[i];}for(int j 0;j<n;j){cin>>val…

[模型部署] 1. 模型导出

&#x1f44b; 你好&#xff01;这里有实用干货与深度分享✨✨ 若有帮助&#xff0c;欢迎&#xff1a;​ &#x1f44d; 点赞 | ⭐ 收藏 | &#x1f4ac; 评论 | ➕ 关注 &#xff0c;解锁更多精彩&#xff01;​ &#x1f4c1; 收藏专栏即可第一时间获取最新推送&#x1f514;…

mac的Cli为什么输入python3才有用python --version显示无效,pyenv入门笔记,如何查看mac自带的标准库模块

根据你的终端输出&#xff0c;可以得出以下结论&#xff1a; 1. 你的 Mac 当前只有一个 Python 版本 系统默认的 Python 3 位于 /usr/bin/python3&#xff08;这是 macOS 自带的 Python&#xff09;通过 which python3 确认当前使用的就是系统自带的 Pythonbrew list python …

Java注解详解:从入门到实战应用篇

1. 引言 Java注解&#xff08;Annotation&#xff09;是JDK 5.0引入的一种元数据机制&#xff0c;用于为代码提供附加信息。它广泛应用于框架开发、代码生成、编译检查等领域。本文将从基础到实战&#xff0c;全面解析Java注解的核心概念和使用场景。 2. 注解基础概念 2.1 什…

前端方法的总结及记录

个人简介 &#x1f468;‍&#x1f4bb;‍个人主页&#xff1a; 魔术师 &#x1f4d6;学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全栈发展 &#x1f6b4;个人状态&#xff1a; 研发工程师&#xff0c;现效力于政务服务网事业 &#x1f1e8;&#x1f1f3;人生格言&…

组件导航 (HMRouter)+flutter项目搭建-混合开发+分栏效果

组件导航 (Navigation)flutter项目搭建 接上一章flutter项目的环境变量配置并运行flutter 1.flutter创建项目并运行 flutter create fluter_hmrouter 进入ohos目录打开编辑器先自动签名 编译项目-生成签名包 flutter build hap --debug 运行项目 HMRouter搭建安装 1.安…

城市排水管网流量监测系统解决方案

一、方案背景 随着工业的不断发展和城市人口的急剧增加&#xff0c;工业废水和城市污水的排放量也大量增加。目前&#xff0c;我国已成为世界上污水排放量大、增加速度快的国家之一。然而&#xff0c;总体而言污水处理能力较低&#xff0c;有相当部分未经处理的污水直接或间接排…

TCP/IP 知识体系

TCP/IP 知识体系 一、TCP/IP 定义 全称&#xff1a;Transmission Control Protocol/Internet Protocol&#xff08;传输控制协议/网际协议&#xff09;核心概念&#xff1a; 跨网络实现信息传输的协议簇&#xff08;包含 TCP、IP、FTP、SMTP、UDP 等协议&#xff09;因 TCP 和…

5G行业专网部署费用详解:投资回报如何最大化?

随着数字化转型的加速&#xff0c;5G行业专网作为企业提升生产效率、保障业务安全和实现智能化管理的重要基础设施&#xff0c;正受到越来越多行业客户的关注。部署5G专网虽然前期投入较大&#xff0c;但通过合理规划和技术选择&#xff0c;能够实现投资回报的最大化。 在5G行…

网页工具-OTU/ASV表格物种分类汇总工具

AI辅助下开发了个工具&#xff0c;功能如下&#xff0c;分享给大家&#xff1a; 基于Shiny开发的用户友好型网页应用&#xff0c;专为微生物组数据分析设计。该工具能够自动处理OTU/ASV_taxa表格&#xff08;支持XLS/XLSX/TSV/CSV格式&#xff09;&#xff0c;通过调用QIIME1&a…

【超分辨率专题】一种考量视频编码比特率优化能力的超分辨率基准

这是一个Benchmark&#xff0c;超分辨率视频编码&#xff08;2024&#xff09; 专题介绍一、研究背景二、相关工作2.1 SR的发展2.2 SR benchmark的发展 三、Benchmark细节3.1 数据集制作3.2 模型选择3.3 编解码器和压缩标准选择3.4 Benchmark pipeline3.5 质量评估和主观评价研…

保姆教程-----安装MySQL全过程

1.电脑从未安装过mysql的&#xff0c;先找到mysql官网&#xff1a;MySQL :: Download MySQL Community Server 然后下载完成后&#xff0c;找到文件&#xff0c;然后双击打开 2. 选择安装的产品和功能 依次点开“MySQL Servers”、“MySQL Servers”、“MySQL Servers 5.7”、…

【React中函数组件和类组件区别】

在 React 中,函数组件和类组件是两种构建组件的方式,它们在多个方面存在区别,以下详细介绍: 1. 语法和定义 类组件:使用 ES6 的类(class)语法定义,继承自 React.Component。需要通过 this.props 来访问传递给组件的属性(props),并且通常要实现 render 方法返回 JSX…

[基础] HPOP、SGP4与SDP4轨道传播模型深度解析与对比

HPOP、SGP4与SDP4轨道传播模型深度解析与对比 文章目录 HPOP、SGP4与SDP4轨道传播模型深度解析与对比第一章 引言第二章 模型基础理论2.1 历史演进脉络2.2 动力学方程统一框架 第三章 数学推导与摄动机制3.1 SGP4核心推导3.1.1 J₂摄动解析解3.1.2 大气阻力建模改进 3.2 SDP4深…

搭建运行若依微服务版本ruoyi-cloud最新教程

搭建运行若依微服务版本ruoyi-cloud 一、环境准备 JDK > 1.8MySQL > 5.7Maven > 3.0Node > 12Redis > 3 二、后端 2.1数据库准备 在navicat上创建数据库ry-seata、ry-config、ry-cloud运行SQL文件ry_20250425.sql、ry_config_20250224.sql、ry_seata_2021012…

Google I/O 2025 观看攻略一键收藏,开启技术探索之旅!

AIGC开放社区https://lerhk.xetlk.com/sl/1SAwVJ创业邦https://weibo.com/1649252577/PrNjioJ7XCSDNhttps://live.csdn.net/room/csdnnews/OOFSCy2g/channel/collectiondetail?sid2941619DONEWShttps://www.donews.com/live/detail/958.html凤凰科技https://flive.ifeng.com/l…