机器学习-决策树

机器学习-决策树,第1张

手写笔记

决策树算法实现
""" @File   : DecisionTree
    @Author : BabyMuu
    @Time   : 2022/4/12 13:00
"""
import pandas as pd
import numpy as np


class DecisionTree:
    def __init__(self, feature, target, labels, max_depth=None):
        """"""
        self.feature = feature.copy()
        self.target = target.copy()
        self.labels = labels
        self.feature_labels = []
        self.tree_height = self.feature.shape[1]
        if max_depth:
            self.max_depth = max_depth
        else:
            self.max_depth = self.tree_height
        self.tree = self.init(self.feature, self.target, self.labels)

    def init(self, feature, target, labels):
        """创建决策树"""
        # 设置停止条件
        if target.nunique() == 1:
            # 如果当前target的值全为一个值 则直接返回该值即可
            return target.unique()[0]
        if self.tree_height - feature.shape[1] == self.max_depth:
            # 如果所有特征全部分完, 仍然没有完全分清, 则根据少数服从多数进行判断属于哪个类别
            return self.majority_cnt(target)
        # 获得 特征中 评价指标最好的 特征索引
        best_feature_label = \
            self.choose_best_feature(feature, target)
        # 将选出的特征放入特征列表中
        self.feature_labels.append(best_feature_label)
        # 删除已经放入树中的特征
        labels = labels.drop(best_feature_label)
        cur_tree = {best_feature_label: {}}  # 创建根节点
        # 获取当前特征中有多少个不同的值
        unique_vals = feature[best_feature_label].unique()
        for value in unique_vals:
            index = feature[feature[best_feature_label] == value].index
            f = feature[labels].loc[index]
            t = target.loc[index]
            cur_tree[best_feature_label][value] = self.init(f, t, labels)
        return cur_tree

    @staticmethod
    def majority_cnt(target: pd.DataFrame):
        """计算当前节点中哪一个类别的比较多"""
        return target.describe()['top']

    def choose_best_feature(self, features, target: pd.DataFrame):
        """选择信息增益最高的特征"""
        # 1, 计算基础信息熵
        base_entropy = self.cal_entropy(target.value_counts())
        print(base_entropy)
        # 2, 计算所有特征的信息增益
        info_gain = self.cal_info_gain(features, target, base_entropy)
        # 3, 返回所有特征中信息增益最大的特征名称
        return sorted(info_gain, key=lambda x: info_gain[x], reverse=True)[0]

    def cal_info_gain(self, features: pd.DataFrame, target, base_info_gain):
        """计算信息增益"""
        # 1, 组合特征和标签
        features['target'] = target
        label_counts = {}
        feature_entropy = {}
        # 2, 遍历每个特征, 获取(特征值, target值) : 数量
        for feature in features[features.columns.drop('target')]:
            # (特征值, target值) : 数量
            value_counts = features.groupby(feature)['target'].value_counts()
            # 特征值
            feature_unique = features[feature].unique().tolist()
            label_counts[feature] = (value_counts, feature_unique)
        # 删除添加的标签
        features.drop(columns='target', inplace=True)
        # 3, 遍历获取每一个特征的信息增益
        for feature in label_counts:
            label_count, value_count = label_counts[feature]
            entropy = 0  # 信息熵
            for value in value_count:
                # 计算每个特征值对应的信息熵
                entropy += self.cal_entropy(
                    label_count[value]) * label_count[value].sum()
            # 计算当前 feature 的信息增益
            feature_entropy[feature] = \
                base_info_gain - entropy / label_count.sum()
        # 4, 返回信息增益列表
        return feature_entropy

    @staticmethod
    def cal_entropy(target_count):
        """计算信息熵"""
        value = target_count / target_count.sum()
        print(value)
        return -np.dot(np.log2(value), value.T)
决策树可视化
""" @File   : draw_tree
    @Author : BabyMuu
    @Time   : 2022/4/13 9:09
"""
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False


# 树信息统计 叶子节点数量 和 树深度
def getTreeSize(decisionTree):
    nodeName = list(decisionTree.keys())[0]
    nodeValue = decisionTree[nodeName]
    leafNum = 0
    treeDepth = 0
    leafDepth = 0
    for val in nodeValue.keys():
        if type(nodeValue[val]) == dict:
            leafNum += getTreeSize(nodeValue[val])[0]
            leafDepth = 1 + getTreeSize(nodeValue[val])[1]
        else:
            leafNum += 1
            leafDepth = 1
        treeDepth = max(treeDepth, leafDepth)
    return leafNum, treeDepth


