文章目录
- 1. k近邻算法
- 2. k近邻模型
- 2.1 模型
- 2.2 距离度量
- 2.2.1 距离计算代码 Python
- 2.3 kkk 值的选择
- 2.4 分类决策规则
- 3. 实现方法, kd树
- 3.1 构造 kdkdkd 树
- Python 代码
- 3.2 搜索 kdkdkd 树
- Python 代码
- 4. 鸢尾花KNN分类
- 4.1 KNN实现
- 4.2 sklearn KNN
- 5. 文章完整代码
k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。
- 输入:实例的特征向量,对应于特征空间的点
- 输出:实例的类别,可以取多类
- 假设:给定一个训练数据集,其中的实例类别已定。
- 分类:对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,k近邻法不具有显式的学习过程。
- k近邻法实际上利用训练数据集对特征向量空间进行划分,并作为其分类的“模型”。
k近邻法1968年由Cover
和Hart
提出。
1. k近邻算法
输入:一组训练数据集,特征向量 xix_ixi,及其类别 yiy_iyi,给定实例特征向量 xxx
输出:实例 xxx 所属的类 yyy
- 根据距离度量,在训练集中找出与 xxx 最邻近的 kkk 个点,涵盖这 kkk 个点的 xxx 的邻域记为 Nk(x)N_k(x)Nk(x)
- 在 Nk(x)N_k(x)Nk(x) 中根据分类决策规则(如,多数表决)决定 xxx 的类别 yyy
y=arg maxcj∑xi∈Nk(x)I(yi=cj),i=1,2,...,N,j=1,2,...,Ky = \argmax\limits_{c_j} \sum\limits_{x_i \in N_k(x) } I(y_i = c_j),\quad i=1,2,...,N, j = 1,2,...,Ky=cjargmaxxi∈Nk(x)∑I(yi=cj),i=1,2,...,N,j=1,2,...,K
III 为指示函数,表示当 yi=cjy_i=c_jyi=cj 时 III 为 1, 否则 III 为 0
当 k=1k=1k=1 时,特殊情况,称为最近邻算法,跟它距离最近的点作为其分类
2. k近邻模型
三要素:k值的选择、距离度量、分类决策规则
2.1 模型
- kkk 近邻模型,三要素确定后,对于任何一个新的输入实例,它的类唯一确定。
- 这相当于根据上述要素将特征空间划分为一些子空间,确定子空间里的每个点所属的类。这一事实从最近邻算法中可以看得很清楚。
2.2 距离度量
空间中两个点的距离是两个实例相似程度的反映。
- LpL_pLp 距离:
设特征 xix_ixi 是 nnn 维的,Lp(xi,xj)=(∑l=1n∣xi(l)−xj(l)∣p)1pL_p(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^p \bigg)^{\frac{1}{p}}Lp(xi,xj)=(l=1∑n∣xi(l)−xj(l)∣p)p1 - 欧氏距离:上面 p=2p=2p=2 时,L2(xi,xj)=(∑l=1n∣xi(l)−xj(l)∣2)12L_2(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^2 \bigg)^{\frac{1}{2}}L2(xi,xj)=(l=1∑n∣xi(l)−xj(l)∣2)21
- 曼哈顿距离:上面 p=1p=1p=1 时,L1(xi,xj)=∑l=1n∣xi(l)−xj(l)∣L_1(x_i,x_j) = \sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|L1(xi,xj)=l=1∑n∣xi(l)−xj(l)∣
- 切比雪夫距离:当 p=∞p=\inftyp=∞ 时,它是坐标距离的最大值:L∞(xi,xj)=maxl∣xi(l)−xj(l)∣L_\infty(x_i,x_j) = \max\limits_l |x_i^{(l)}-x_j^{(l)}|L∞(xi,xj)=lmax∣xi(l)−xj(l)∣
2.2.1 距离计算代码 Python
import mathdef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0
x1 = [1, 1]
x2 = [5, 1]
x3 = [4, 4]
X = [x1, x2, x3]
for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距离是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p)))
x1,x2的L1距离是:4.00
x1,x2的L2距离是:4.00
x1,x2的L3距离是:4.00
x1,x2的L4距离是:4.00
x1,x3的L1距离是:6.00
x1,x3的L2距离是:4.24
x1,x3的L3距离是:3.78
x1,x3的L4距离是:3.57
x2,x3的L1距离是:4.00
x2,x3的L2距离是:3.16
x2,x3的L3距离是:3.04
x2,x3的L4距离是:3.01
2.3 kkk 值的选择
-
k值的选择会对k近邻法的结果产生重大影响。
-
选较小的 k 值,相当于用较小的邻域中的训练实例进行预测,“学习”的近似误差(approximation error)会减小,只有与输入实例较近的(相似的)训练实例才会对预测结果起作用。但缺点是“学习”的估计误差(estimation error)会增大,预测结果会对近邻的实例点非常敏感。
-
如果邻近的实例点恰巧是噪声,预测就会出错。换句话说,k值的减小就意味着整体模型变得复杂,容易发生过拟合。
-
选较大的 k 值,相当于用较大邻域中的训练实例进行预测。优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。这时与输入实例较远的(不相似的)训练实例也会对预测起作用,使预测发生错误。
-
k值的增大就意味着整体的模型变得简单。
-
如果 k=N,无论输入实例是什么,都将简单地预测它属于在训练实例中最多的类。模型过于简单,完全忽略大量有用信息,不可取。
-
应用中,k 值一般取一个比较小的数值。通常采用交叉验证法来选取最优的 k 值。
2.4 分类决策规则
- 多数表决(
majority voting rule
)
假设损失函数为0-1损失,对于 xix_ixi 的近邻域 Nk(x)N_k(x)Nk(x) 的分类是 cjc_jcj,那么误分类率是:
1k∑xi∈Nk(x)I(yi≠cj)=1−1k∑xi∈Nk(x)I(yi=cj)\frac{1}{k} \sum\limits_{x_i \in N_k(x) }I(y_i \neq c_j) = 1- \frac{1}{k}\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)k1xi∈Nk(x)∑I(yi=cj)=1−k1xi∈Nk(x)∑I(yi=cj)
要使误分类率最小,那么就让 ∑xi∈Nk(x)I(yi=cj)\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)xi∈Nk(x)∑I(yi=cj) 最大,所以选多数的那个类(经验风险最小化)
3. 实现方法, kd树
-
算法实现时,需要对大量的点进行距离计算,复杂度是 O(n2)O(n^2)O(n2),训练集很大时,效率低,不可取
-
考虑特殊的结构存储训练数据,以减少计算距离次数,如 kdkdkd 树
3.1 构造 kdkdkd 树
kdkdkd 树是一种对 k 维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
- kdkdkd 树是二叉树,表示对k维空间的一个划分(partition)。
- 构造 kdkdkd 树相当于不断地用垂直于坐标轴的超平面将 k 维空间切分,构成一系列的k维超矩形区域。
- kdkdkd 树的每个结点对应于一个 k 维超矩形区域。
构造 kdkdkd 树的方法:
- 根结点:使根结点对应于k维空间中包含所有实例点的超矩形区域;通过递归方法,不断地对 k 维空间进行切分,生成子结点
- 在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,将当前超矩形区域切分为左右两个子区域(子结点)
- 实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。
Python 代码
class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)self.split = split # 整数(进行分割维度的序号)self.left = left # 该结点分割超平面左子空间构成的kd-treeself.right = right # 该结点分割超平面右子空间构成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 实例的向量维度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right)
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KdTree(data)
kd.preorder(kd.root)
运行结果:
[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]
3.2 搜索 kdkdkd 树
给定目标点,搜索其最近邻。
- 先找到包含目标点的叶结点
- 从该叶结点出发,依次回退到父结点;不断查找与目标点最邻近的结点
- 当确定不可能存在更近的结点时终止。
- 这样搜索就被限制在空间的局部区域上,效率大为提高。
- 目标点的最近邻一定在以目标点为中心并通过当前最近点的超球体的内部。
- 然后返回当前结点的父结点,如果父结点的另一子结点的超矩形区域与超球体相交,那么在相交的区域内寻找与目标点更近的实例点。
- 如果存在这样的点,将此点作为新的当前最近点。算法转到更上一级的父结点,继续上述过程。
- 如果父结点的另一子结点的超矩形区域与超球体不相交,或不存在比当前最近点更近的点,则停止搜索。
Python 代码
from collections import namedtuple# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 数据维度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正负无穷nodes_visited = 1s = kd_node.split # 进行分割的维度pivot = kd_node.dom_elt # 进行分割的“轴”if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)nearer_node = kd_node.left # 下一个访问节点为左子树根节点further_node = kd_node.right # 同时记录下右子树else: # 目标离右子树更近nearer_node = kd_node.right # 下一个访问节点为右子树根节点further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”dist = temp1.nearest_dist # 更新最近距离nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离if max_dist < temp_dist: # 判断超球体是否与超平面相交return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断# ----------------------------------------------------------------------# 计算目标点与分割点的欧氏距离p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p-t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近点dist = temp_dist # 更新最近距离max_dist = dist # 更新超球体半径# 检查另一个子结点对应的区域是否有更近的点temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离nearest = temp2.nearest_point # 更新最近点dist = temp2.nearest_dist # 更新最近距离return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 从根节点开始递归
from time import time
from random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5])
print(ret)N = 400000
t0 = time()
kd2 = KdTree(random_points(3, N))#40万个3维点(坐标值0-1之间)
ret2 = find_nearest(kd2, [0.1, 0.5, 0.8])
t1 = time()
print("time: ", t1 - t0, " s")
print(ret2)
运行结果:40万个点,只用了4s就搜索完毕,找到最近邻点
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)
time: 4.314465284347534 s
Result_tuple(nearest_point=[0.10186986970329936, 0.5007753108096316, 0.7998708312483109], nearest_dist=0.002028350099282986, nodes_visited=49)
4. 鸢尾花KNN分类
4.1 KNN实现
# -*- coding:utf-8 -*-
# @Python Version: 3.7
# @Time: 2020/3/2 22:44
# @Author: Michael Ming
# @Website: https://michael.blog.csdn.net/
# @File: 3.KNearestNeighbors.py
# @Reference: https://github.com/fengdu78/lihang-code
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counterclass KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在训练集中取n个点出来,计算距离for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的训练集中取出剩余的,计算距离,有距离更近的,替换knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n个最近邻点的标签knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次数最多的标签,排序后最后一个 标签:出现次数max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同时遍历多个对象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮点型相等判断right_count += 1print("准确率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)if __name__ == '__main__':# ---------鸢尾花K近邻----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2种花,前两个特征X, y = data[:, :-1], data[:, -1]# 切分数据集,留20%做测试数据X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近邻选择20个,距离度量L2距离clf = KNearNeighbors(X_train, y_train, 20, 2)# 预测测试点,统计正确率clf.score(X_test, y_test)# 随意给一个点,用KNN预测其分类test_point = [4.75, 2.75]test_point_flower = '测试点' + iris.target_names[int(clf.predict(test_point))]print("测试点的类别是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文乱码plt.rcParams['axes.unicode_minus'] = False # 正常显示负号plt.legend()plt.show()
准确率:1.0000
测试点的类别是:测试点setosa
4.2 sklearn KNN
sklearn.neighbors.KNeighborsClassifier
class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform',
algorithm='auto', leaf_size=30, p=2, metric='minkowski',
metric_params=None, n_jobs=None, **kwargs)
- n_neighbors: 临近点个数
- p: 距离度量
- algorithm: 近邻算法,可选{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
- weights: 确定近邻的权重
from sklearn.neighbors import KNeighborsClassifier
clf_skl = KNeighborsClassifier(n_neighbors=50, p=4, algorithm='kd_tree')
start = time.time()
sum = 0
for i in range(100):clf_skl.fit(X_train, y_train)sum += clf_skl.score(X_test, y_test)
end = time.time()
print("平均准确率:%.4f" % (sum/100))
print("花费时间:%0.4f ms" % (1000*(end - start)/100))
5. 文章完整代码
# -*- coding:utf-8 -*-
# @Python Version: 3.7
# @Time: 2020/3/2 22:44
# @Author: Michael Ming
# @Website: https://michael.blog.csdn.net/
# @File: 3.KNearestNeighbors.py
# @Reference: https://github.com/fengdu78/lihang-code
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
import timedef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0class KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在训练集中取n个点出来,计算距离for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的训练集中取出剩余的,计算距离,有距离更近的,替换knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n个最近邻点的标签knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次数最多的标签,排序后最后一个 标签:出现次数max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同时遍历多个对象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮点型相等判断right_count += 1print("准确率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)self.split = split # 整数(进行分割维度的序号)self.left = left # 该结点分割超平面左子空间构成的kd-treeself.right = right # 该结点分割超平面右子空间构成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 实例的向量维度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right)from collections import namedtuple# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 数据维度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正负无穷nodes_visited = 1s = kd_node.split # 进行分割的维度pivot = kd_node.dom_elt # 进行分割的“轴”if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)nearer_node = kd_node.left # 下一个访问节点为左子树根节点further_node = kd_node.right # 同时记录下右子树else: # 目标离右子树更近nearer_node = kd_node.right # 下一个访问节点为右子树根节点further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”dist = temp1.nearest_dist # 更新最近距离nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离if max_dist < temp_dist: # 判断超球体是否与超平面相交return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断# ----------------------------------------------------------------------# 计算目标点与分割点的欧氏距离p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p - t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近点dist = temp_dist # 更新最近距离max_dist = dist # 更新超球体半径# 检查另一个子结点对应的区域是否有更近的点temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离nearest = temp2.nearest_point # 更新最近点dist = temp2.nearest_dist # 更新最近距离return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 从根节点开始递归if __name__ == '__main__':# ---------计算距离----------------x1 = [1, 1]x2 = [5, 1]x3 = [4, 4]X = [x1, x2, x3]for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距离是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p)))# ---------鸢尾花K近邻----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2种花,前两个特征X, y = data[:, :-1], data[:, -1]# 切分数据集,留20%做测试数据X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近邻选择20个,距离度量L2距离clf = KNearNeighbors(X_train, y_train, 20, 2)# 预测测试点,统计正确率clf.score(X_test, y_test)# 随意给一个点,用KNN预测其分类test_point = [4.75, 2.75]test_point_flower = '测试点' + iris.target_names[int(clf.predict(test_point))]print("测试点的类别是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文乱码plt.rcParams['axes.unicode_minus'] = False # 正常显示负号plt.legend()plt.show()# ---------sklearn KNN----------from sklearn.neighbors import KNeighborsClassifierclf_skl = KNeighborsClassifier(n_neighbors=50, p=4, algorithm='kd_tree')start = time.time()sum = 0for i in range(100):clf_skl.fit(X_train, y_train)sum += clf_skl.score(X_test, y_test)end = time.time()print("平均准确率:%.4f" % (sum / 100))print("花费时间:%0.4f ms" % (1000 * (end - start) / 100))# ------build KD Tree--------------data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]kd = KdTree(data)kd.preorder(kd.root)# ------search in KD Tree-----------from time import timefrom random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5])print(ret)N = 400000t0 = time()kd2 = KdTree(random_points(3, N))ret2 = find_nearest(kd2, [0.1, 0.5, 0.8])t1 = time()print("time: ", t1 - t0, " s")print(ret2)