ML学习笔记——决策树Python3代码实现与解析(ID3、C4.5)

ML学习笔记——决策树Python3代码实现与解析(ID3、C4.5),第1张

ML学习笔记——决策树Python3代码实现与解析(ID3、C4.5)

ML学习笔记——决策树Python3代码实现与解析(ID3、C4.5)

ID3

代码详解实现效果 C4.5

代码实现效果 代码学习技巧总结

本文以经典的西瓜问题为例,决策树原理详见文章 ML学习笔记——决策树

ID3 代码详解

代码来源:https://blog.csdn.net/leaf_zizi/article/details/82848682

数据

青绿 蜷缩 浊响 清晰 凹陷 硬滑 是 
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是 
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是 
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是 
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是 
青绿 稍蜷 浊响 清晰 稍凹 软粘 是 
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是 
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是 
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否 
青绿 硬挺 清脆 清晰 平坦 软粘 否 
浅白 硬挺 清脆 模糊 平坦 硬滑 否 
浅白 蜷缩 浊响 模糊 平坦 软粘 否 
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否 
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否 
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否 
浅白 蜷缩 浊响 模糊 平坦 硬滑 否 
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否

ID3主代码
源代码基于python2,在python3环境下会报错,因此以下代码在源代码的基础上稍有修改并添加了详细的注释。

import numpy as np
import pandas as pd
import sklearn.tree as st
import math
import matplotlib
import os
import matplotlib.pyplot as plt

# 以下注释主要针对西瓜数据集


# 计算信息熵
# 参数 - 数据集(二维list)
# 返回值 - 传入的数据集的信息熵(浮点数)
def calcEntropy(dataSet):
    mD = len(dataSet)   # 数据集行数
    dataLabelList = [x[-1] for x in dataSet]    # 类别list(是/否)
    dataLabelSet = set(dataLabelList)   # 类别list -> 类别set(元素不重复,只有‘是’和‘否’)
    ent = 0
    for label in dataLabelSet:  # 遍历两次,一次遍历‘是’类别,一次遍历‘否’类别,遍历结束得到信息熵ent
        mDv = dataLabelList.count(label)
        prop = float(mDv) / mD
        ent = ent - prop * np.math.log(prop, 2)

    return ent


# 删除数据集包含某一特征的所有列数据,并返回经过删除的曾包含该特征的list的数据集(二维list)
# 参数
# index - 要拆分的特征的下标(整型数),代表属性
# feature - 要拆分的特征(字符串)
# 返回值 - dataSet中所有特征是feature的那一行数据中去掉该feature的数据集(二维list)
def splitDataSet(dataSet, index, feature):
    splitedDataSet = []  # 初始化已拆分的数据集splitedDataSet,类型为list
    for data in dataSet:  # 遍历数据集的每一行
        if (data[index] == feature):  # 如果是我们要拆分出来的特征feature
            sliceTmp = data[:index]  # 新建一个列表sliceTmp,先取这一行feature前面的所有数据
            sliceTmp.extend(data[index + 1:])  # 再取这一行feature后面的所有数据
            splitedDataSet.append(sliceTmp)  # 将sliceTmp加到splitedDataSet中去
    return splitedDataSet


# 选择最优划分属性(信息增益法)
# 参数 - 数据集(二维list)
# 返回值 - 最好的特征的下标(整型数)
def chooseBestFeature(dataSet):
    entD = calcEntropy(dataSet)  # 计算信息熵
    mD = len(dataSet)  # 数据集的行数
    featureNumber = len(dataSet[0]) - 1  # 数据集的特征数,减1是因为数据集最后一列是类型而不是特征
    maxGain = -100  # 初始化最大信息增益
    maxIndex = -1   # 初始化最大信息增益对应的列号(属性)
    for i in range(featureNumber):  # 遍历数据集的所有列(属性),i:0,1,...,featureNumber-1
        entDCopy = entD
        featureI = [x[i] for x in dataSet]  # 特征向量(第i列,即第i个属性)
        featureSet = set(featureI)  # 不含重复元素的特征集合(第i列,即第i个属性)
        for feature in featureSet:  # 遍历该特征集合中的特征feature
            splitedDataSet = splitDataSet(dataSet, i, feature)  # 含feature的数据集,但不包含feature的一列
            mDv = len(splitedDataSet)  # 行数,即含feature的数据个数
            entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
            #  为什么这里可以直接计算splitedDataSet的信息熵?
            #  因为计算信息熵时只需要数据集中最后一列的数据(类型),有无含feature的一列对计算结果无影响
        if (maxIndex == -1):
            maxGain = entDCopy
            maxIndex = i
        elif (maxGain < entDCopy):  # 找到了更大的信息增益,即找到了更优的划分属性
            maxGain = entDCopy
            maxIndex = i

    return maxIndex


