京公网安备 11010802034615号
经营许可证编号:京B2-20210330
机器学习之k-近邻(kNN)算法与Python实现
k-近邻算法(kNN,k-NearestNeighbor),是最简单的机器学习分类算法之一,其核心思想在于用距离目标最近的k个样本数据的分类来代表目标的分类(这k个样本数据和目标数据最为相似)。
一 k-近邻(kNN)算法概述
1.概念
kNN算法的核心思想是用距离最近的k个样本数据的分类来代表目标数据的分类。
其原理具体地讲,存在一个训练样本集,这个数据训练样本的数据集合中的每个样本都包含数据的特征和目标变量(即分类值),输入新的不含目标变量的数据,将该数据的特征与训练样本集中每一个样本进行比较,找到最相似的k个数据,这k个数据出席那次数最多的分类,即输入的具有特征值的数据的分类。
例如,训练样本集中包含一系列数据,这个数据包括样本空间位置(特征)和分类信息(即目标变量,属于红色三角形还是蓝色正方形),要对中心的绿色数据的分类。运用kNN算法思想,距离最近的k个样本的分类来代表测试数据的分类,那么:
当k=3时,距离最近的3个样本在实线内,具有2个红色三角和1个蓝色正方形**,因此将它归为红色三角。
当k=5时,距离最近的5个样本在虚线内,具有2个红色三角和3个蓝色正方形**,因此将它归为蓝色正方形。
2.特点
优点
(1)监督学习:可以看到,kNN算法首先需要一个训练样本集,这个集合中含有分类信息,因此它属于监督学习。
(2)通过计算距离来衡量样本之间相似度,算法简单,易于理解和实现。
(3)对异常值不敏感
缺点 (4)需要设定k值,结果会受到k值的影响,通过上面的例子可以看到,不同的k值,最后得到的分类结果不尽相同。k一般不超过20。(5)计算量大,需要计算样本集中每个样本的距离,才能得到k个最近的数据样本。 (6)训练样本集不平衡导致结果不准确问题。当样本集中主要是某个分类,该分类数量太大,导致近邻的k个样本总是该类,而不接近目标分类。
3.kNN算法流程
一般情况下,kNN有如下流程:
(1)收集数据:确定训练样本集合测试数据;
(2)计算测试数据和训练样本集中每个样本数据的距离;
常用的距离计算公式:
(3)按照距离递增的顺序排序;
(4)选取距离最近的k个点;
(5)确定这k个点中分类信息的频率;
(6)返回前k个点中出现频率最高的分类,作为当前测试数据的分类。二 、Python算法实现
1.KNN算法分类器
建立一个名为“KNN.py”的文件,构造一个kNN算法分类器的函数:
from numpy import *
import operator
#定义KNN算法分类器函数
#函数参数包括:(测试数据,训练数据,分类,k值)
def classify(inX,dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1))-dataSet
sqDiffMat=diffMat**2
sqDistances=sqDiffMat.sum(axis=1)
distances=sqDistances**0.5 #计算欧式距离
sortedDistIndicies=distances.argsort() #排序并返回index
#选择距离最近的k个值
classCount={}
for i in range(k):
voteIlabel=labels[sortedDistIndicies[i]]
#D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.
classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
#排序
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
在KNN.py中定义一个生成“训练样本集”的函数:
#定义一个生成“训练样本集”的函数,包含特征和分类信息在Python控制台先将当前目录设置为“KNN.py”所在的文件目录,将测试数据[0,0]进行KNN算法分类测试,输入:
import KNN
#生成训练样本
group,labels=KNN.createDataSet()
#对测试数据[0,0]进行KNN算法分类测试
KNN.classify([0,0],group,labels,3)
Out[3]: 'B'
可以看到该分类器函数将[0,0]分类为B组,符合实际情况,分入了符合逻辑的正确的类别。但如何知道KNN分类的正确性呢?
2.kNN算法用于约会网站配对
2.1准备数据
该数据在文本文件datingTestSet2.txt中,该数据具有1000行,4列,分别是特征数据(每年获得的飞行常客里程数,玩视频游戏所耗时间百分比,每周消费的冰淇淋公升数),和目标变量/分类数据(是否喜欢(1表示不喜欢,2表示魅力一般,3表示极具魅力)),部分数据展示如下:
完整地数据下载地址如下:
约会网站测试数据
(1)将文本记录转为成numpy
在python控制台输入:
in [5]:datingDataMat,datingLabels=KNN.file2matrix('G:\Workspaces\MachineLearning\machinelearninginaction\Ch02\datingTestSet2.txt')#括号是文件路径
(2)可视化分析数据
运用Matplotlib创建散点图来分析数据:
import matplotlib
import matplotlib.pyplot as plt
#对第二列和第三列数据进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],c=datingLabels)
plt.xlabel('Percentage of Time Spent Playing Video Games')
plt.ylabel('Liters of Ice Cream Consumed Per Week')
#对第一列和第二列进行分析:
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],c=datingLabels)
plt.xlabel('Miles of plane Per year')
plt.ylabel('Percentage of Time Spent Playing Video Games')
ax.legend(loc='best')

