京公网安备 11010802034615号
经营许可证编号:京B2-20210330
简单易学的机器学习算法—Mean Shift聚类算法
一、Mean Shift算法概述
Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:
定义了核函数;
增加了权重系数。
核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。
二、Mean Shift算法的核心原理
2.1、核函数
在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:

并且满足:
(1)、k是非负的
(2)、k是非增的
(3)、k是分段连续的
那么,函数K(x)就称为核函数。
常用的核函数有高斯核函数。高斯核函数如下所示:
其中,h称为带宽(bandwidth),不同带宽的核函数如下图所示:

上图的画图脚本如下所示:
'''
Date:201604026
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
import math
def cal_Gaussian(x, h=1):
molecule = x * x
denominator = 2 * h * h
left = 1 / (math.sqrt(2 * math.pi) * h)
return left * math.exp(-molecule / denominator)
x = []
for i in xrange(-40,40):
x.append(i * 0.5);
score_1 = []
score_2 = []
score_3 = []
score_4 = []
for i in x:
score_1.append(cal_Gaussian(i,1))
score_2.append(cal_Gaussian(i,2))
score_3.append(cal_Gaussian(i,3))
score_4.append(cal_Gaussian(i,4))
plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")
plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
2.2、Mean Shift算法的核心思想
2.2.1、基本原理
对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):
步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

步骤2:移动该点到偏移均值点处

步骤3: 重复上述的过程(计算新的偏移均值,移动)

步骤4:满足了最终的条件,即退出

从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。
2.2.2、基本的Mean Shift向量形式
对于给定的d维空间Rd中的n个样本点
,则对于x点,其Mean Shift向量的基本形式为:
其中,Sh指的是一个半径为h的高维球区域,如上图中的蓝色的圆形区域。Sh的定义为:
这样的一种基本的Mean Shift形式存在一个问题:在Sh的区域内,每一个点对x的贡献是一样的。而实际上,这种贡献与x到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。
2.2.3、改进的Mean Shift向量形式
基于以上的考虑,对基本的Mean Shift向量形式中增加核函数和样本权重,得到如下的改进的Mean Shift向量形式:
其中:
G(x)是一个单位的核函数。H是一个正定的对称d×d矩阵,称为带宽矩阵,其是一个对角阵。w(xi)⩾0是每一个样本的权重。对角阵H的形式为:

上述的Mean Shift向量可以改写成:

Mean Shift向量Mh(x)是归一化的概率密度梯度。
2.3、Mean Shift算法的解释
在Mean Shift算法中,实际上是利用了概率密度,求得概率密度的局部最优解。
2.3.1、概率密度梯度
对一个概率密度函数f(x),已知d维空间中n个采样点xi,i=1,⋯,n,f(x)的核函数估计(也称为Parzen窗估计)为:

其中
w(xi)⩾0是一个赋给采样点xi的权重
K(x)是一个核函数
概率密度函数f(x)的梯度▽f(x)的估计为
令
,则有:
其中,第二个方括号中的就是Mean Shift向量,其与概率密度梯度成正比。
2.3.2、Mean Shift向量的修正

Mh(x)=∑ni=1G(∥∥xi−xh∥∥2)w(xi)xi∑ni=1G(xi−xh)w(xi)−x
记:
,则上式变成:
Mh(x)=mh(x)+x
这与梯度上升的过程一致。
2.4、Mean Shift算法流程
Mean Shift算法的算法流程如下:
计算mh(x)
令x=mh(x)
如果∥mh(x)−x∥<ε,结束循环,否则,重复上述步骤
三、实验
3.1、实验数据
实验数据如下图所示(来自参考文献1):

画图的代码如下:
'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data")
x = []
y = []
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 2:
x.append(float(lines[0]))
y.append(float(lines[1]))
f.close()
plt.plot(x, y, 'b.', label="original data")
plt.title('Mean Shift')
plt.legend(loc="upper right")
plt.show()
3.2、实验的源码
#!/bin/python
#coding:UTF-8
'''
Date:20160426
@author: zhaozhiyong
'''
import math
import sys
import numpy as np
MIN_DISTANCE = 0.000001#mini error
def load_data(path, feature_num=2):
f = open(path)
data = []
for line in f.readlines():
lines = line.strip().split("\t")
data_tmp = []
if len(lines) != feature_num:
continue
for i in xrange(feature_num):
data_tmp.append(float(lines[i]))
data.append(data_tmp)
f.close()
return data
def gaussian_kernel(distance, bandwidth):
m = np.shape(distance)[0]
right = np.mat(np.zeros((m, 1)))
for i in xrange(m):
right[i, 0] = (-0.5 * distance[i] * distance[i].T) / (bandwidth * bandwidth)
right[i, 0] = np.exp(right[i, 0])
left = 1 / (bandwidth * math.sqrt(2 * math.pi))
gaussian_val = left * right
return gaussian_val
def shift_point(point, points, kernel_bandwidth):
points = np.mat(points)
m,n = np.shape(points)
#计算距离
point_distances = np.mat(np.zeros((m,1)))
for i in xrange(m):
point_distances[i, 0] = np.sqrt((point - points[i]) * (point - points[i]).T)
#计算高斯核
point_weights = gaussian_kernel(point_distances, kernel_bandwidth)
#计算分母
all = 0.0
for i in xrange(m):
all += point_weights[i, 0]
#均值偏移
point_shifted = point_weights.T * points / all
return point_shifted
def euclidean_dist(pointA, pointB):
#计算pointA和pointB之间的欧式距离
total = (pointA - pointB) * (pointA - pointB).T
return math.sqrt(total)
def distance_to_group(point, group):
min_distance = 10000.0
for pt in group:
dist = euclidean_dist(point, pt)
if dist < min_distance:
min_distance = dist
return min_distance
def group_points(mean_shift_points):
group_assignment = []
m,n = np.shape(mean_shift_points)
index = 0
index_dict = {}
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
print item_1
if item_1 not in index_dict:
index_dict[item_1] = index
index += 1
for i in xrange(m):
item = []
for j in xrange(n):
item.append(str(("%5.2f" % mean_shift_points[i, j])))
item_1 = "_".join(item)
group_assignment.append(index_dict[item_1])
return group_assignment
def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print "iter : " + str(iter)
for i in range(0, m):
#判断每一个样本点是否需要计算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#计算最终的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
if __name__ == "__main__":
#导入数据集
path = "./data"
data = load_data(path, 2)
#训练,h=2
points, shift_points, cluster = train_mean_shift(data, 2)
for i in xrange(len(cluster)):
print "%5.2f,%5.2f\t%5.2f,%5.2f\t%i" % (points[i,0], points[i, 1], shift_points[i, 0], shift_points[i, 1], cluster[i])
3.3、实验的结果
经过Mean Shift算法聚类后的数据如下所示:

'''
Date:20160426
@author: zhaozhiyong
'''
import matplotlib.pyplot as plt
f = open("data_mean")
cluster_x_0 = []
cluster_x_1 = []
cluster_x_2 = []
cluster_y_0 = []
cluster_y_1 = []
cluster_y_2 = []
center_x = []
center_y = []
center_dict = {}
for line in f.readlines():
lines = line.strip().split("\t")
if len(lines) == 3:
label = int(lines[2])
if label == 0:
data_1 = lines[0].strip().split(",")
cluster_x_0.append(float(data_1[0]))
cluster_y_0.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
elif label == 1:
data_1 = lines[0].strip().split(",")
cluster_x_1.append(float(data_1[0]))
cluster_y_1.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
else:
data_1 = lines[0].strip().split(",")
cluster_x_2.append(float(data_1[0]))
cluster_y_2.append(float(data_1[1]))
if label not in center_dict:
center_dict[label] = 1
data_2 = lines[1].strip().split(",")
center_x.append(float(data_2[0]))
center_y.append(float(data_2[1]))
f.close()
plt.plot(cluster_x_0, cluster_y_0, 'b.', label="cluster_0")
plt.plot(cluster_x_1, cluster_y_1, 'g.', label="cluster_1")
plt.plot(cluster_x_2, cluster_y_2, 'k.', label="cluster_2")
plt.plot(center_x, center_y, 'r+', label="mean point")
plt.title('Mean Shift 2')数据分析师培训
#plt.legend(loc="best")
plt.show()
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
在用户行为分析实践中,很多从业者会陷入一个核心误区:过度关注“当前数据的分析结果”,却忽视了结果的“泛化能力”——即分析 ...
2026-03-13在数字经济时代,用户的每一次点击、浏览、停留、转化,都在传递着真实的需求信号。用户行为分析,本质上是通过收集、整理、挖掘 ...
2026-03-13在金融、零售、互联网等数据密集型行业,量化策略已成为企业挖掘商业价值、提升决策效率、控制经营风险的核心工具。而CDA(Certi ...
2026-03-13在机器学习建模体系中,随机森林作为集成学习的经典算法,凭借高精度、抗过拟合、适配多场景、可解释性强的核心优势,成为分类、 ...
2026-03-12在机器学习建模过程中,“哪些特征对预测结果影响最大?”“如何筛选核心特征、剔除冗余信息?”是从业者最常面临的核心问题。随 ...
2026-03-12在数字化转型深度渗透的今天,企业管理已从“经验驱动”全面转向“数据驱动”,数据思维成为企业高质量发展的核心竞争力,而CDA ...
2026-03-12在数字经济飞速发展的今天,数据分析已从“辅助工具”升级为“核心竞争力”,渗透到商业、科技、民生、金融等各个领域。无论是全 ...
2026-03-11上市公司财务报表是反映企业经营状况、盈利能力、偿债能力的核心数据载体,是投资者决策、研究者分析、从业者复盘的重要依据。16 ...
2026-03-11数字化浪潮下,数据已成为企业生存发展的核心资产,而数据思维,正是CDA(Certified Data Analyst)数据分析师解锁数据价值、赋 ...
2026-03-11线性回归是数据分析中最常用的预测与关联分析方法,广泛应用于销售额预测、风险评估、趋势分析等场景(如前文销售额预测中的多元 ...
2026-03-10在SQL Server安装与配置的实操中,“服务名无效”是最令初学者头疼的高频问题之一。无论是在命令行执行net start启动服务、通过S ...
2026-03-10在数据驱动业务的当下,CDA(Certified Data Analyst)数据分析师的核心价值,不仅在于解读数据,更在于搭建一套科学、可落地的 ...
2026-03-10在企业经营决策中,销售额预测是核心环节之一——无论是库存备货、营销预算制定、产能规划,还是战略布局,都需要基于精准的销售 ...
2026-03-09金融数据分析的核心价值,是通过挖掘数据规律、识别风险、捕捉机会,为投资决策、风险控制、业务优化提供精准支撑——而这一切的 ...
2026-03-09在数据驱动决策的时代,CDA(Certified Data Analyst)数据分析师的核心工作,是通过数据解读业务、支撑决策,而指标与指标体系 ...
2026-03-09在数据处理的全流程中,数据呈现与数据分析是两个紧密关联却截然不同的核心环节。无论是科研数据整理、企业业务复盘,还是日常数 ...
2026-03-06在数据分析、数据预处理场景中,dat文件是一种常见的二进制或文本格式数据文件,广泛应用于科研数据、工程数据、传感器数据等领 ...
2026-03-06在数据驱动决策的时代,CDA(Certified Data Analyst)数据分析师的核心价值,早已超越单纯的数据清洗与统计分析,而是通过数据 ...
2026-03-06在教学管理、培训数据统计、课程体系搭建等场景中,经常需要对课时数据进行排序并实现累加计算——比如,按课程章节排序,累加各 ...
2026-03-05在数据分析场景中,环比是衡量数据短期波动的核心指标——它通过对比“当前周期与上一个相邻周期”的数据,直观反映指标的月度、 ...
2026-03-05