# 返回数量最多的类别
# 参数 - 类别向量(一维list)
# 返回值 - 类别(字符串)
def mainLabel(labelList):
    labelRec = labelList[0]
    maxLabelCount = -1  # 数量最多的类别的类别数
    labelSet = set(labelList)  # 无重复元素的类别集合
    for label in labelSet:  # 遍历该集合中的类别
        if (labelList.count(label) > maxLabelCount):  # 如果传入的类别向量中当前类别的数量比maxLabelCount大,更新数量最多的类别
            maxLabelCount = labelList.count(label)
            labelRec = label
    return labelRec


# 构建决策树
def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
    labelList = [x[-1] for x in dataSet]  # 取数据集最后一列为类别向量
    if (len(dataSet) == 0):  # 数据集为空集
        return mainLabel(labelListParent)
    if (len(dataSet[0]) == 1):  # 没有可划分的属性了
        return mainLabel(labelList)  # 选出最多的类别作为该数据集的标签
    if (labelList.count(labelList[0]) == len(labelList)):  # 全部都属于同一个类别
        return labelList[0]

    bestFeatureIndex = chooseBestFeature(dataSet)  # 最好的属性的下标
    bestFeatureName = featureNames.pop(bestFeatureIndex)  # 从featureNames中删去bestFeature,并返回bestFeature
    myTree = {bestFeatureName: {}}  # 用字典构建决策树
    featureList = featureNamesSet.pop(bestFeatureIndex)  # 从featureNamesSet中删去bestFeature对应的子属性集,并返回这个最好的属性的子属性集
    featureSet = set(featureList)  # 转换成set类型
    for feature in featureSet:
        featureNamesNext = featureNames[:]  # 这里的featureNames已经删去了最好的属性
        featureNamesSetNext = featureNamesSet[:][:]  # 这里的featureNamesSet也已经删去了最好的属性
        splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
        # 递归建树
        myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
    return myTree


# 返回值
# dataSet 数据集
# featureNames 属性集(一维list)
# featureNamesSet 一个二维list,其中每一个元素是对应的属性的所有子属性的集合list
def readWatermelonDataSet():
    fr = open(r'C:Users静如止水DesktopID3data.txt', encoding='utf-8')
    dataSet = [inst.strip().split(' ') for inst in fr.readlines()]
    # print("dataSet =")
    # print(dataSet)
    # print('n')
    featureNames = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    # 获取featureNamesSet
    featureNamesSet = []
    for i in range(len(dataSet[0]) - 1):
        col = [x[i] for x in dataSet]
        colSet = set(col)
        featureNamesSet.append(list(colSet))
    return dataSet, featureNames, featureNamesSet

基于matplotlib的可视化代码

以下代码与源代码一致,无删改。

# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']

# 分叉节点,也就是决策节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")

# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")

# 箭头样式
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制一个节点
    :param nodeTxt: 描述该节点的文本信息
    :param centerPt: 文本的坐标
    :param parentPt: 点的坐标,这里也是指父节点的坐标
    :param nodeType: 节点类型,分为叶子节点和决策节点
    :return:
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    """
    获取叶节点的数目
    :param myTree:
    :return:
    """
    # 统计叶子节点的总数
    numLeafs = 0

    # 得到当前第一个key,也就是根节点
    firstStr = list(myTree.keys())[0]

    # 得到第一个key对应的内容
    secondDict = myTree[firstStr]

    # 递归遍历叶子节点
    for key in secondDict.keys():
        # 如果key对应的是一个字典,就递归调用
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 不是的话,说明此时是一个叶子节点
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    得到数的深度层数
    :param myTree:
    :return:
    """
    # 用来保存最大层数
    maxDepth = 0

    # 得到根节点
    firstStr = list(myTree.keys())[0]

    # 得到key对应的内容
    secondDic = myTree[firstStr]

    # 遍历所有子节点
    for key in secondDic.keys():
        # 如果该节点是字典,就递归调用
        if type(secondDic[key]).__name__ == 'dict':
            # 子节点的深度加1
            thisDepth = 1 + getTreeDepth(secondDic[key])

        # 说明此时是叶子节点
        else:
            thisDepth = 1

        # 替换最大层数
        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    """
    计算出父节点和子节点的中间位置,填充信息
    :param cntrPt: 子节点坐标
    :param parentPt: 父节点坐标
    :param txtString: 填充的文本信息
    :return:
    """
    # 计算x轴的中间位置
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    # 计算y轴的中间位置
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    # 进行绘制
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制出树的所有节点,递归绘制
    :param myTree: 树
    :param parentPt: 父节点的坐标
    :param nodeTxt: 节点的文本信息
    :return:
    """
    # 计算叶子节点数
    numLeafs = getNumLeafs(myTree=myTree)

    # 计算树的深度
    depth = getTreeDepth(myTree=myTree)

    # 得到根节点的信息内容
    firstStr = list(myTree.keys())[0]

    # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    # 绘制该节点与父节点的联系
    plotMidText(cntrPt, parentPt, nodeTxt)

    # 绘制该节点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    # 得到当前根节点对应的子树
    secondDict = myTree[firstStr]

    # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD

    # 循环遍历所有的key
    for key in secondDict.keys():
        # 如果当前的key是字典的话,代表还有子树,则递归遍历
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            # 打开注释可以观察叶子节点的坐标变化
            # print((plotTree.xOff, plotTree.yOff), secondDict[key])
            # 绘制叶子节点
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 绘制叶子节点和父节点的中间连线内容
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    """
    需要绘制的决策树
    :param inTree: 决策树字典
    :return:
    """
    # 创建一个图像
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 计算出决策树的总宽度
    plotTree.totalW = float(getNumLeafs(inTree))
    # 计算出决策树的总深度
    plotTree.totalD = float(getTreeDepth(inTree))
    # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
    plotTree.xOff = -0.5 / plotTree.totalW
    # 初始的y轴偏移量,每次向下或者向上移动1/D
    plotTree.yOff = 1.0
    # 调用函数进行绘制节点图像
    plotTree(inTree, (0.5, 1.0), '')
    # 绘制
    plt.show()