(3)数据归一化
由于不同的数据在大小上差别较大,在计算欧式距离,整体较大的数据明细所占的比重更高,因此需要对数据进行归一化处理。
在Python控制台输入:
reload(KNN)数据的准备工作完成,下一步对算法进行测试。
2.2 算法测试
kNN算法分类的结果的效果,可以使用正确率/错误率来衡量,错误率为0,则表示分类很完美,如果错误率为1,表示分类完全错误。我们使用1000条数据中的90%作为训练样本集,其中的10%来测试错误率。
#定义测试算法的函数在控制台输入命令来测试错误率:
reload(KNN)
Out[150]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.datingClassTest()
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
the classifier came back with: 3,the real answer is: 1
the total error rate is : 0.050000
可以看到KNN算法分类器处理约会数据的错误率是5%,具有较高额正确率。
可以在datingClassTest函数中传入参数h来改变测试数据比例,来看修改后Ration后错误率有什么样的变化。
KNN.datingClassTest(0.2)
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 1,the real answer is: 1
... ...
the classifier came back with: 2,the real answer is: 2
the classifier came back with: 3,the real answer is: 3
the classifier came back with: 2,the real answer is: 2
the total error rate is : 0.080000
减小训练样本集数据,增加测试数据,错误率增加到8%。
2.3 使用KNN算法进行预测
def classifypersion():测试一下:
reload(KNN)
Out[153]: <module 'KNN' from 'G:\\Workspaces\\MachineLearning\\KNN.py'>
KNN.classifypersion()
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice creamconsued per year?0.5
You will probably like this persion :not at all
3. KNN算法用于手写识别系统
已经将图片转化为32*32 的文本格式,文本格式如下:
00000000000111110000000000000000
00000000001111111000000000000000
00000000011111111100000000000000
00000000111111111110000000000000
00000001111111111111000000000000
00000011111110111111100000000000
00000011111100011111110000000000
00000011111100001111110000000000
00000111111100000111111000000000
00000111111100000011111000000000
00000011111100000001111110000000
00000111111100000000111111000000
00000111111000000000011111000000
00000111111000000000011111100000
00000111111000000000011111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000111111000000000001111100000
00000011111000000000001111100000
00000011111100000000011111100000
00000011111100000000111111000000
00000001111110000000111111100000
00000000111110000001111111000000
00000000111110000011111110000000
00000000111111000111111100000000
00000000111111111111111000000000
00000000111111111111110000000000
00000000011111111111100000000000
00000000001111111111000000000000
00000000000111111110000000000000
3.1数据准备
(1)将32*32的文本格式转为成1*2014的向量
在控制台中输入命令测试下函数:
reload(KNN)
3.2 算法测试
使用kNN算法测试手写数字识别
#引入os模块的listdir函数,列出给定目录的文件名
from os impor listdir
def handwritingClassTest():
hwLabels=[]
trainingFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits')#列出文件名
m=len(trainingFileList) #文件数目
trainMat=zeros((m,1024))
#从文件名中解析分类信息,如0_13.txt
for i in range(m):
fileNameStr=trainingFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
hwLabels.append(classNumber)
trainMat[i]=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/trainingDigits/%s'%fileNameStr)
testFileList=listdir('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits')
errorCount=0
#同上,解析测试数据的分类信息
mTest=len(testFileList)
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumber=int(fileStr.split('_')[0])
vectorUnderTest=img2vector('G:/Workspaces/MachineLearning/machinelearninginaction/Ch02/testDigits/%s'%fileNameStr)
classifierResult=classify(vectorUnderTest,trainMat,hwLabels,3)
print('the classifier came back with :%d,the real answer is:%d'%(classifierResult,classNumber))
if(classifierResult!=classNumber):errorCount+=1
print('\n the total number of errors is: %d'%errorCount)
print('\n total error rate is %f'%(errorCount/float(mTest)))
接下来在Python控制台输入命令来测试手写数字识别:
reload(KNN)
KNN.handwritingClassTest()
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
the classifier came back with :0,the real answer is:0
... ...
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the classifier came back with :9,the real answer is:9
the total number of errors is: 10
total error rate is 0.010571
错误利率1.057%,具有较高的准确率。
CDA学员免费下载查看报告全文:2026全球数智化人才指数报告【CDA数据科学研究院】.pdf
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
在中介效应分析中,人口统计学变量(如年龄、性别、学历、收入、职业等)是常见的控制变量或调节变量,其处理方式直接影响分析结 ...
2026-04-29在SQL数据库实操中,日期数据的存储与显示是高频需求,而“数字日期”(如20240520、20241231、45321)是很多开发者、数据分析师 ...
2026-04-29 很多分析师在设计标签时思路清晰,但真到落地环节却面临“数据在手,不知如何转化为可用标签”的困境:或因加工方式选择不当 ...
2026-04-29在手游行业竞争日趋白热化的当下,“流量为王”早已升级为“留存为王”,而付费用户留存率更是衡量一款手游盈利能力、运营质量的 ...
2026-04-28在日常MySQL数据库运维与开发中,经常会遇到“同一台服务器上,两个不同数据库(以下简称“源库”“目标库”)的表数据需要保持 ...
2026-04-28 很多分析师每天和数据打交道,但当被问到“标签是什么”“标签和指标有什么区别”“标签体系如何设计”时,却常常答不上来。 ...
2026-04-28箱线图(Box Plot)作为一种经典的数据可视化工具,广泛应用于统计学、数据分析、科研实证等领域,核心价值在于直观呈现数据的集 ...
2026-04-27实证分析是社会科学、自然科学、经济管理等领域开展研究的核心范式,其核心逻辑是通过对多维度数据的收集、分析与解读,揭示变量 ...
2026-04-27 很多数据分析师精通Excel函数和数据透视表,但当被问到“数据从哪里来”“表和视图有什么区别”“数据库管理系统和SQL是什么 ...
2026-04-27在大数据技术飞速迭代、数字营销竞争日趋激烈的今天,“精准触达、高效转化、成本可控”已成为企业营销的核心诉求。传统广告投放 ...
2026-04-24在游戏行业竞争白热化的当下,用户流失已成为制约游戏生命周期、影响营收增长的核心痛点。据行业报告显示,2024年移动游戏平均次 ...
2026-04-24 很多业务负责人开会常说“我们要数据驱动”,最后却变成“看哪张报表数据多就用哪个”,往往因为缺乏一套结构性的方法去搭建 ...
2026-04-24在Power BI数据可视化分析中,切片器是连接用户与数据的核心交互工具,其核心价值在于帮助使用者快速筛选目标数据、聚焦分析重点 ...
2026-04-23以数为据,以析促优——数据分析结果指导临床技术改进的实践路径 临床技术是医疗服务的核心载体,其水平直接决定患者诊疗效果、 ...
2026-04-23很多数据分析师每天盯着GMV、DAU、转化率,但当被问到“哪些指标是所有企业都需要的”“哪些指标是因行业而异的”“北极星指标和 ...
2026-04-23近日,由 CDA 数据科学研究院重磅发布的《2026 全球数智化人才指数报告》,被中国教育科学研究院官方账号正式收录, ...
2026-04-22在数字化时代,客户每一次点击、浏览、下单、咨询等行为,都在传递其潜在需求与决策倾向——这些按时间顺序串联的行为轨迹,构成 ...
2026-04-22数据是数据分析、建模与业务决策的核心基石,而“数据清洗”作为数据预处理的核心环节,是打通数据从“原始杂乱”到“干净可用” ...
2026-04-22 很多数据分析师每天盯着GMV、转化率、DAU等数字看,但当被问到“什么是指标”“指标和维度有什么区别”“如何搭建一套完整的 ...
2026-04-22在数据分析与业务决策中,数据并非静止不变的数值,而是始终处于动态波动之中——股市收盘价的每日涨跌、企业月度销售额的起伏、 ...
2026-04-21