KNN算法详解

KNN算法

KNN算法简介

【理解】KNN算法思想

K-近邻算法(K Nearest Neighbor,简称KNN)。比如:根据你的“邻居”来推断出你的类别

KNN算法思想:如果一个样本在特征空间中的 k 个最相似的样本中的大多数属于某一个类别,则该样本也属于这个类别

思考:如何确定样本的相似性?

样本相似性:样本都是属于一个任务数据集的。样本距离越近则越相似。

利用K近邻算法预测电影类型

【知道】K值的选择

  1. K值过小:相当于用较小领域中的训练实例进行预测容易受到异常点的影响。K值的减小就意味着整体模型变得复杂,容易发生过拟合
    举例:K=N(N为训练样本个数)
    无论输入实例是什么,只会按训练集中最多的类别进行预测,受到样本均衡的影响
  2. K值过大:相当于用较大领域中的训练实例进行预测受到样本均衡的问题。且K值的增大就意味着整体的模型变得简单,欠拟合
    如何对K值超参数进行调优?
    需要一些方法来寻找这个最合适的K值
    交叉验证、网格搜索

【知道】KNN的应用方式

  • 解决问题:分类问题、回归问题

  • 算法思想:若一个样本在特征空间中的 k 个最相似的样本大多数属于某一个类别,则该样本也属于这个类别

  • 相似性:欧氏距离

  • 分类问题的处理流程:

1.计算未知样本到每一个训练样本的距离

2.将训练样本根据距离大小升序排列

3.取出距离最近的 K 个训练样本

4.进行多数表决,统计 K 个样本中哪个类别的样本个数最多

5.将未知的样本归属到出现次数最多的类别

  • 回归问题的处理流程:

1.计算未知样本到每一个训练样本的距离

2.将训练样本根据距离大小升序排列

3.取出距离最近的 K 个训练样本

4.把这个 K 个样本的目标值计算其平均值

5.作为将未知的样本预测的值

API介绍

分类API

KNN分类API:

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5)

​ n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数

回归API

KNN分类API:

sklearn.neighbors.KNeighborsRegressor(n_neighbors=5)
# 1.工具包fromsklearn.neighborsimportKNeighborsClassifier,KNeighborsRegressor# from sklearn.neighbors import KNeighborsRegressor# 2.数据(特征工程)# 分类# x = [[0,2,3],[1,3,4],[3,5,6],[4,7,8],[2,3,4]]# y = [0,0,1,1,0]x=[[0,1,2],[1,2,3],[2,3,4],[3,4,5]]y=[0.1,0.2,0.3,0.4]# 3.实例化# model =KNeighborsClassifier(n_neighbors=3)model=KNeighborsRegressor(n_neighbors=3)# 4.训练model.fit(x,y)# 5.预测print(model.predict([[4,4,5]]))

距离度量方法

欧式距离

曼哈顿距离

切比雪夫距离

闵氏距离

闵可夫斯基距离 Minkowski Distance 闵氏距离,不是一种新的距离的度量方式。而是距离的组合 是对多个距离度量公式的概括性的表述

特征预处理

为什么进行归一化、标准化

特征的单位或者大小相差较大,或者某特征的方差相比其他的特征要大出几个数量级容易影响(支配)目标结果,使得一些模型(算法)无法学习到其它的特征。

归一化

通过对原始数据进行变换把数据映射到【mi,mx】(默认为[0,1])之间

数据归一化的API实现

sklearn.preprocessing.MinMaxScaler (feature_range=(0,1)… )

feature_range 缩放区间

  • 调用 fit_transform(X) 将特征进行归一化缩放

归一化受到最大值与最小值的影响,这种方法容易受到异常数据的影响, 鲁棒性较差,适合传统精确小数据场景

标准化

通过对原始数据进行标准化,转换为均值为0标准差为1的标准正态分布的数据

  • mean 为特征的平均值
  • σ 为特征的标准差

数据标准化的API实现

sklearn.preprocessing.StandardScaler()

调用 fit_transform(X) 将特征进行归一化缩放

