登录
首页精彩阅读机器学习之决策树(ID3)算法与Python实现
机器学习之决策树(ID3)算法与Python实现
2017-07-23
收藏

机器学习决策树(ID3)算法与Python实现

机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 数据挖掘决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。

一、决策树与ID3概述1.决策树

决策树,其结构和树非常相似,因此得其名决策树决策树具有树形的结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。

例如:

按照豆腐脑的冷热、甜咸和是否含有大蒜构建决策树,对其属性的测试,在最终的叶节点决定该豆腐脑吃还是不吃。

分类树(决策树)是一种十分常用的将决策树应用于分类的机器学习方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性(特征)和一个类别(分类信息/目标),这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。
其原理在于,每个决策树都表述了一种树型结构,它由它的分支来对该类型的对象依靠属性进行分类。每个决策树可以依靠对源数据库的分割进行数据测试。这个过程可以递归式的对树进行修剪。 当不能再进行分割或一个单独的类可以被应用于某一分支时,递归过程就完成了。

机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。数据挖掘决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。从数据产生决策树机器学习技术叫做决策树学习, 通俗说就是决策树

目前常用的决策树算法有ID3算法、改进的C4.5算法和CART算法。

决策树的特点

1.多层次的决策树形式易于理解;

2.只适用于标称型数据,对连续性数据处理得不好;

2、ID3算法

ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树

信息熵(Entropy):


,其中p(xi)是选择i的概率。
熵越高,表示混合的数据越多。

信息增益(Information Gain):


T是划分之后的分支集合,p(t)是该分支集合在原本的父集合中出现的概率,H(t)是该子集合的信息熵。
3.ID3算法与决策树的流程
(1)数据准备:需要对数值型数据进行离散化
(2)ID3算法构建决策树
如果数据集类别完全相同,则停止划分
否则,继续划分决策树
计算信息熵和信息增益来选择最好的数据集划分方法;
划分数据集
创建分支节点:
对每个分支进行判定是否类别相同,如果相同停止划分,不同按照上述方法进行划分。
二、Python算法实现

创建 trees.py文件,在其中创建构建决策树的函数。
首先构建一组测试数据:

0. 构造函数createDataSet:

def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels

在Python控制台测试构造函数

#测试下构造的数据
import trees
myDat,labels = trees.createDataSet()
myDat
Out[4]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels

Out[5]: ['no surfacing', 'flippers']

2.1 计算信息熵

from math import log

def calcShannonEnt(dataSet):
    numEntries = len(dataSet) #nrows
    #为所有的分类类目创建字典
    labelCounts ={}
    for featVec in dataSet:
        currentLable=featVec[-1] #取得最后一列数据
        if currentLable not in labelCounts.keys():
            labelCounts[currentLable]=0
        labelCounts[currentLable]+=1
    #计算香农熵
    shannonEnt=0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

利用构造的数据测试calcShannonEnt:

#Python console
In [6]: trees.calcShannonEnt(myDat)
   ...:
Out[6]: 0.9709505944546686

2.2 按照最大信息增益划分数据集

#定义按照某个特征进行划分的函数splitDataSet
#输入三个变量(待划分的数据集,特征,分类值)
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value :
            reduceFeatVec=featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet #返回不含划分特征的子集

#定义按照最大信息增益划分数据的函数
def chooseBestFeatureToSplit(dataSet):
    numFeature=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)#香农熵
    bestInforGain=0
    bestFeature=-1
    for i in range(numFeature):
        featList=[number[i] for number in dataSet] #得到某个特征下所有值(某列)
        uniqualVals=set(featList) #set无重复的属性特征
        newEntropy=0
        for value in uniqualVals:
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet)) #即p(t)
            newEntropy+=prob*calcShannonEnt(subDataSet)#对各子集香农熵求和
        infoGain=baseEntropy-newEntropy #计算信息增益
        #最大信息增益
        if (infoGain>bestInforGain):
            bestInforGain=infoGain
            bestFeature=i
    return bestFeature #返回特征

在控制台中测试这两个函数:

#测试按照特征划分数据集的函数
In [8]: from imp import reload
In [9]: reload(trees)
Out[9]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'>
In [10]: myDat,labels=trees.createDataSet()
    ...:
In [11]: trees.splitDataSet(myDat,0,0)
    ...:
Out[11]: [[1, 'no'], [1, 'no']]
In [12]: trees.splitDataSet(myDat,0,1)
    ...:
Out[12]: [[1, 'yes'], [1, 'yes'], [0, 'no']]

#测试chooseBestFeatureToSplit函数
In [13]: reload(trees)
    ...:
Out[13]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'>

In [14]: trees.chooseBestFeatureToSplit(myDat)
    ...:

Out[14]: 0

2.3 创建决策树构造函数createTree

import operater
#投票表决代码
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items,key=operator.itemgetter(1),reversed=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    #类别相同,停止划分
    if classList.count(classList[-1])==len(classList):
        return classList[-1]
    #长度为1,返回出现次数最多的类别
    if len(classList[0])==1:
        return majorityCnt(classList)
    #按照信息增益最高选取分类特征属性
    bestFeat=chooseBestFeatureToSplit(dataSet)#返回分类的特征序号
    bestFeatLable=labels[bestFeat] #该特征的label
    myTree={bestFeatLable:{}} #构建树的字典
    del(labels[bestFeat]) #从labels的list中删除该label
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLables=labels[:] #子集合
        #构建数据的子集合,并进行递归
        myTree[bestFeatLable][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLables)
    return myTree

以之前构造的测试数据为例,对决策树构造函数进行测试,在python控制台进行输入:

#决策树构造函数测试
In [15]: reload(trees)
    ...:
Out[15]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'>