执行代码

dataSet, featureNames, featureNamesSet = readWatermelonDataSet()
testTree = createFullDecisionTree(dataSet, featureNames, featureNamesSet, featureNames)
print(testTree)  # 打印字典型决策树
createPlot(testTree)  # 可视化决策树
实现效果

终端:

{'纹理': {'清晰': {'根蒂': {'稍蜷': {'色泽': {'青绿': '是', '乌黑': {'触感': {'硬滑': '是', '软粘': '否'}}, '浅白': '是'}}, '蜷缩': '是', '硬挺': '否'}}, '模糊': '否', '稍糊': {'触感': {'硬滑': '否', '软粘': '是'}}}}

matplotlib:

对照西瓜书

C4.5 代码

代码来源:https://blog.csdn.net/leaf_zizi/article/details/84866918

本文在源代码的基础上进行了数据集的更换,并增加了连续值的处理与剪枝。

数据集

青绿 蜷缩 浊响 清晰 凹陷 硬滑 0.697 0.460 1 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 0.774 0.376 1 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 0.634 0.264 1 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 0.608 0.318 1 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 0.556 0.215 1 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 0.403 0.237 1 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 0.481 0.149 1 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 0.437 0.211 1 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 0.666 0.091 1 否
青绿 硬挺 清脆 清晰 平坦 软粘 0.243 0.267 1 否
浅白 硬挺 清脆 模糊 平坦 硬滑 0.245 0.057 1 否
浅白 蜷缩 浊响 模糊 平坦 软粘 0.343 0.099 1 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 0.639 0.161 1 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 0.657 0.198 1 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 0.360 0.370 1 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 0.593 0.042 1 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 0.719 0.103 1 否

主代码(C45)
与源代码一致,无删改。

# -*- coding: cp936 -*-
from math import log
import operator
import os

import re
from numpy import inf
import copy


# 计算信息熵
def calcShannonEnt(dataSet, labelIndex):
    # type: (list) -> float
    numEntries = 0  # 样本数(按权重计算)
    labelCounts = {}
    for featVec in dataSet:  # 遍历每个样本
        if featVec[labelIndex] != 'N':
            weight = float(featVec[-2])
            numEntries += weight
            currentLabel = featVec[-1]  # 当前样本的类别
            if currentLabel not in labelCounts.keys():  # 生成类别字典
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += weight  # 数据集的倒数第二个值用来标记样本权重
    shannonEnt = 0.0
    for key in labelCounts:  # 计算信息熵
        prob = float(labelCounts[key]) / numEntries
        shannonEnt = shannonEnt - prob * log(prob, 2)
    return shannonEnt


def splitDataSet(dataSet, axis, value, LorR='N'):
    """
    type: (list, int, string or float, string) -> list
    划分数据集
    axis:按第几个特征划分
    value:划分特征的值
    LorR: N 离散属性; L 小于等于value值; R 大于value值
    """
    retDataSet = []
    featVec = []
    if LorR == 'N':  # 离散属性
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis + 1:])
                retDataSet.append(reducedFeatVec)
    elif LorR == 'L':
        for featVec in dataSet:
            if featVec[axis] != 'N':
                if float(featVec[axis]) < value:
                    retDataSet.append(featVec)
    elif LorR == 'R':
        for featVec in dataSet:
            if featVec[axis] != 'N':
                if float(featVec[axis]) > value:
                    retDataSet.append(featVec)
    return retDataSet