# 1.导入工具包fromsklearn.preprocessingimportMinMaxScaler,StandardScaler# 2.数据(只有特征)x=[[90,2,10,40],[60,4,15,45],[75,3,13,46]]# 3.实例化(归一化,标准化)# process =MinMaxScaler()process=StandardScaler()# 4.fit_transform 处理1data=process.fit_transform(x)# print(data)print(process.mean_)print(process.var_)

对于标准化来说,如果出现异常点,由于具有一定数据量,少量的异常点对于平均值的影响并不大

利用KNN算法进行鸢尾花分类

鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa和Virginica

每个花的特征用如下属性描述:

代码实现:

# 导入工具包fromsklearn.datasetsimportload_iris# 加载鸢尾花测试集的.importseabornassnsimportpandasaspdimportmatplotlib.pyplotaspltfromsklearn.model_selectionimporttrain_test_split# 分割训练集和测试集的fromsklearn.preprocessingimportStandardScaler# 数据标准化的fromsklearn.neighborsimportKNeighborsClassifier# KNN算法 分类对象fromsklearn.metricsimportaccuracy_score# 模型评估的, 计算模型预测的准确率# 1. 定义函数 dm01_loadiris(), 加载数据集.defdm01_loadiris():# 1. 加载数据集, 查看数据iris_data=load_iris()print(iris_data)# 字典形式, 键: 属性名, 值: 数据.print(iris_data.keys())# 1.1 查看数据集print(iris_data.data[:5])# 1.2 查看目标值.print(iris_data.target)# 1.3 查看目标值名字.print(iris_data.target_names)# 1.4 查看特征名.print(iris_data.feature_names)# 1.5 查看数据集的描述信息.print(iris_data.DESCR)# 1.6 查看数据文件路径print(iris_data.filename)# 2. 定义函数 dm02_showiris(), 显示鸢尾花数据.defdm02_showiris():# 1. 加载数据集, 查看数据iris_data=load_iris()# 2. 数据展示# 读取数据, 并设置 特征名为列名.iris_df=pd.DataFrame(iris_data.data,columns=iris_data.feature_names)# print(iris_df.head(5))iris_df['label']=iris_data.target# 可视化, x=花瓣长度, y=花瓣宽度, data=iris的df对象, hue=颜色区分, fit_reg=False 不绘制拟合回归线.sns.lmplot(x='petal length (cm)',y='petal width (cm)',data=iris_df,hue='label',fit_reg=False)plt.title('iris data')plt.show()# 3. 定义函数 dm03_train_test_split(), 实现: 数据集划分defdm03_train_test_split():# 1. 加载数据集, 查看数据iris_data=load_iris()# 2. 划分数据集, 即: 特征工程(预处理-标准化)x_train,x_test,y_train,y_test=train_test_split(iris_data.data,iris_data.target,test_size=0.2,random_state=22)print(f'数据总数量:{len(iris_data.data)}')print(f'训练集中的x-特征值:{len(x_train)}')print(f'训练集中的y-目标值:{len(y_train)}')print(f'测试集中的x-特征值:{len(x_test)}')# 4. 定义函数 dm04_模型训练和预测(), 实现: 模型训练和预测defdm04_model_train_and_predict():# 1. 加载数据集, 查看数据iris_data=load_iris()# 2. 划分数据集, 即: 数据基本处理x_train,x_test,y_train,y_test=train_test_split(iris_data.data,iris_data.target,test_size=0.2,random_state=22)# 3. 数据集预处理-数据标准化(即: 标准的正态分布的数据集)transfer=StandardScaler()# fit_transform(): 适用于首次对数据进行标准化处理的情况,通常用于训练集, 能同时完成 fit() 和 transform()。x_train=transfer.fit_transform(x_train)# transform(): 适用于对测试集进行标准化处理的情况,通常用于测试集或新的数据. 不需要重新计算统计量。x_test=transfer.transform(x_test)# 4. 机器学习(模型训练)estimator=KNeighborsClassifier(n_neighbors=5)estimator.fit(x_train,y_train)# 5. 模型评估.# 场景1: 对抽取出的测试集做预测.# 5.1 模型评估, 对抽取出的测试集做预测.y_predict=estimator.predict(x_test)print(f'预测结果为:{y_predict}')# 场景2: 对新的数据进行预测.# 5.2 模型预测, 对测试集进行预测.# 5.2.1 定义测试数据集.my_data=[[5.1,3.5,1.4,0.2]]# 5.2.2 对测试数据进行-数据标准化.my_data=transfer.transform(my_data)# 5.2.3 模型预测.my_predict=estimator.predict(my_data)print(f'预测结果为:{my_predict}')# 5.2.4 模型预测概率, 返回每个类别的预测概率my_predict_proba=estimator.predict_proba(my_data)print(f'预测概率为:{my_predict_proba}')# 6. 模型预估, 有两种方式, 均可.# 6.1 模型预估, 方式1: 直接计算准确率, 100个样本中模型预测正确的个数.my_score=estimator.score(x_test,y_test)print(my_score)# 0.9666666666666667# 6.2 模型预估, 方式2: 采用预测值和真实值进行对比, 得到准确率.print(accuracy_score(y_test,y_predict))# 在main方法中测试.if__name__=='__main__':# 1. 调用函数 dm01_loadiris(), 加载数据集.# dm01_loadiris()# 2. 调用函数 dm02_showiris(), 显示鸢尾花数据.# dm02_showiris()# 3. 调用函数 dm03_train_test_split(), 查看: 数据集划分# dm03_train_test_split()# 4. 调用函数 dm04_模型训练和预测(), 实现: 模型训练和预测dm04_model_train_and_predict()

