登录
首页精彩阅读简单易学的机器学习算法—Rosenblatt感知机的对偶解法
简单易学的机器学习算法—Rosenblatt感知机的对偶解法
2017-03-21
收藏

简单易学的机器学习算法—Rosenblatt感知机的对偶解法

一、Rosenblatt感知机回顾
    在博文“简单易学的机器学习算法——Rosenblatt感知机”中介绍了Rosenblatt感知机的基本概念。Rosenblatt感知机是针对线性可分问题的二分类算法。通过构造分离超平面将正类和负类区分开。构造了如下的输入空间到输出空间的函数:

其中,w为权重,b为偏置。为符号函数:

求解这个函数的重点就是求解函数中的参数:和。Rosenblatt感知机通过构造损失函数,并求得使得这样的损失函数达到最小时的w和b。
    其中,为:

这里的为所有误分类的点的集合。我们的目标是求得损失函数的最小值:。
    通过梯度下降法(详细请见“简单易学的机器学习算法——Rosenblatt感知机”),我们得到了w和b的更新公式:


其中,为学习率。
二、Rosenblatt感知机的对偶形式
    对偶形式的基本想法是,将w和b表示为实例的线性组合的形式,通过求解其系数而求得
    通过上面的的更新公式,我们发现,是一个累加的过程。如果令,则可以表示为:
其中,
    此时的感知机模型就变为:
三、算法流程
初始化,
选择误分类数据点,即,更新a和b

直到没有误分类的点,否则重复步骤2
计算出
四、实验的仿真
    利用博文“简单易学的机器学习算法——Rosenblatt感知机”中的数据集,原始数据集如下图所示:

(原始数据点)

MATLAB代码
[plain] view plain copy 在CODE上查看代码片派生到我的代码片
%% Rosenblatt感知机的对偶解法  
clear all;  
clc;  
 
%读入数据  
x=[3,3;4,3;1,1];  
y=[1;1;-1];  
[m,n] = size(x);%取得数据集的大小  
 
%% 画出原始的点  
hold on  
axis([0 5 0 5]);%axis一般用来设置axes的样式,包括坐标轴范围,可读比例等  
for i = 1:m  
    plot(x(i,1),x(i,2),'.');  
end  
 
%% 初始化  
alpha = zeros(1,m);  
b = 0;  
yita = 1;%学习率  
gram = zeros(m,m);  
 
%% 计算Gram矩阵  
for i = 1:m  
    for j = 1:m  
        gram(i,j)=x(i,:)*x(j,:)';  
    end  
end  
 
%% 更新  
for i = 1:m  
    tmp = 0;  
    for j = 1:m  
        tmp = tmp + alpha(j)*y(j)*gram(i,j);  
    end  
    tmp = tmp + b;  
    tmp = y(i)*tmp;  
    if tmp <= 0  
        alpha(i) = alpha(i)+yita;  
        b = b + y(i);  
    end  
end  
% 要使得数据集中没有误分类的点  
flag = 0;%标志位,用于标记有没有误分类的点  
i = 1;  
while flag~=1  
    while i <= 3  
        tmp = 0;  
        for j = 1:m  
            tmp = tmp + alpha(j)*y(j)*gram(i,j);  
        end  
        tmp = tmp + b;  
        tmp = y(i)*tmp;  
        if tmp <= 0  
            alpha(i) = alpha(i)+yita;  
            b = b + y(i);  
            i = 1;%重置i  
            break;  
        else  
            i = i+1;  
        end  
        if i == 4  
            flag = 1;  
        end  
    end  
end  
 
%% 重新计算w和b  
for i = 1:m  
    x_new(i,:) = x(i,:) * y(i);  
end  
w = alpha * x_new;  
 
%% 画出分隔线  
x_1 = (0:3);  
y_1 = (-b-w(1,1)*x_1)./w(1,2);  
plot(x_1,y_1);  

最终的分离超平面:

(最终分离超平面)

数据分析咨询请扫描二维码

客服在线
立即咨询