def splitDataSetWithNull(dataSet, axis, value, LorR='N'):
    """
    type: (list, int, string or float, string) -> list
    划分数据集
    axis:按第几个特征划分
    value:划分特征的值
    LorR: N 离散属性; L 小于等于value值; R 大于value值
    """
    retDataSet = []
    nullDataSet = []
    featVec = []
    totalWeightV = calcTotalWeight(dataSet, axis, False)  # 非空样本权重
    totalWeightSub = 0.0
    if LorR == 'N':  # 离散属性
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis + 1:])
                retDataSet.append(reducedFeatVec)
            elif featVec[axis] == 'N':
                reducedNullVec = featVec[:axis]
                reducedNullVec.extend(featVec[axis + 1:])
                nullDataSet.append(reducedNullVec)
    elif LorR == 'L':
        for featVec in dataSet:
            if featVec[axis] != 'N':
                if float(featVec[axis]) < value:
                    retDataSet.append(featVec)
            elif featVec[axis] == 'N':
                nullDataSet.append(featVec)
    elif LorR == 'R':
        for featVec in dataSet:
            if featVec[axis] != 'N':
                if float(featVec[axis]) > value:
                    retDataSet.append(featVec)
            elif featVec[axis] == 'N':
                nullDataSet.append(featVec)

    totalWeightSub = calcTotalWeight(retDataSet, -1, True)  # 计算此分支中非空样本的总权重
    for nullVec in nullDataSet:  # 把缺失值样本按权值比例划分到分支中
        nullVec[-2] = float(nullVec[-2]) * totalWeightSub / totalWeightV
        retDataSet.append(nullVec)

    return retDataSet


def calcTotalWeight(dataSet, labelIndex, isContainNull):
    """
    type: (list, int, bool) -> float
    计算样本集对某个特征值的总样本树(按权重计算)
    :param dataSet: 数据集
    :param labelIndex: 特征值索引
    :param isContainNull: 是否包含空值的样本
    :return: 返回样本集的总权重值
    """
    totalWeight = 0.0
    for featVec in dataSet:  # 遍历每个样本
        weight = float(featVec[-2])
        if isContainNull is False and featVec[labelIndex] != 'N':
            totalWeight += weight  # 非空样本树,按权重计算
        if isContainNull is True:
            totalWeight += weight  # 总样本数,按权重计算
    return totalWeight


def calcGain(dataSet, labelIndex, labelPropertyi):
    """
    type: (list, int, int) -> float, int
    计算信息增益,返回信息增益值和连续属性的划分点
    dataSet: 数据集
    labelIndex: 特征值索引
    labelPropertyi: 特征值类型,0为离散,1为连续
    """
    baseEntropy = calcShannonEnt(dataSet, labelIndex)  # 计算根节点的信息熵
    featList = [example[labelIndex] for example in dataSet]  # 特征值列表
    uniquevals = set(featList)  # 该特征包含的所有值
    newEntropy = 0.0
    totalWeight = 0.0
    totalWeightV = 0.0
    totalWeight = calcTotalWeight(dataSet, labelIndex, True)  # 总样本权重
    totalWeightV = calcTotalWeight(dataSet, labelIndex, False)  # 非空样本权重
    if labelPropertyi == 0:  # 对离散的特征
        for value in uniquevals:  # 对每个特征值,划分数据集, 计算各子集的信息熵
            if value != 'N':
                subDataSet = splitDataSet(dataSet, labelIndex, value)
                totalWeightSub = 0.0
                totalWeightSub = calcTotalWeight(subDataSet, labelIndex, True)
                prob = totalWeightSub / totalWeightV
                newEntropy += prob * calcShannonEnt(subDataSet, labelIndex)
    else:  # 对连续的特征
        uniquevalsList = list(uniquevals)
        if 'N' in uniquevalsList:
            uniquevalsList.remove('N')
        sortedUniquevals = sorted(uniquevalsList)  # 对特征值排序
        listPartition = []
        minEntropy = inf
        if len(sortedUniquevals) == 1:  # 如果只有一个值,可以看作只有左子集,没有右子集
            totalWeightLeft = calcTotalWeight(dataSet, labelIndex, True)
            probLeft = totalWeightLeft / totalWeightV
            minEntropy = probLeft * calcShannonEnt(dataSet, labelIndex)
        else:
            for j in range(len(sortedUniquevals) - 1):  # 计算划分点
                partValue = (float(sortedUniquevals[j]) + float(
                    sortedUniquevals[j + 1])) / 2
                # 对每个划分点,计算信息熵
                dataSetLeft = splitDataSet(dataSet, labelIndex, partValue, 'L')
                dataSetRight = splitDataSet(dataSet, labelIndex, partValue, 'R')
                totalWeightLeft = 0.0
                totalWeightLeft = calcTotalWeight(dataSetLeft, labelIndex, True)
                totalWeightRight = 0.0
                totalWeightRight = calcTotalWeight(dataSetRight, labelIndex, True)
                probLeft = totalWeightLeft / totalWeightV
                probRight = totalWeightRight / totalWeightV
                Entropy = probLeft * calcShannonEnt(dataSetLeft, labelIndex) + 
                          probRight * calcShannonEnt(dataSetRight, labelIndex)
                if Entropy < minEntropy:  # 取最小的信息熵
                    minEntropy = Entropy
        newEntropy = minEntropy
    gain = totalWeightV / totalWeight * (baseEntropy - newEntropy)
    return gain