超参数选择的方法

交叉验证

交叉验证是一种数据集的分割方法,将训练集划分为 n 份,其中一份做验证集、其他n-1份做训练集集

交叉验证法原理:将数据集划分为 cv=10 份:

1.第一次:把第一份数据做验证集,其他数据做训练

2.第二次:把第二份数据做验证集,其他数据做训练

3… 以此类推,总共训练10次,评估10次。

4.使用训练集+验证集多次评估模型,取平均值做交叉验证为模型得分

5.若k=5模型得分最好,再使用全部训练集(训练集+验证集) 对k=5模型再训练一边,再使用测试集对k=5模型做评估

交叉验证法,是划分数据集的一种方法,目的就是为了得到更加准确可信的模型评分。

【知道】网格搜索

交叉验证网格搜索的API:

交叉验证网格搜索在鸢尾花分类中的应用:

fromsklearn.datasetsimportload_iris# 加载鸢尾花测试集的.fromsklearn.model_selectionimporttrain_test_split,GridSearchCV# 分割训练集和测试集的, 网格搜索的fromsklearn.preprocessingimportStandardScaler# 数据标准化的fromsklearn.neighborsimportKNeighborsClassifier# KNN算法 分类对象fromsklearn.metricsimportaccuracy_score# 模型评估的, 计算模型预测的准确率# 1. 获取数据集.iris_data=load_iris()# 2. 数据基本处理-划分数据集.x_train,x_test,y_train,y_test=train_test_split(iris_data.data,iris_data.target,test_size=0.2,random_state=22)# 3. 数据集预处理-数据标准化.transfer=StandardScaler()x_train=transfer.fit_transform(x_train)x_test=transfer.transform(x_test)# 4. 模型训练.# 4.1 创建估计器对象.estimator=KNeighborsClassifier()# 4.2 使用校验验证网格搜索. 指定参数范围.param_grid={"n_neighbors":range(1,10)}# 4.3 具体的 网格搜索过程 + 交叉验证.# 参1: 估计器对象, 参2: 参数范围, 参3: 交叉验证的折数.estimator=GridSearchCV(estimator=estimator,param_grid=param_grid,cv=5)# 具体的模型训练过程.estimator.fit(x_train,y_train)# 4.4 交叉验证, 网格搜索结果查看.print(estimator.best_score_)# 模型在交叉验证中, 所有参数组合中的最高平均测试得分print(estimator.best_estimator_)# 最优的估计器对象.print(estimator.cv_results_)# 模型在交叉验证中的结果.print(estimator.best_params_)# 模型在交叉验证中的结果.# 5. 得到最优模型后, 对模型重新预测.estimator=KNeighborsClassifier(n_neighbors=6)estimator.fit(x_train,y_train)print(f'模型评估:{estimator.score(x_test,y_test)}')# 因为数据量和特征的问题, 该值可能小于上述的平均测试得分.

