深度学习知识蒸馏学习笔记

深度学习知识蒸馏学习笔记,第1张

深度学习知识蒸馏学习笔记

1.网络最后一层(全连接层)的输出为类别得分scores(也加logits),使用sofxmax对logits映射为概率分布,且和为1,添加t参数,使网络概率分布更加平缓.如下为添加t参数方式.

import numpy as np

#原softmax函数,numpy实现
def softmax(x):
    x_exp = np.exp(x)
    return x_exp / np.sum(x_exp)

output = np.array([0.1, 1.6, 3.6])
print(softmax(output))

#添加了温度参数t的softmax函数,温度t越大,softmax输出的概率分布越明显. 蒸馏的名字就是由温度参数t而来
def softmax_t(x, t):
    x_exp = np.exp(x / t)
    return x_exp / np.sum(x_exp)

output = np.array([0.1, 1.6, 3.6])
print(softmax_t(output, 5))

2.求crossEntroyLoss,交叉熵的输入为softmax输出值和真实标签值(one_hot数组).

output = torch.tensor([[1.2, 2, 3]])
target = torch.tensor([0])

ce_loss = F.cross_entropy(output, target)
print(ce_loss)

3.先训练老师网络,定义epoch后正常训练,保存teacher_model

4.训练学生网络

从数据集读取图片data,加载学生网络,获得输出output.接着加载老师网络,获得输出teacher_output.将这两个输出,温度参数t,平衡因子alpha添加到蒸馏函数进行loss计算.

for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)  #学生网络
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()  # 切断老师网络的反向传播,感谢B站“淡淡的落”的提醒
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
        loss.backward()
        optimizer.step()

附:蒸馏函数实现

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

参考链接:Efficient-Neural-Network-Bilibili/4-Knowledge-Distillation at master · mepeichun/Efficient-Neural-Network-Bilibili · GitHub

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5689360.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存