def calcGainRatio(dataSet, labelIndex, labelPropertyi):
    """
    type: (list, int, int) -> float, int
    计算信息增益率,返回信息增益率和连续属性的划分点
    dataSet: 数据集
    labelIndex: 特征值索引
    labelPropertyi: 特征值类型,0为离散,1为连续
    """
    baseEntropy = calcShannonEnt(dataSet, labelIndex)  # 计算根节点的信息熵
    featList = [example[labelIndex] for example in dataSet]  # 特征值列表
    uniquevals = set(featList)  # 该特征包含的所有值
    newEntropy = 0.0
    bestPartValuei = None
    IV = 0.0
    totalWeight = 0.0
    totalWeightV = 0.0
    totalWeight = calcTotalWeight(dataSet, labelIndex, True)  # 总样本权重
    totalWeightV = calcTotalWeight(dataSet, labelIndex, False)  # 非空样本权重
    if labelPropertyi == 0:  # 对离散的特征
        for value in uniquevals:  # 对每个特征值,划分数据集, 计算各子集的信息熵
            subDataSet = splitDataSet(dataSet, labelIndex, value)
            totalWeightSub = 0.0
            totalWeightSub = calcTotalWeight(subDataSet, labelIndex, True)
            if value != 'N':
                prob = totalWeightSub / totalWeightV
                newEntropy += prob * calcShannonEnt(subDataSet, labelIndex)
            prob1 = totalWeightSub / totalWeight
            IV -= prob1 * log(prob1, 2)
    else:  # 对连续的特征
        uniquevalsList = list(uniquevals)
        if 'N' in uniquevalsList:
            uniquevalsList.remove('N')
            # 计算空值样本的总权重,用于计算IV
            totalWeightN = 0.0
            dataSetNull = splitDataSet(dataSet, labelIndex, 'N')
            totalWeightN = calcTotalWeight(dataSetNull, labelIndex, True)
            probNull = totalWeightN / totalWeight
            if probNull > 0.0:
                IV += -1 * probNull * log(probNull, 2)

        sortedUniquevals = sorted(uniquevalsList)  # 对特征值排序
        listPartition = []
        minEntropy = inf

        if len(sortedUniquevals) == 1:  # 如果只有一个值,可以看作只有左子集,没有右子集
            totalWeightLeft = calcTotalWeight(dataSet, labelIndex, True)
            probLeft = totalWeightLeft / totalWeightV
            minEntropy = probLeft * calcShannonEnt(dataSet, labelIndex)
            IV = -1 * probLeft * log(probLeft, 2)
        else:
            for j in range(len(sortedUniquevals) - 1):  # 计算划分点
                partValue = (float(sortedUniquevals[j]) + float(
                    sortedUniquevals[j + 1])) / 2
                # 对每个划分点,计算信息熵
                dataSetLeft = splitDataSet(dataSet, labelIndex, partValue, 'L')
                dataSetRight = splitDataSet(dataSet, labelIndex, partValue, 'R')
                totalWeightLeft = 0.0
                totalWeightLeft = calcTotalWeight(dataSetLeft, labelIndex, True)
                totalWeightRight = 0.0
                totalWeightRight = calcTotalWeight(dataSetRight, labelIndex, True)
                probLeft = totalWeightLeft / totalWeightV
                probRight = totalWeightRight / totalWeightV
                Entropy = probLeft * calcShannonEnt(
                    dataSetLeft, labelIndex) + probRight * calcShannonEnt(dataSetRight, labelIndex)
                if Entropy < minEntropy:  # 取最小的信息熵
                    minEntropy = Entropy
                    bestPartValuei = partValue
                    probLeft1 = totalWeightLeft / totalWeight
                    probRight1 = totalWeightRight / totalWeight
                    IV += -1 * (probLeft1 * log(probLeft1, 2) + probRight1 * log(probRight1, 2))

        newEntropy = minEntropy
    gain = totalWeightV / totalWeight * (baseEntropy - newEntropy)
    if IV == 0.0:  # 如果属性只有一个值,IV为0,为避免除数为0,给个很小的值
        IV = 0.0000000001
    gainRatio = gain / IV
    return gainRatio, bestPartValuei


# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet, labelProperty):
    """
    type: (list, int) -> int, float
    :param dataSet: 样本集
    :param labelProperty: 特征值类型,1 连续, 0 离散
    :return: 最佳划分属性的索引和连续属性的划分值
    """
    numFeatures = len(labelProperty)  # 特征数
    bestInfoGainRatio = 0.0
    bestFeature = -1
    bestPartValue = None  # 连续的特征值,最佳划分值
    gainSum = 0.0
    gainAvg = 0.0
    for i in range(numFeatures):  # 对每个特征循环
        infoGain = calcGain(dataSet, i, labelProperty[i])
        gainSum += infoGain
    gainAvg = gainSum / numFeatures
    for i in range(numFeatures):  # 对每个特征循环
        infoGainRatio, bestPartValuei = calcGainRatio(dataSet, i, labelProperty[i])
        infoGain = calcGain(dataSet, i, labelProperty[i])
        if infoGainRatio > bestInfoGainRatio and infoGain > gainAvg:  # 取信息增益高于平均增益且信息增益率最大的特征
            bestInfoGainRatio = infoGainRatio
            bestFeature = i
            bestPartValue = bestPartValuei
    return bestFeature, bestPartValue