利用KNN算法实现手写数字识别

MNIST手写数字识别 是计算机视觉领域中 "hello world"级别的数据集

  • 1999年发布,成为分类算法基准测试的基础
  • 随着新的机器学习技术的出现,MNIST仍然是研究人员和学习者的可靠资源。

本次案例中,目标是从数万个手写图像的数据集中正确识别数字。

数据介绍

数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像。

  • 每个图像高 28 像素,宽28 像素,共784个像素。

  • 每个像素取值范围[0,255],取值越大意味着该像素颜色越深

  • 训练数据集(train.csv)共785列。第一列为 “标签”,为该图片对应的手写数字。其余784列为该图像的像素值

  • 训练集中的特征名称均有pixel前缀,后面的数字([0,783])代表了像素的序号。

像素组成图像如下:

000001002003...026027028029030031...054055056057058059...082083||||......||728729730731...754755756757758759...782783

数据集示例如下:

importmatplotlib.pyplotaspltimportpandasaspdfromsklearn.model_selectionimporttrain_test_splitfromsklearn.neighborsimportKNeighborsClassifierimportjoblibfromcollectionsimportCounter# 1. 显示图片.defshow_digit(idx):# 1.1 加载数据.data=pd.read_csv('手写数字识别.csv')# 1.2非法值校验.ifidx<0oridx>len(data)-1:return# 1.3 打印数据基本信息x=data.iloc[:,1:]y=data.iloc[:,0]print(f'数据基本信息:{x.shape})')print(f'类别数据比例:{Counter(y)}')# 显示图片# 1.4 将数据形状修改为: 28*28digit=x.iloc[idx].values.reshape(28,28)# 1.5 关闭坐标轴标签plt.axis('off')# 1.6 显示图像plt.imshow(digit,cmap='gray')# 灰色显示plt.show()# 2. 训练模型.deftrain_model():# 1. 加载数据.data=pd.read_csv('手写数字识别.csv')x=data.iloc[:,1:]y=data.iloc[:,0]# 2.数据预处理, 归一化.x=x/255# 3. 分割训练集和测试集.# stratify: 按照y的类别比例进行分割x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,stratify=y,random_state=21)# 4. 训练模型estimator=KNeighborsClassifier(n_neighbors=3)estimator.fit(x_train,y_train)# 5. 模型评估my_score=estimator.score(x_test,y_test)print(f'测试集准确率为:{my_score:.2f}')# 6. 模型保存.joblib.dump(estimator,'model/knn.pth')# 3. 测试模型.defuse_model():# 1. 读取图片img=plt.imread('data/demo.png')# 灰度图, 28*28像素plt.imshow(img,cmap='gray')plt.show()# 2. 加载模型.estimator=joblib.load('model/knn.pth')# 3. 预测图片.img=img.reshape(1,-1)# 形状从: (28, 28) => (1, 784)# print(img.shape)y_test=estimator.predict(img)print(f'您绘制的数字是:{y_test}')# 在main函数中测试if__name__=='__main__':# 1. 调用函数, 查看图片.# show_digit(0)# show_digit(10)# show_digit(100)# 2. 训练模型.# train_model()# 3. 测试模型use_model()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1159414.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手把手AI论文工具全攻略:9款神器精准控率无压力操作指南

同学们&#xff0c;还在为论文发愁吗&#xff1f;从开题报告到文献综述&#xff0c;从数据分析到格式排版&#xff0c;每一步都感觉压力山大&#xff1f;别担心&#xff0c;AI时代已经为我们送来了强大的“学术神器”。今天&#xff0c;我将化身你的专属论文助教&#xff0c;为…

卡尔曼滤波做轨迹跟踪 鲁棒卡尔曼滤波做野值剔除后的预测 扩展卡尔曼滤波对GPS数据进行状态估计滤波