decisionNodeStyle = dict(boxstyle="sawtooth", fc="0.8")
leafNodeStyle = {"boxstyle": "round4", "fc": "0.8"}
arrowArgs = {"arrowstyle": "<-"}


# 画节点
def plotNode(nodeText, centerPt, parentPt, nodeStyle):
    createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords="axes fraction",
                            xytext=centerPt
                            , textcoords="axes fraction", va="center",
                            ha="center", bbox=nodeStyle, arrowprops=arrowArgs)


# 添加箭头上的标注文字
def plotMidText(centerPt, parentPt, lineText):
    xMid = (centerPt[0] + parentPt[0]) / 2.0
    yMid = (centerPt[1] + parentPt[1]) / 2.0
    createPlot.ax1.text(xMid, yMid, lineText)


# 画树
def plotTree(decisionTree, parentPt, parentValue):
    # 计算宽与高
    leafNum, treeDepth = getTreeSize(decisionTree)
    # 在 1 * 1 的范围内画图,因此分母为 1
    # 每个叶节点之间的偏移量
    plotTree.xOff = plotTree.figSize / (plotTree.totalLeaf - 1)
    # 每一层的高度偏移量
    plotTree.yOff = plotTree.figSize / plotTree.totalDepth
    # 节点名称
    nodeName = list(decisionTree.keys())[0]
    # 根节点的起止点相同,可避免画线;如果是中间节点,则从当前叶节点的位置开始,
    #      然后加上本次子树的宽度的一半,则为决策节点的横向位置
    centerPt = (plotTree.x + (leafNum - 1) * plotTree.xOff / 2.0, plotTree.y)
    # 画出该决策节点
    plotNode(nodeName, centerPt, parentPt, decisionNodeStyle)
    # 标记本节点对应父节点的属性值
    plotMidText(centerPt, parentPt, parentValue)
    # 取本节点的属性值
    treeValue = decisionTree[nodeName]
    # 下一层各节点的高度
    plotTree.y = plotTree.y - plotTree.yOff
    # 绘制下一层
    for val in treeValue.keys():
        # 如果属性值对应的是字典,说明是子树,进行递归调用; 否则则为叶子节点
        if type(treeValue[val]) == dict:
            plotTree(treeValue[val], centerPt, str(val))
        else:
            plotNode(treeValue[val], (plotTree.x, plotTree.y), centerPt,
                     leafNodeStyle)
            plotMidText((plotTree.x, plotTree.y), centerPt, str(val))
            # 移到下一个叶子节点
            plotTree.x = plotTree.x + plotTree.xOff
    # 递归完成后返回上一层
    plotTree.y = plotTree.y + plotTree.yOff


# 画出决策树
def createPlot(decisionTree):
    fig = plt.figure(1, facecolor="white")
    fig.clf()
    axprops = {"xticks": [], "yticks": []}
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 定义画图的图形尺寸
    plotTree.figSize = 1.0
    # 初始化树的总大小
    plotTree.totalLeaf, plotTree.totalDepth = getTreeSize(decisionTree)
    # 叶子节点的初始位置x 和 根节点的初始层高度y
    plotTree.x = 0
    plotTree.y = plotTree.figSize
    plotTree(decisionTree, (plotTree.figSize / 2.0, plotTree.y), "")
    plt.show()

简单测试
""" @File   : demo1
    @Author : BabyMuu
    @Time   : 2022/5/6 13:01
"""
from pprint import pprint

import pandas as pd

from handwritten_algorithm_model.decision_tree.DecisionTree import DecisionTree
from handwritten_algorithm_model.template.draw.draw_tree import createPlot

data_path = '../_data/sales_data.xls'
data = pd.read_excel(data_path, index_col='序号')
feature = data[data.columns.drop('销量')]
target = data['销量']
tree = DecisionTree(feature, target, feature.columns)
pprint(tree.tree)
createPlot(tree.tree)

可视化结果

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/875447.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-13
下一篇 2022-05-13

发表评论

登录后才能评论

评论列表(0条)

保存