# 通过排序返回出现次数最多的类别
def majorityCnt(classList, weightList):
    classCount = {}
    for i in range(len(classList)):
        if classList[i] not in classCount.keys():
            classCount[classList[i]] = 0.0
        classCount[classList[i]] += round(float(weightList[i]),1)

    # python 2.7
    # sortedClassCount = sorted(classCount.iteritems(),
    #                         key=operator.itemgetter(1), reverse=True)
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    if len(sortedClassCount) == 1:
        return (sortedClassCount[0][0],sortedClassCount[0][1],0.0)
    return (sortedClassCount[0][0], sortedClassCount[0][1], sortedClassCount[1][1])


# 创建树, 样本集 特征 特征属性(0 离散, 1 连续)
def createTree(dataSet, labels, labelProperty):
    classList = [example[-1] for example in dataSet]  # 类别向量
    weightList = [example[-2] for example in dataSet]  # 权重向量
    if classList.count(classList[0]) == len(classList):  # 如果只有一个类别,返回
        totalWeiht = calcTotalWeight(dataSet,0,True)
        return (classList[0], round(totalWeiht,1),0.0)
    #totalWeight = calcTotalWeight(dataSet, 0, True)
    if len(dataSet[0]) == 1:  # 如果所有特征都被遍历完了,返回出现次数最多的类别
        return majorityCnt(classList)
    bestFeat, bestPartValue = chooseBestFeatureToSplit(dataSet,
                                                       labelProperty)  # 最优分类特征的索引
    if bestFeat == -1:  # 如果无法选出最优分类特征,返回出现次数最多的类别
        return majorityCnt(classList, weightList)
    if labelProperty[bestFeat] == 0:  # 对离散的特征
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel: {}}
        labelsNew = copy.copy(labels)
        labelPropertyNew = copy.copy(labelProperty)
        del (labelsNew[bestFeat])  # 已经选择的特征不再参与分类
        del (labelPropertyNew[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniquevalue = set(featValues)  # 该特征包含的所有值
        uniquevalue.discard('N')
        for value in uniquevalue:  # 对每个特征值,递归构建树
            subLabels = labelsNew[:]
            subLabelProperty = labelPropertyNew[:]
            myTree[bestFeatLabel][value] = createTree(
                splitDataSetWithNull(dataSet, bestFeat, value), subLabels,
                subLabelProperty)
    else:  # 对连续的特征,不删除该特征,分别构建左子树和右子树
        bestFeatLabel = labels[bestFeat] + '<' + str(bestPartValue)
        myTree = {bestFeatLabel: {}}
        subLabels = labels[:]
        subLabelProperty = labelProperty[:]
        # 构建左子树
        valueLeft = 'Y'
        myTree[bestFeatLabel][valueLeft] = createTree(
            splitDataSetWithNull(dataSet, bestFeat, bestPartValue, 'L'), subLabels,
            subLabelProperty)
        # 构建右子树
        valueRight = 'N'
        myTree[bestFeatLabel][valueRight] = createTree(
            splitDataSetWithNull(dataSet, bestFeat, bestPartValue, 'R'), subLabels,
            subLabelProperty)
    return myTree


# 测试算法
def classify(inputTree, classList, featLabels, featLabelProperties, testVec):
    firstStr = list(inputTree.keys())[0]  # 根节点
    firstLabel = firstStr
    lessIndex = str(firstStr).find('<')
    if lessIndex > -1:  # 如果是连续型的特征
        firstLabel = str(firstStr)[:lessIndex]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstLabel)  # 跟节点对应的特征
    classLabel = {}
    for classI in classList:
        classLabel[classI] = 0.0
    for key in secondDict.keys():  # 对每个分支循环
        if featLabelProperties[featIndex] == 0:  # 离散的特征
            if testVec[featIndex] == key:  # 测试样本进入某个分支
                if type(secondDict[key]).__name__ == 'dict':  # 该分支不是叶子节点,递归
                    classLabelSub = classify(secondDict[key], classList, featLabels,
                                          featLabelProperties, testVec)
                    for classKey in classLabel.keys():
                        classLabel[classKey] += classLabelSub[classKey]
                else:  # 如果是叶子, 返回结果
                    for classKey in classLabel.keys():
                        if classKey == secondDict[key][0]:
                            classLabel[classKey] += secondDict[key][1]
                        else:
                            classLabel[classKey] += secondDict[key][2]
            elif testVec[featIndex] == 'N':  # 如果测试样本的属性值缺失,则进入每个分支
                if type(secondDict[key]).__name__ == 'dict':  # 该分支不是叶子节点,递归
                    classLabelSub = classify(secondDict[key], classList, featLabels,
                                          featLabelProperties, testVec)
                    for classKey in classLabel.keys():
                        classLabel[classKey] += classLabelSub[key]
                else:  # 如果是叶子, 返回结果
                    for classKey in classLabel.keys():
                        if classKey == secondDict[key][0]:
                            classLabel[classKey] += secondDict[key][1]
                        else:
                            classLabel[classKey] += secondDict[key][2]
        else:
            partValue = float(str(firstStr)[lessIndex + 1:])
            if testVec[featIndex] == 'N':  # 如果测试样本的属性值缺失,则对每个分支的结果加和
                # 进入左子树
                if type(secondDict[key]).__name__ == 'dict':  # 该分支不是叶子节点,递归
                    classLabelSub = classify(secondDict[key], classList, featLabels,
                                          featLabelProperties, testVec)
                    for classKey in classLabel.keys():
                        classLabel[classKey] += classLabelSub[classKey]
                else:  # 如果是叶子, 返回结果
                    for classKey in classLabel.keys():
                        if classKey == secondDict[key][0]:
                            classLabel[classKey] += secondDict[key][1]
                        else:
                            classLabel[classKey] += secondDict[key][2]
            elif float(testVec[featIndex]) <= partValue and key == 'Y':  # 进入左子树
                if type(secondDict['Y']).__name__ == 'dict':  # 该分支不是叶子节点,递归
                    classLabelSub = classify(secondDict['Y'], classList, featLabels,
                                             featLabelProperties, testVec)
                    for classKey in classLabel.keys():
                        classLabel[classKey] += classLabelSub[classKey]
                else:  # 如果是叶子, 返回结果
                    for classKey in classLabel.keys():
                        if classKey == secondDict[key][0]:
                            classLabel[classKey] += secondDict['Y'][1]
                        else:
                            classLabel[classKey] += secondDict['Y'][2]
            elif float(testVec[featIndex]) > partValue and key == 'N':
                if type(secondDict['N']).__name__ == 'dict':  # 该分支不是叶子节点,递归
                    classLabelSub = classify(secondDict['N'], classList, featLabels,
                                             featLabelProperties, testVec)
                    for classKey in classLabel.keys():
                        classLabel[classKey] += classLabelSub[classKey]
                else:  # 如果是叶子, 返回结果
                    for classKey in classLabel.keys():
                        if classKey == secondDict[key][0]:
                            classLabel[classKey] += secondDict['N'][1]
                        else:
                            classLabel[classKey] += secondDict['N'][2]

    return classLabel


# 存储决策树
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


# 读取决策树, 文件不存在返回None
def grabTree(filename):
    import pickle
    if os.path.isfile(filename):
        fr = open(filename)
        return pickle.load(fr)
    else:
        return None


# 测试决策树正确率
def testing(myTree, classList, data_test, labels, labelProperties):
    error = 0.0
    for i in range(len(data_test)):
        classLabelSet = classify(myTree, classList, labels, labelProperties, data_test[i])
        maxWeight = 0.0
        classLabel = ''
        for item in classLabelSet.items():
            if item[1] > maxWeight:
                classLabel = item[0]
        if classLabel !=  data_test[i][-1]:
            error += 1
    return float(error)


# 测试投票节点正确率
def testingMajor(major, data_test):
    error = 0.0
    for i in range(len(data_test)):
        if major[0] != data_test[i][-1]:
            error += 1
    # print 'major %d' %error
    return float(error)


# 后剪枝
def postPruningTree(inputTree, classSet, dataSet, data_test, labels, labelProperties):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    classList = [example[-1] for example in dataSet]
    weightList = [example[-2] for example in dataSet]
    featkey = copy.deepcopy(firstStr)
    if '<' in firstStr:  # 对连续的特征值,使用正则表达式获得特征标签和value
        featkey = re.compile("(.+<)").search(firstStr).group()[:-1]
        featvalue = float(re.compile("(<.+)").search(firstStr).group()[1:])
    labelIndex = labels.index(featkey)
    temp_labels = copy.deepcopy(labels)
    temp_labelProperties = copy.deepcopy(labelProperties)
    if labelProperties[labelIndex] == 0:  # 离散特征
        del (labels[labelIndex])
        del (labelProperties[labelIndex])
    for key in secondDict.keys():  # 对每个分支
        if type(secondDict[key]).__name__ == 'dict':  # 如果不是叶子节点
            if temp_labelProperties[labelIndex] == 0:  # 离散的
                subDataSet = splitDataSet(dataSet, labelIndex, key)
                subDataTest = splitDataSet(data_test, labelIndex, key)
            else:
                if key == 'Y':
                    subDataSet = splitDataSet(dataSet, labelIndex, featvalue,
                                              'L')
                    subDataTest = splitDataSet(data_test, labelIndex,
                                               featvalue, 'L')
                else:
                    subDataSet = splitDataSet(dataSet, labelIndex, featvalue,
                                              'R')
                    subDataTest = splitDataSet(data_test, labelIndex,
                                               featvalue, 'R')
            if len(subDataTest) > 0:
                inputTree[firstStr][key] = postPruningTree(secondDict[key], classSet,
                                                       subDataSet, subDataTest,
                                                       copy.deepcopy(labels),
                                                       copy.deepcopy(
                                                           labelProperties))
    if testing(inputTree, classSet, data_test, temp_labels,
               temp_labelProperties) <= testingMajor(majorityCnt(classList, weightList),
                                                     data_test):
        return inputTree
    return majorityCnt(classList,weightList)

可视化代码(treePlotter)
与源代码一致,无删改。

# -*- coding: cp936 -*-
import matplotlib.pyplot as plt

# 能够显示中文
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['font.serif'] = ['SimHei']

# 设置决策节点和叶节点的边框形状、边距和透明度,以及箭头的形状
decisionNode = dict(boxstyle="square,pad=0.5", fc="0.9")
leafNode = dict(boxstyle="round4, pad=0.5", fc="0.9")
arrow_args = dict(arrowstyle="<-", connectionstyle="arc3", shrinkA=0,
                  shrinkB=16)


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    newTxt = nodeTxt
    if type(nodeTxt).__name__ == 'tuple':
        newTxt = nodeTxt[0] + 'n'
        for strI in nodeTxt[1:-1]:
            newTxt += str(strI) + ','
        newTxt+= str(nodeTxt[-1])

    createPlot.ax1.annotate(newTxt, xy=parentPt,
                            xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="top", ha="center", bbox=nodeType,
                            arrowprops=arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def retrieveTree(i):
    listOfTrees = [
        {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
        {'no surfacing': {0: 'no', 1: {
            'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    ]
    return listOfTrees[i]


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                     cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree)) + 0.5
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


def createPlot0():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    # plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
    # plotNode('叶节点', (0.8, 0.1), (0.3, 0.75), leafNode)

    '''an1 = createPlot.ax1.annotate(unicode("决策节点", 'cp936'), xy=(0.5, 0.5),
                                  xycoords="data",
                                  va="center", ha="center",
                                  bbox=dict(box, fc="w"))

    createPlot.ax1.annotate(unicode('叶节点', 'cp936'),
                            xytext=(0.2, 0.3), arrowprops=dict(arrow),
                            xycoords=an1,
                            textcoords='axes fraction',
                            va="bottom", ha="left",
                            bbox=leafNode)'''

    an1 = createPlot.ax1.annotate("Test 1", xy=(0.5, 0.5), xycoords="data",
                                  va="center", ha="center",
                                  bbox=dict(boxstyle="round", fc="w"))
    an2 = createPlot.ax1.annotate("Test 2", xy=(0, 0.5), xycoords=an1,
                                  # (1,0.5) of the an1's bbox
                                  xytext=(-50, -50), textcoords="offset points",
                                  va="center", ha="center",
                                  bbox=dict(boxstyle="round", fc="w"),
                                  arrowprops=dict(arrowstyle="<-"))

    plt.show()


'''
    an1 = createPlot.ax1.annotate(unicode('决策节点', 'cp936'), xy=(0.5, 0.6),
                            xycoords='axes fraction',
                            textcoords='axes fraction',
                            va="bottom", ha="center", bbox=decisionNode)

    createPlot.ax1.annotate(unicode('叶节点', 'cp936'), xy=(0.8, 0.1),
                            xycoords=an1,
                            textcoords='axes fraction',
                            va="bottom", ha="center", bbox=leafNode,
                            arrowprops=arrow_args)
'''

if __name__ == '__main__':
    createPlot0()

执行代码

import C45
import treePlotter

# 读取数据文件
fr = open(r'C:Users静如止水DesktopC4.5决策树data.txt', encoding='utf-8')
# 生成数据集
lDataSet = [inst.strip().split(' ') for inst in fr.readlines()]
# 样本特征标签
labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感', '密度', '含糖率']
# 样本特征类型,0为离散,1为连续
labelProperties = [0, 0, 0, 0, 0, 0, 1, 1]
# 类别向量
classList = ['是', '否']
# 验证集
dataSet_test = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '0.697', '0.460', '是'],
                ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '0.774', '0.376', '是']]
# 构建决策树
trees = C45.createTree(lDataSet, labels, labelProperties)
# 绘制决策树
treePlotter.createPlot(trees)
# 利用验证集对决策树剪枝
C45.postPruningTree(trees, classList, lDataSet, dataSet_test, labels, labelProperties)
# 绘制剪枝后的决策树
treePlotter.createPlot(trees)
实现效果

对照西瓜书

代码学习技巧总结
    如何提取数据集dataSet(二维list)某一列:
dataLabelList = [x[index] for x in dataSet] # index表示要提取的列的下标
    如何实现不重复地提取出list中的元素:可将其转换成集合set类型
dataLabelSet = set(dataLabelList)
    如何删除二维list中的某一列:遍历每一行,先取这一行要删除的列的前面的所有数据,再取这一行要删除的列的后面的所有数据,最后加到一个新的list中去。
# 代码有删改
splitedDataSet = []  # 初始化已拆分的数据集splitedDataSet,类型为list
for data in dataSet:  # 遍历数据集的每一行
    sliceTmp = data[:index] # 新建一个列表sliceTmp,先取这一行index前面的所有数据
    sliceTmp.extend(data[index + 1:])  # 再取这一行index后面的所有数据
    splitedDataSet.append(sliceTmp)
    如何在matplotlib中显示中文:在前面加上这两行代码即可。
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['font.serif'] = ['SimHei']

参考书籍:周志华《机器学习》

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5710745.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存