决策树是一种简单高效并且具有强解释性的模型,广泛应用于数据分析领域。其本质是一颗由多个判断节点组成的树,在使用模型进行预测时,根据输入参数依次在各个判断节点进行判断游走,最后到叶子节点即为预测结果。

"""
文件trees
"""
from math import log
import operator
%matplotlib inline #这一句设置在线显示

"""
创建数据集
"""
def createDataSet():
    dataSet = [[1, 1, 'yes'],
              [1, 1, 'yes'],
              [1, 0, 'no'],
              [0, 1, 'no'],
              [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels

"""
计算给定数据集的信息熵(香农)
"""
def calcuShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    """
    为所有可能分能创建字典
    """
    for featVec in dataSet:
        currentLabel = featVec[-1]
        #方法1
        #   if currentLabel not in labelCounts.keys():
        #        labelCounts[currentLabel] = 0
        #   labelCounts[currentLabel] += 1
        #方法2
        labelCounts[currentLabel] = labelCounts.get(currentLabel,0) + 1
    shannonEnt = 0.0
    for key in labelCounts:
        """
        每个类别所占的比
        """
        prob = float(labelCounts[key])/numEntries
        """
        求对数
        """
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

"""
按照给定特征划分数据集
三个参数:待划分的数据集、划分数据集的特征、特征的返回值
"""
def splitDataSet(dataSet, axis, value):
    """
    创建新的list对象
    """
    """
    理解:按axis这列来划分,若这列的数=value归到一类,并创建一个新列表返回
    """
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            """
            抽取
            """
            # 应该是清空列表
            reducedFeatVec = featVec[:axis] 
            # [0+1:]取下标1之后的,这里是取后两位,将元素塞入reducedFeatVec
            reducedFeatVec.extend(featVec[axis+1:]) 
            #将reducedFeatVec列表塞入retDataSet
            retDataSet.append(reducedFeatVec)
    return retDataSet

    """
    选择最好的数据集划分方式
    此处使用ID3算法,获取信息增益最大的
    """
def chooseBestFeatureToSplit(dataSet):
    #dataSet[0]列数
    #只用前两列进行分类 -1
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcuShannonEnt(dataSet)
    #信息增益
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        """
        创建唯一的分类标签列表
        """
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)# 创建集合
        newEntropy = 0.0
        """
        计算每种划分方式的信息熵
        第一列或第二列的划分方式 表3-1 将其他两列分类
        """
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            """
            将分完的两个类分别计算信息熵,乘以每个分类所出现的概率,相加后得到新的熵
            """
            newEntropy += prob * calcuShannonEnt(subDataSet)
        """
        判断信息增益是否大于0
        """
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            """
            计算最好的信息增益
            """
            bestInfoGain = infoGain
            #  i为列数,对应表3-1
            bestFeature = i
    return bestFeature

"""
类似于投票表决的方法,挑选次数出现最多的类别
"""
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.key():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

"""
创建树的代码
"""
def createTree(dataSet,labels):
    # 把dataSet最后一列放到classList
    classList = [example[-1] for example in dataSet]
    #     print classList
    """
    类别完全相同则停止继续划分
    """
    #     对classList第一个元素进行统计,如果与总长度相同,表示只有一个分类
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    """
    遍历完所有特征时返回出现次数最多的
    """
    #     只剩最后一项的时候,按较多的
    #     print dataSet[0]
    #     计算dataSet第一个元素的长度
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    #     获取到最佳特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    """
    得到列表包含的所有属性值
    """
    # 从标签里删除最佳特征标签
    del(labels[bestFeat])
    #将dataSet每个元素的第一列拿出
    featValues = [example[bestFeat] for example in dataSet]
    #使用set无重复的存入uniqueVals
    uniqueVals = set(featValues)
    for value in uniqueVals:
    #将标签复制一份,保证每次调用函数时不改变原始列表的内容,使用新变量代替原始列表
        subLabels = labels[:]
        #splitDataSet按当前分类方式进行分类,并将其他项作为新列表返回
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

"""
使用决策树的分类函数
"""
def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    #将标签字符串转换成索引
    featIndex = featLabels.index(firstStr)
#     for key in secondDict.keys():
#         if testVec[featIndex] == key:
#             classLabel = classify(secondDict[key],featLabels,testVec)
#         else:
#             classLabel = secondDict[key]
    #下面是源码中的内容 可以运行
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        # 非叶子节点继续递归判断
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel
"""
决策树的存储,每次调用使用已经构造好的决策树,节省时间
"""
"""
使用pickle模块存储决策树
"""
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    #将对象保存到文件中
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    #从文件中读取
    return pickle.load(fr)
"""
测试决策树分类函数
"""
myDat,labels=createDataSet()
myTree=retrieveTree(0)
classify(myTree,labels,[1,1])
'yes'
"""
测试存储效果
"""
myTree=retrieveTree(0)
storeTree(myTree,'classfierStorage.txt')
grabTree('classfierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""
生成决策树字典
"""
myData,labels = createDataSet()
myTree = createTree(myData,labels)
myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""
测试选择最好的数据集划分方式
"""
myData,labels = createDataSet()
chooseBestFeatureToSplit(myData)
0
"""
测试使用创建的数据集计算信息熵
"""
myData,labels = createDataSet()
#增加分类熵变大
myData[0][-1]='maybe'
calcuShannonEnt(myData)
1.3709505944546687
"""
测试划分数据集
"""
myData,labels = createDataSet()
splitDataSet(myData,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
"""
文件treePlotter
"""
"""
使用matplot文本注解回执树节点
"""
import matplotlib.pyplot as plt
"""
定义文本框和箭头格式
"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
"""
绘制带箭头的注解
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                           xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot1():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    #decisionNode leafNode 是节点类型,不同节点类型样式不一样 包围文字的box不一样
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

"""
画一个完整的树,我们需要知道有多少个叶节点(确定x),树有多少层(确定y)
"""
#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    #获取第一个键
    #树的第一层一定只有一个节点
    firstStr = myTree.keys()[0]
    #获取第一个键所对应的值
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #判断节点数据(键的值)类型是否为字典
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
#获取树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth : maxDepth = thisDepth
    return maxDepth

#预先存储下树的信息,避免每次测试都需要创建树,为了节约练习的时间
def retrieveTree(i):
    listOfTrees = [{'no surfacing':
                    {0:'no',
                     1:{'flippers':
                        {0:'no',1:'yes'}}}},
                  {'no surfacing':
                   {0:'no',1:{'flippers':
                     {0:
                      {'head':
                       {0:'no',1:'yes'},1:'no'}}}}}]
    return listOfTrees[i]

"""
在父子节点间填充文本信息
"""
def plotMidText(cntrPt, parentPt, txtString):
    #计算文本信息的位置
    #① 0.5 1
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    #计算宽与高
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    #获取根节点
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #① (0.5,1)  (0.5,1)  ""
    # 在父子节点之间添加文本信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    #画节点: 节点内容 子节点坐标  父节点坐标  节点类型
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    # 第二层 
    secondDict = myTree[firstStr]
    # 修改y偏移量 1-1/2  有坐标范围0-1 这里从上到下绘制因此依次递减
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 判断第二层节点下的节点是否为叶子节点
    for key in secondDict.keys():
        #不是叶子节点  递归执行
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        #如果是叶子节点
        else:
            #计算x的偏移量 -1/8+1/4
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff,plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    # 1/2+1/2
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    #facecolor设置背景
    fig = plt.figure(1, facecolor='white')
    #清除
    fig.clf()
    #清空ticks,标线,这里应该就是坐标轴
    axprops = dict(xticks=[], yticks=[])
    #这个是没有ticks的
    #createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
    #这个是有ticks的
    createPlot.ax1 = plt.subplot(111,frameon=False)
    #将树的深度和叶子节点数保存为全局变量
    plotTree.totalW = float(getNumLeafs(inTree))
    print "totalW: %f" % plotTree.totalW
    plotTree.totalD = float(getTreeDepth(inTree))
    print "totalD: %f"% plotTree.totalD
    # -0.5/4
    plotTree.xOff = -0.5/plotTree.totalW;
    plotTree.yOff = 1.0
    #设置跟节点坐标
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
"""
测试获取树的叶子节点数,树的层数
"""
myTree = retrieveTree(0)
getNumLeafs(myTree)
getTreeDepth(myTree)
2
"""
创建树节点
"""
"""
前面的createPlot()为了做区分 改名为createPlot1
"""
createPlot1()

图像输出

"""
获取树信息
"""
myTree=retrieveTree(0)
"""
创建树
"""
createPlot(myTree)
totalW: 3.000000
totalD: 2.000000

图像输出

"""
添加节点,测试输出效果
"""
myTree=retrieveTree(0)
myTree['no surfacing'][3]='maybe'
createPlot(myTree)

图像输出

fr=open('F:\study\lenses.txt')
lences=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','preascript','astigmatic','tearRate']
lensesTree = createTree(lences,lensesLabels)
lensesTree
{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
      'presbyopic': {'preascript': {'hyper': 'soft', 'myope': 'no lenses'}},
      'young': 'soft'}},
    'yes': {'preascript': {'hyper': {'age': {'pre': 'no lenses',
        'presbyopic': 'no lenses',
        'young': 'hard'}},
      'myope': 'hard'}}}},
  'reduced': 'no lenses'}}
createPlot(lensesTree)

图像输出