哈夫曼树python实现

哈夫曼树python实现,第1张

HuffmanTree的python实现 – 潘登同学的图论笔记

文章目录
    • HuffmanTree的python实现 -- 潘登同学的图论笔记
  • 哈夫曼树
    • 构建哈夫曼树的过程
  • 树节点实现
  • HuffmanTree实现
  • 绘制HuffmanTree
    • 测试代码

哈夫曼树

当用 n 个结点(都做叶子结点且都有各自的权值)试图构建一棵树时,如果构建的这棵树的带权路径长度最小,称这棵树为“最优二叉树”,

在构建哈弗曼树时,要使树的带权路径长度最小,只需要遵循一个原则,那就是:权重越大的结点离树根越近。


在图 1 中,因为结点 a 的权值最大,所以理应直接作为根结点的孩子结点。


构建哈夫曼树的过程
  1. 在 n 个权值中选出两个最小的权值,对应的两个结点组成一个新的二叉树,且新二叉树的根结点的权值为左右孩子权值的和
  2. 在原有的 n 个权值中删除那两个最小的权值,同时将新的权值加入到 n–2 个权值的行列中,以此类推
  3. 重复 1 和 2 ,直到所以的结点构建成了一棵二叉树为止,这棵树就是哈夫曼树

话不多说,直接看代码

树节点实现

树节点基本上都是大同小异的

  • root: 该节点是否为叶节点(不是则为None)
  • value: 记录这个词
  • frq: 记录这个词出现的频次(或者是某个父节点下所有frq之和)
  • size: 记录这个某个父节点下的节点总数(主要用于画图)
class HuffmanTreeNode:
    def __init__(self,
                root=None,
                value:str=None,
                frq:int=0,
                ) -> None:
        self.root=root
        self.value = value
        self.frq = frq
        self.left = None
        self.right = None
        self.size = 1

    def Setleft(self,left):
        self.left = left
        self.frq += left.Getfrq()
        self.size += left.GetSize()
        return self

    def Setright(self,right):
        self.right = right
        self.frq += right.Getfrq()
        self.size += right.GetSize()
        return self
    
    def Getfrq(self):
        return self.frq
    
    def Getvalue(self):
        return self.value

    def GetSize(self):
        return self.size

    def Hasright(self):
        return self.right

    def Hasleft(self):
        return self.left

    def Isroot(self):
        return self.root

    def __str__(self) -> str:
        if self.root:
            return f'root, sum of frequency:{self.frq}'
        else:
            return f'value: {self.value}, frequency: {self.frq}'
HuffmanTree实现

HuffmanTree主要有两个方法

  • _buildHuffmanTree: 将词频字典输入,进行树的构建
  • _iter_node: 在构建好的树中,获得某个词的编码(因为哈夫曼树就是用于解决编码的,但是后来有很多的作用,我就是从CBOW模型过来的)
class HuffmanTree:
    def __init__(self,
                num:dict) -> None:
        # 对字典按照其values进行排序
        self.num = sorted(num.items(),key=lambda x:x[1],reverse=False)
        self.list = []  # 一个储存列表
        self.coding = {} # 编码结果
        self._buildHuffmanTree()
        self._iter_node(self.list[0])
    
    def _buildHuffmanTree(self):
        self.list = [HuffmanTreeNode(root=False,value=i[0],frq=i[1]) for i in self.num]
        while len(self.list) > 1:
            # 将两个小的节点合并  小的放左边
            right_node = self.list[1]
            left_node = self.list[0]
            # 注意pop顺序
            self.list.pop(1)
            self.list.pop(0)
            temp_node = HuffmanTreeNode(root=True)
            temp_node.Setright(right_node)
            temp_node.Setleft(left_node)
            # 将合并后的根节点放回list中
            if len(self.list) == 1:
                if temp_node.Getfrq() < self.list[0].Getfrq():
                    self.list.insert(0,temp_node)
                else:
                    self.list.insert(1,temp_node)
            elif len(self.list) == 0:
                self.list.insert(0,temp_node)
            else:
                for i in range(len(self.list)-1):
                    if i == 0 and temp_node.Getfrq() <= self.list[i].Getfrq():
                        self.list.insert(i,temp_node)
                        continue
                    elif self.list[i].Getfrq() < temp_node.Getfrq() <= self.list[i+1].Getfrq():
                        self.list.insert(i+1,temp_node)
                        continue
                    elif i == len(self.list)-2 and temp_node.Getfrq() > self.list[i+1].Getfrq():
                        self.list.insert(i+2,temp_node)
                        continue

    def getTree(self):
        return self.list[0]

    def _iter_node(self,node,code=''):
        if node:
            if not node.Isroot():
                self.coding[node.Getvalue()] = code
            self._iter_node(node.Hasleft(),code='0'+code)
            self._iter_node(node.Hasright(),code='1'+code)
    
    def getCode(self):
        return self.coding
绘制HuffmanTree

画图函数与之前画红黑树的区别不大,改一改拿来用就行

class Draw_RBTree:
    def __init__(self, tree):
        self.tree = tree

    def show_node(self, node, ax, height, index, font_size):
        if not node:
            return
        x1, y1 = None, None
        if node.left:
            x1, y1, index = self.show_node(node.left, ax, height-1, index, font_size)
        x = 100 * index - 50
        y = 100 * height - 50
        if x1:
            plt.plot((x1, x), (y1, y), linewidth=2.0,color='b')
        circle_color = 'mediumspringgreen'
        text_color = 'black'
        ax.add_artist(plt.Circle((x, y), 50, color=circle_color))
        text = str(node.Getfrq()) if node.Isroot() else node.Getvalue() + '\n' + str(node.Getfrq())
        ax.add_artist(plt.Text(x, y, text, color= text_color, fontsize=font_size, horizontalalignment="center",verticalalignment="center"))
        # print(str(node.val), (height, index))

        index += 1
        if node.right:
            x1, y1, index = self.show_node(node.right, ax, height-1, index, font_size)
            plt.plot((x1, x), (y1, y), linewidth=2.0, color='b')

        return x, y, index

    def show_hf_tree(self, title):
        fig = plt.figure(figsize=(10,6))
        ax = fig.add_subplot(111)
        left, right = self.get_left_length(), self.get_right_length(), 
        height = 2 * np.log2(self.tree.size + 1)
        # print(left, right, height)
        plt.ylim(0, height*100 + 50)
        plt.xlim(0, 100 * self.tree.size + 100)
        self.show_node(self.tree, ax, height, 1, self.get_fontsize())
        plt.axis('off')
        plt.title(title)
        plt.show()

    def get_left_length(self):
        temp = self.tree
        len = 1
        while temp:
            temp = temp.left
            len += 1
        return len

    def get_right_length(self):
        temp = self.tree
        len = 1
        while temp:
            temp = temp.right
            len += 1
        return len

    def get_fontsize(self):
        count = self.tree.size
        if count < 10:
            return 30
        if count < 20:
            return 20
        return 16
测试代码
if __name__ == '__main__':
    num = {'a':10,'b':15,'c':12,'d':3,'e':4,'f':13,'g':1}
    h = HuffmanTree(num)
    tree = h.getTree()
    d = Draw_RBTree(tree)
    d.show_hf_tree('HuffmanTree')
    print(h.getCode())

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/578459.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-11
下一篇 2022-04-11

发表评论

登录后才能评论

评论列表(0条)

保存