知行编程网知行编程网  2022-02-24 14:00 知行编程网 隐藏边栏 |   抢沙发  43 
文章评分 0 次,平均分 0.0

在上篇文章中通俗易懂讲解感知机(一)--模型与学习策略 我已经表达清楚了感知机的模型以及学习策略,明白了感知机的任务是解决二分类问题,学习策略是优化损失函数

通俗易懂讲解感知机(二)--学习算法及python代码剖析

那么我们怎么来进行学习呢?根据书中例子给出python代码实现!

1学习算法

当我们已经有了一个目标是最小化损失函数,如下图:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

我们就可以用常用的梯度下降方法来进行更新,对w,b参数分别进行求偏导可得:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

那么我们任意初始化w,b之后,碰到误分类点时,采取的权值更新为w,b分别为:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

通俗易懂讲解感知机(二)--学习算法及python代码剖析

好了,当我们碰到误分类点的时候,我们就采取上面的更新步骤进行更新参数即可!

 

但李航博士在书中并不是用到所有误分类点的数据点来进行更新,而是采取随机梯度下降法(stochastic gradient descent)。

步骤如下,首先,任取一个超平面w0,b0,然后用梯度下降法不断地极小化目标函数极小化过程中不是一次是M中所有误分类点的梯度下降而是一次随机选取一个误分类点使其梯度下降(有证明可以证明随机梯度下降可以收敛,并更新速度快于批量梯度下降,在这里不是我们考虑的重点,我们默认为它能收敛到最优点即可,后面我会写一篇文章说明一下随机梯度下降与批梯度下降区别与代码实现

那么碰到误分类点的时候,采取的权值更新w,b分别为:

 

通俗易懂讲解感知机(二)--学习算法及python代码剖析

好了,到这里我们可以给出整个感知机学习过程算法!如下:

(1)选定初值w0,b0,(当于初始给了一个超平面

(2)在训练集中选取数据(xi,yi)(任意抽取数据点,判断是否所有数据点判断完成没有误分累点了,如果没有了,直接结束算法,如果还有进入(3)

(3)如果yi(w*xi+b)<0(说明是误分类点,就需要更新参数)

那么进行参数更新!更新方式如下:

 

通俗易懂讲解感知机(二)--学习算法及python代码剖析

这种更新方式,我们也有直观上的感觉,可以可视化理解一下,如下图:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

当我们数据点应该分类为y=+1的时候,我们分错了,分成-1(说明w*x<0,代表w与x向量夹角大于90度),这个时候应该调整,更新过程为w=w+1*x,往x向量方向更接近了!

第二种更新过程如下图:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

当我们数据点应该分类为y=-1的时候,我们分错了,分成+1(说明w*x>0,代表w与x向量夹角小于90度),这个时候应该调整,更新过程为w=w-1*x,往远离x向量方向更接近了!

 

(4)转到(2),直到训练集中没有误分类点(能够证明在有限次更新后,收敛,下篇文章会讲到!)

到这里为止,其实感知机算法理论部分已经全部讲完了,下面我给出算法python代码实现以及详细的代码注释!

2代码讲解

书上例子讲解:

通俗易懂讲解感知机(二)--学习算法及python代码剖析通俗易懂讲解感知机(二)--学习算法及python代码剖析通俗易懂讲解感知机(二)--学习算法及python代码剖析

根据上述例子和算法讲解,我实现了python代码如下,其中过程用详细注释解释了!

核心算法流程图如下:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

<span style="color: #808080;"># -*- coding: utf-8 -*-
</span><span style="color: #cc7832;">import </span>copy

trainint_set = [[(3,3),1],[(4,3),1],[(1,1),-1]] #输入数据
w = [0,0] #初始化w参数
b = 0              #初始化b参数

def update(item):
global w,b
w[0] += 1*item[1]*item[0][0] #w的第一个分量更新
   w[1] += 1*item[1]*item[0][1] #w的第二个分量更新
   b += 1*item[1]
print 'w = ',w,'b=',b #打印出结果

def judge(item): #返回y = yi(w*x+b)的结果
   res = 0
   for i in range(len(item[0])):
res +=item[0][i]*w[i] #对应公式w*x
   res += b #对应公式w*x+b
   res *= item[1] #对应公式yi(w*x+b)
   return res

def check(): #检查所有数据点是否分对了
   flag = False
   for item in trainint_set:
if judge(item)<=0: #如果还有误分类点,那么就小于等于0
           flag = True
           update(item) #只要有一个点分错,我就更新
   return flag #flag为False,说明没有分错的了

if __name__ == '__main__':
flag = False
   for i in range(1000):
if not check(): #如果已经没有分错的话
           flag = True
           break
   if flag:
print "在1000次以内全部分对了"
   else:
print "很不幸,1000次迭代还是没有分对"

程序运行结果如下:

通俗易懂讲解感知机(二)--学习算法及python代码剖析

实验证明这与我们书本上的结果是对应的。到这里已经讲完了本次要讲的内容,希望对大家理解有帮助~欢迎大家指错交流!

 

通俗易懂讲解感知机(二)--学习算法及python代码剖析

 

推荐阅读:

通俗易懂讲解感知机(一)--模型与学习策略

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享