机器学习实战(三)决策树
决策树是一种简单高效并且具有强解释性的模型,广泛应用于数据分析领域。其本质是一颗由多个判断节点组成的树,在使用模型进行预测时,根据输入参数依次在各个判断节点进行判断游走,最后到叶子节点即为预测结果。
"""
文件trees
"""
from math import log
import operator
%matplotlib inline #这一句设置在线显示
"""
创建数据集
"""
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
"""
计算给定数据集的信息熵(香农)
"""
def calcuShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
"""
为所有可能分能创建字典
"""
for featVec in dataSet:
currentLabel = featVec[-1]
#方法1
# if currentLabel not in labelCounts.keys():
# labelCounts[currentLabel] = 0
# labelCounts[currentLabel] += 1
#方法2
labelCounts[currentLabel] = labelCounts.get(currentLabel,0) + 1
shannonEnt = 0.0
for key in labelCounts:
"""
每个类别所占的比
"""
prob = float(labelCounts[key])/numEntries
"""
求对数
"""
shannonEnt -= prob * log(prob,2)
return shannonEnt
"""
按照给定特征划分数据集
三个参数:待划分的数据集、划分数据集的特征、特征的返回值
"""
def splitDataSet(dataSet, axis, value):
"""
创建新的list对象
"""
"""
理解:按axis这列来划分,若这列的数=value归到一类,并创建一个新列表返回
"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
"""
抽取
"""
# 应该是清空列表
reducedFeatVec = featVec[:axis]
# [0+1:]取下标1之后的,这里是取后两位,将元素塞入reducedFeatVec
reducedFeatVec.extend(featVec[axis+1:])
#将reducedFeatVec列表塞入retDataSet
retDataSet.append(reducedFeatVec)
return retDataSet
"""
选择最好的数据集划分方式
此处使用ID3算法,获取信息增益最大的
"""
def chooseBestFeatureToSplit(dataSet):
#dataSet[0]列数
#只用前两列进行分类 -1
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcuShannonEnt(dataSet)
#信息增益
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
"""
创建唯一的分类标签列表
"""
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)# 创建集合
newEntropy = 0.0
"""
计算每种划分方式的信息熵
第一列或第二列的划分方式 表3-1 将其他两列分类
"""
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
"""
将分完的两个类分别计算信息熵,乘以每个分类所出现的概率,相加后得到新的熵
"""
newEntropy += prob * calcuShannonEnt(subDataSet)
"""
判断信息增益是否大于0
"""
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
"""
计算最好的信息增益
"""
bestInfoGain = infoGain
# i为列数,对应表3-1
bestFeature = i
return bestFeature
"""
类似于投票表决的方法,挑选次数出现最多的类别
"""
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.key():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
"""
创建树的代码
"""
def createTree(dataSet,labels):
# 把dataSet最后一列放到classList
classList = [example[-1] for example in dataSet]
# print classList
"""
类别完全相同则停止继续划分
"""
# 对classList第一个元素进行统计,如果与总长度相同,表示只有一个分类
if classList.count(classList[0]) == len(classList):
return classList[0]
"""
遍历完所有特征时返回出现次数最多的
"""
# 只剩最后一项的时候,按较多的
# print dataSet[0]
# 计算dataSet第一个元素的长度
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 获取到最佳特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
"""
得到列表包含的所有属性值
"""
# 从标签里删除最佳特征标签
del(labels[bestFeat])
#将dataSet每个元素的第一列拿出
featValues = [example[bestFeat] for example in dataSet]
#使用set无重复的存入uniqueVals
uniqueVals = set(featValues)
for value in uniqueVals:
#将标签复制一份,保证每次调用函数时不改变原始列表的内容,使用新变量代替原始列表
subLabels = labels[:]
#splitDataSet按当前分类方式进行分类,并将其他项作为新列表返回
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
"""
使用决策树的分类函数
"""
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
#将标签字符串转换成索引
featIndex = featLabels.index(firstStr)
# for key in secondDict.keys():
# if testVec[featIndex] == key:
# classLabel = classify(secondDict[key],featLabels,testVec)
# else:
# classLabel = secondDict[key]
#下面是源码中的内容 可以运行
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
# 非叶子节点继续递归判断
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
"""
决策树的存储,每次调用使用已经构造好的决策树,节省时间
"""
"""
使用pickle模块存储决策树
"""
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
#将对象保存到文件中
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
#从文件中读取
return pickle.load(fr)
"""
测试决策树分类函数
"""
myDat,labels=createDataSet()
myTree=retrieveTree(0)
classify(myTree,labels,[1,1])
'yes'
"""
测试存储效果
"""
myTree=retrieveTree(0)
storeTree(myTree,'classfierStorage.txt')
grabTree('classfierStorage.txt')
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""
生成决策树字典
"""
myData,labels = createDataSet()
myTree = createTree(myData,labels)
myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
"""
测试选择最好的数据集划分方式
"""
myData,labels = createDataSet()
chooseBestFeatureToSplit(myData)
0
"""
测试使用创建的数据集计算信息熵
"""
myData,labels = createDataSet()
#增加分类熵变大
myData[0][-1]='maybe'
calcuShannonEnt(myData)
1.3709505944546687
"""
测试划分数据集
"""
myData,labels = createDataSet()
splitDataSet(myData,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
"""
文件treePlotter
"""
"""
使用matplot文本注解回执树节点
"""
import matplotlib.pyplot as plt
"""
定义文本框和箭头格式
"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
"""
绘制带箭头的注解
"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', \
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot1():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
#decisionNode leafNode 是节点类型,不同节点类型样式不一样 包围文字的box不一样
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
"""
画一个完整的树,我们需要知道有多少个叶节点(确定x),树有多少层(确定y)
"""
#获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
#获取第一个键
#树的第一层一定只有一个节点
firstStr = myTree.keys()[0]
#获取第一个键所对应的值
secondDict = myTree[firstStr]
for key in secondDict.keys():
#判断节点数据(键的值)类型是否为字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
#获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth : maxDepth = thisDepth
return maxDepth
#预先存储下树的信息,避免每次测试都需要创建树,为了节约练习的时间
def retrieveTree(i):
listOfTrees = [{'no surfacing':
{0:'no',
1:{'flippers':
{0:'no',1:'yes'}}}},
{'no surfacing':
{0:'no',1:{'flippers':
{0:
{'head':
{0:'no',1:'yes'},1:'no'}}}}}]
return listOfTrees[i]
"""
在父子节点间填充文本信息
"""
def plotMidText(cntrPt, parentPt, txtString):
#计算文本信息的位置
#① 0.5 1
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
#计算宽与高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
#获取根节点
firstStr = myTree.keys()[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#① (0.5,1) (0.5,1) ""
# 在父子节点之间添加文本信息
plotMidText(cntrPt, parentPt, nodeTxt)
#画节点: 节点内容 子节点坐标 父节点坐标 节点类型
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 第二层
secondDict = myTree[firstStr]
# 修改y偏移量 1-1/2 有坐标范围0-1 这里从上到下绘制因此依次递减
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# 判断第二层节点下的节点是否为叶子节点
for key in secondDict.keys():
#不是叶子节点 递归执行
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
#如果是叶子节点
else:
#计算x的偏移量 -1/8+1/4
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff,plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 1/2+1/2
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
#facecolor设置背景
fig = plt.figure(1, facecolor='white')
#清除
fig.clf()
#清空ticks,标线,这里应该就是坐标轴
axprops = dict(xticks=[], yticks=[])
#这个是没有ticks的
#createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
#这个是有ticks的
createPlot.ax1 = plt.subplot(111,frameon=False)
#将树的深度和叶子节点数保存为全局变量
plotTree.totalW = float(getNumLeafs(inTree))
print "totalW: %f" % plotTree.totalW
plotTree.totalD = float(getTreeDepth(inTree))
print "totalD: %f"% plotTree.totalD
# -0.5/4
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0
#设置跟节点坐标
plotTree(inTree, (0.5,1.0), '')
plt.show()
"""
测试获取树的叶子节点数,树的层数
"""
myTree = retrieveTree(0)
getNumLeafs(myTree)
getTreeDepth(myTree)
2
"""
创建树节点
"""
"""
前面的createPlot()为了做区分 改名为createPlot1
"""
createPlot1()
"""
获取树信息
"""
myTree=retrieveTree(0)
"""
创建树
"""
createPlot(myTree)
totalW: 3.000000
totalD: 2.000000
"""
添加节点,测试输出效果
"""
myTree=retrieveTree(0)
myTree['no surfacing'][3]='maybe'
createPlot(myTree)
fr=open('F:\study\lenses.txt')
lences=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','preascript','astigmatic','tearRate']
lensesTree = createTree(lences,lensesLabels)
lensesTree
{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
'presbyopic': {'preascript': {'hyper': 'soft', 'myope': 'no lenses'}},
'young': 'soft'}},
'yes': {'preascript': {'hyper': {'age': {'pre': 'no lenses',
'presbyopic': 'no lenses',
'young': 'hard'}},
'myope': 'hard'}}}},
'reduced': 'no lenses'}}
createPlot(lensesTree)