卡尔曼滤波做轨迹跟踪 鲁棒卡尔曼滤波做野值剔除后的预测 扩展卡尔曼滤波对GPS数据进行状态估计滤波 轨迹跟踪这活儿听起来玄乎&#xff0c;其实咱们每天都在用——手机导航里那个蓝色小圆点&#xff0c;背后八成藏着卡尔曼滤波的数学魔法。今天咱们扯点实在的&#xff0c;用P…

2026年PLC厂家推荐:2026年度权威评测与市场格局排名解析

摘要 在工业4.0与智能制造转型的宏观趋势下&#xff0c;可编程逻辑控制器作为自动化系统的“大脑”&#xff0c;其选型决策直接关系到生产线的可靠性、灵活性与长期数字化升级潜力。当前&#xff0c;企业决策者面临的核心焦虑在于&#xff1a;如何在技术路线日趋多元、开放与封…

国外学术论文怎么找:实用检索技巧与资源平台推荐

刚开始做科研的时候&#xff0c;我一直以为&#xff1a; 文献检索就是在知网、Google Scholar 里反复换关键词。 直到后来才意识到&#xff0c;真正消耗精力的不是“搜不到”&#xff0c;而是—— 你根本不知道最近这个领域发生了什么。 生成式 AI 出现之后&#xff0c;学术检…

langchain 使用 MessagesPlaceholder 实现会话上下文

第一步&#xff1a;创建带历史消息占位符的提示词模板from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholderprompt ChatPromptTemplate.from_messages([SystemMessage(content"你是3DM的一个技术专家&#xff0c;擅长解决各种Web开发中的技术问题…

langchain 创建智能体,并使用saver保存会话消息

