1.网络最后一层(全连接层)的输出为类别得分scores(也加logits),使用sofxmax对logits映射为概率分布,且和为1,添加t参数,使网络概率分布更加平缓.如下为添加t参数方式.
import numpy as np #原softmax函数,numpy实现 def softmax(x): x_exp = np.exp(x) return x_exp / np.sum(x_exp) output = np.array([0.1, 1.6, 3.6]) print(softmax(output)) #添加了温度参数t的softmax函数,温度t越大,softmax输出的概率分布越明显. 蒸馏的名字就是由温度参数t而来 def softmax_t(x, t): x_exp = np.exp(x / t) return x_exp / np.sum(x_exp) output = np.array([0.1, 1.6, 3.6]) print(softmax_t(output, 5))
2.求crossEntroyLoss,交叉熵的输入为softmax输出值和真实标签值(one_hot数组).
output = torch.tensor([[1.2, 2, 3]]) target = torch.tensor([0]) ce_loss = F.cross_entropy(output, target) print(ce_loss)
3.先训练老师网络,定义epoch后正常训练,保存teacher_model
4.训练学生网络
从数据集读取图片data,加载学生网络,获得输出output.接着加载老师网络,获得输出teacher_output.将这两个输出,温度参数t,平衡因子alpha添加到蒸馏函数进行loss计算.
for batch_idx, (data, target) in enumerate(train_loader): output = model(data) #学生网络 teacher_output = teacher_model(data) teacher_output = teacher_output.detach() # 切断老师网络的反向传播,感谢B站“淡淡的落”的提醒 loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7) loss.backward() optimizer.step()
附:蒸馏函数实现
def distillation(y, labels, teacher_scores, temp, alpha): return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * ( temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
参考链接:Efficient-Neural-Network-Bilibili/4-Knowledge-Distillation at master · mepeichun/Efficient-Neural-Network-Bilibili · GitHub
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)