KNN 算法–图像分类算法
找到最近的K个邻居,在前k个最近样本中选择最近的占比最高的类别作为预测类别。
- 给定测试对象,计算它与训练集中每个对象的距离。
- 圈定距离最近的k个训练对象,作为测试对象的邻居。
- 根据这k个紧邻对象所属的类别,找到占比最高的那个类别作为测试对象的预测类别。
影响因素:
- 计算测试对象与训练集中各个对象的距离。
- k的选择。
import operatorimport numpy as np
import matplotlib.pyplot as pltdef create_data_set():group = np.array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5], [1.1, 1.0], [0.5, 1.5]])labels = np.array(['A', 'A', 'B', 'B', 'A', 'B'])return group, labelsdef knn_classify(k, dis, X_train, x_train, Y_test):assert dis == 'E' or dis == 'M', 'dis must E or M, E 代表欧式距离,M代表曼哈顿距离'num_test = Y_test.shape[0]label_list = []if dis == 'E':for i in range(num_test):distances = np.sqrt(np.sum(((X_train - np.tile(Y_test[i], (X_train.shape[0], 1))) ** 2), axis=1))nearest_k = np.argsort(distances)topK = nearest_k[:k]print(topK)classCount = {}for i in topK:classCount[x_train[i]] = classCount.get(x_train[i], 0) + 1sorted_class_count = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)label_list.append(sorted_class_count[0][0])return np.array(label_list)if __name__ == '__main__':group, labels = create_data_set()plt.scatter(group[labels == 'A', 0], group[labels == 'A', 1], color='r', marker='*')plt.scatter(group[labels == 'B', 0], group[labels == 'B', 1], color='g', marker='+')y_test_pred = knn_classify(1, 'E', group, labels, np.array([[1.0, 2.1], [0.4, 2.0]]))print(y_test_pred)plt.show()