简单创建智能体并调用 """ 可参考官方文档地址&#xff1a;https://docs.langchain.com/oss/python/langchain/agents 智能体会遵循 ReAct&#xff08;“推理行动”&#xff09;模式&#xff0c;交替进行简短的推理步骤和针对性工具调用&#xff0c;并将所得观察…

python基于vue的江西特色乡村综合风貌展示平台django flask pycharm

目录技术框架与开发工具功能模块设计数据库与性能优化特色创新点应用价值与推广开发技术路线相关技术介绍核心代码参考示例结论源码lw获取/同行可拿货,招校园代理 &#xff1a;文章底部获取博主联系方式&#xff01;技术框架与开发工具 该平台采用Python作为后端核心语言&…

langchian 使用外部MCP工具创建自己的MCP服务

普通智能体接入高德MCP mcp协议官网&#xff1a;https://modelcontextprotocol.info/zh-cn/ 关于什么是mcp建议大家看一下MCP官网就可以&#xff0c;首页的描述就非常贴切 “AI应用的USB-C接口” 接口通用&#xff0c;功能强大接入流程 官网示例&#xff1a;https://docs.langc…

【开题答辩全过程】以 基于Java的大学生兼职信息系统的设计与实现为例,包含答辩的问题和答案

个人简介一名14年经验的资深毕设内行人&#xff0c;语言擅长Java、php、微信小程序、Python、Golang、安卓Android等开发项目包括大数据、深度学习、网站、小程序、安卓、算法。平常会做一些项目定制化开发、代码讲解、答辩教学、文档编写、也懂一些降重方面的技巧。感谢大家的…

2026年强推新兴GEO服务商-微盟星启:抢占AI搜索心智打造品牌增长

一、AI搜索重构增长逻辑&#xff0c;品牌亟需“标准答案”破局当生成式AI成为搜索领域的核心变革力量&#xff0c;互联网用户的信息获取方式正发生根本性转变。不同于传统关键词搜索的“信息罗列”&#xff0c;AI搜索更倾向于输出“整合式、结论性”的答案&#xff0c;这直接重…

Spring 中 Servlet 容器和 Python FastAPI 对比

一、核心架构对比&#xff1a;Java Servlet vs. Python ASGI 下表清晰地展示了两个生态在对应层级上的核心组件与关系&#xff1a;架构层级核心职责Java / Servlet 生态Python / ASGI 生态1. 网络与协议层处理原始HTTP请求/响应、连接管理、线程/进程调度。Tomcat, Jetty, Unde…

虎贲等考 AI:重构学术创作新范式,一站式智能论文写作解决方案

在学术研究与论文写作的赛道上&#xff0c;研究者与学子们往往面临选题迷茫、文献繁杂、实证低效、合规棘手等多重困境。虎贲等考 AI 智能写作平台&#xff08;官网&#xff1a;https://www.aihbdk.com/&#xff09;应势而生&#xff0c;作为一款基于前沿人工智能技术打造的专业…

揭秘Emmi AI每月人均千欧的远程团队协作模式

Emmi AI是一家奥地利深度科技公司&#xff0c;致力于构建人工智能驱动的物理仿真技术&#xff0c;以加速流体动力学、多物理场和固体力学等领域的工程流程。 对于从事此类工作的公司而言&#xff0c;人员协作方式与技术本身同等重要。该公司采用了一种混合、远程优先的模式&…

拒稿率暴跌 90%!虎贲等考 AI 期刊论文功能:从初稿到录用的 “学术加速器”

《自然》期刊统计显示&#xff0c;全球 78% 的学术论文因写作问题被拒稿&#xff0c;其中结构性缺陷占 53%&#xff0c;学术规范失误占 32%。对于科研人来说&#xff0c;撰写期刊论文不仅要攻克研究难题&#xff0c;还要面对文献梳理、格式规范、查重降重等一系列 “附加关卡”…

langchain 常见提示词模板使用案例

大模型对象创建&调用 """ 大模型共用定义""" import os from dotenv import load_dotenv from langchain_openai import ChatOpenAI load_dotenv()# 创建大模型对象 llm ChatOpenAI(model"qwen-max-latest",base_url"https…

langchain的工具调用

Tools 就是给大模型安装的"手和脚"&#xff0c;让大模型能够调用外部函数/API来获取实时信息或执行具体操作。Tools 的工作流程 完整流程 用户问题 → 大模型思考 → 调用Tool → 执行Tool → 结果返回 → 大模型重新组织 → 最终回答 # 1. 用户提问 user_question …

告别熬夜做 PPT!虎贲等考 AI PPT:学术汇报的 “一键焕新” 神器

学术汇报的终极痛点是什么&#xff1f;不是论文写得不够好&#xff0c;而是熬了三个通宵做的 PPT&#xff0c;被导师批 “逻辑混乱、排版杂乱、重点不明”。从开题汇报、中期答辩到最终答辩&#xff0c;每一次 PPT 制作都像一场耗时耗力的 “硬仗”—— 既要提炼论文核心观点&a…

销售要少夸赞自己实力强,多问问客户害怕什么

制造业的销售常常会犯一个致命的错误&#xff1a;一和客户见面就急着向对方证明“我们技术领先同行”“设备精度非常高”“服务响应速度快”……但客户内心里想的却是&#xff1a;“你说得再好&#xff0c;万一出现问题&#xff0c;这个责任还是得我来承担&#xff0c;”在责任…

GetX 从 0 开始:理解 Flutter 的“对象级响应式系统”

很多人听说 GetX&#xff0c;是因为它“什么都能干”&#xff1a;状态管理、路由、依赖注入。 但如果一上来就学 API&#xff0c;很容易学成“工具集合”。 这篇文章只做一件事&#xff1a; &#x1f449; 从 0 建立对 GetX 的正确认知&#xff1a;它到底解决什么问题&#xff…

极致感知与定位:基于电鱼智能 RK3588 的 AMR 机器人高精度 vSLAM 导航方案

为什么 AMR 机器人首选 RK3588 进行 vSLAM&#xff1f;1. 多核异构算力匹配 vSLAM 任务链vSLAM 算法包含高度复杂的流水线&#xff0c;RK3588 的异构架构可以实现完美的分工&#xff1a;Cortex-A76 高大核&#xff1a;负责前端视觉里程计&#xff08;VO&#xff09;的特征点提取…