In [16]: myTree=trees.createTree(myDat,labels)
    ...:

In [17]: myTree
Out[17]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

可以看到,最后生成的决策树myTree是一个多层嵌套的字典。

2.4 决策树运用于分类

运用决策树进行分类,首先构建一个决策树分类函数:

#输入三个变量(决策树,属性特征标签,测试的数据)
def classify(inputTree,featLables,testVec):
    firstStr=list(inputTree.keys())[0] #获取树的第一个特征属性
    secondDict=inputTree[firstStr] #树的分支,子集合Dict
    featIndex=featLables.index(firstStr) #获取决策树第一层在featLables中的位置
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLables,testVec)
            else:classLabel=secondDict[key]
    return classLabel

决策树分类函数进行测试:

In [29]: reload(trees)
    ...:
Out[29]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'>

In [30]: myDat,labels=trees.createDataSet()
    ...:

In [31]: labels
    ...:
Out[31]: ['no surfacing', 'flippers']

In [32]: myTree=treeplotter.retrieveTree(0)
    ...:

In [33]: myTree
    ...:
Out[33]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

In [34]: trees.classify(myTree,labels,[1,0])
    ...:
Out[34]: 'no'

In [35]: trees.classify(myTree,labels,[1,1])
    ...:

Out[35]: 'yes'

2.5 决策树的存储

如果每次都需要训练样本集来构建决策树,费时费力,特别是数据很大的时候,每次重新构建决策树浪费时间。因此可以将已经创建的决策树(如字典形式)保存在硬盘上,需要使用的时候直接读取就好。
(1)存储函数

    1def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb') #pickle默认方式是二进制,需要制定'wb'
    pickle.dump(inputTree,fw)
    fw.close()

(2)读取函数

def grabTree(filename):
    import pickle
    fr=open(filename,'rb')#需要制定'rb',以byte形式读取
    return pickle.load(fr)

对这两个函数进行测试(Python console):

In [36]: myTree
Out[36]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

In [37]: trees.storeTree(myTree,'classifierStorage.txt')

In [38]: trees.grabTree('classifierStorage.txt')
Out[38]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

在工作目录下存在一个名为’classifierStorage.txt’的txt文档,该文档 保存了myTree的决策树信息,需要使用的时候直接调出使用。

三、使用Matplotlib绘制决策树

import matplotlib.pyplot as plt

from pylab import *  
mpl.rcParams['font.sans-serif'] = ['SimHei'] #否则中文无法正常显示

decisionNode=dict(boxstyle='sawtooth',fc='0.8') #决策点样式
leafNode=dict(boxstyle='round4',fc='0.8') #叶节点样式
arrow_args=dict(arrowstyle='<-') #箭头样式

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt,textcoords='axes fraction',
                            va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)

def createPlot():
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

#测试
#获取叶节点数量(广度)
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]#'dict_keys' object does not support indexing
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else:numLeafs+=1
    return numLeafs

#获取树的深度的函数(深度)
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth > maxDepth:
            maxDepth=thisDepth
    return maxDepth
#定义一个预先创建树的函数
def retrieveTree(i):
    listOfTrees=[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head':{0:'no', 1: 'yes'}},1:'no'}}}}
                 ]
    return listOfTrees[i]

#定义在父子节点之间填充文本信息的函数
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

#定义树绘制的函数    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff -1/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1/plotTree.totalD

 #定义主函数,来调用其它函数   
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

对绘制决策树图的函数进行测试(控制台):

In [26]: reload(treeplotter)
    ...:
Out[26]: <module 'treeplotter' from 'G:\\Workspaces\\MachineLearning\\treeplotter.py'>

In [27]: myTree=treeplotter.retrieveTree(0)
    ...:

In [28]: treeplotter.createPlot(myTree)
    ...:

得到决策树图:

四、实例(使用决策树预测隐形眼镜类型)

隐形眼镜的数据集包含了患者的四个属性age,prescript,stigmatic,tearRate,利用这些数据构建决策树,并通过Matplotlib绘制出决策树的树状图。
附lenses.txt数据:

young   myope   no  reduced no lenses
young   myope   no  normal  soft
young   myope   yes reduced no lenses
young   myope   yes normal  hard
young   hyper   no  reduced no lenses
young   hyper   no  normal  soft
young   hyper   yes reduced no lenses
young   hyper   yes normal  hard
pre myope   no  reduced no lenses
pre myope   no  normal  soft
pre myope   yes reduced no lenses
pre myope   yes normal  hard
pre hyper   no  reduced no lenses
pre hyper   no  normal  soft
pre hyper   yes reduced no lenses
pre hyper   yes normal  no lenses
presbyopic  myope   no  reduced no lenses
presbyopic  myope   no  normal  no lenses
presbyopic  myope   yes reduced no lenses
presbyopic  myope   yes normal  hard
presbyopic  hyper   no  reduced no lenses
presbyopic  hyper   no  normal  soft
presbyopic  hyper   yes reduced no lenses
presbyopic  hyper   yes normal  no lenses

In [40]: fr=open('machinelearninginaction/Ch03/lenses.txt')

In [41]: lenses=[inst.strip().split('\t') for inst in fr.readlines()]

In [42]: lensesLabels=['age','prescript','astigmatic','tearRate']

In [43]: lensesTree=trees.createTree(lenses,lensesLabels)

In [44]: lensesTree
Out[44]:
{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
      'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}},
      'young': 'soft'}},
    'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses',
        'presbyopic': 'no lenses',
        'young': 'hard'}},
      'myope': 'hard'}}}},
  'reduced': 'no lenses'}}
  In [45]:  treeplotter.createPlot(lensesTree)

得到图

数据分析咨询请扫描二维码

客服在线
立即咨询