Seq2seq模型蒸馏方法

Seq2seq模型蒸馏方法,第1张

一. Seq2seq模型蒸馏方法总体过程如下

1. 训练teacher模型

2. 产生student模型

3. 利用teacher模型预测的logits和来自语料的true labels来计算student 模型的训练过程中的loss。

二. 涉及的具体步骤和参数有

1. 训练参数量相对较大的teacher模型。

2. 生成student模型,可以从teacher模型结构中抽取部分层组成,也可以随机初始化student模型的参数。

        如果从teacher模型中抽取,则可以在训练时固定某些层,例如可在训练时freeze_embeds.

     如果student模型的encoder和teacher模型的encoder完全一致,在训练时,可以考虑freeze_encoder。其他情况则不考虑freeze_encoder。

3. 根据teacher logits产生的时间不同,模型蒸馏可分为在线蒸馏和离线蒸馏。

        离线蒸馏是采用teacher模型,预先将decoder端每个token对应的词表(或类别)大小的概率分布预测出来,在训练时和true label一起输入来计算loss。

        在线蒸馏是同时将teacher模型和student模型加载到训练机上,在训练时利用teacher模型来预测每个token位置的概率分布(logits), 同时和true label一起参与loss的计算。

        在线蒸馏时,teacher模型参数固定,只有student模型的参数为trainable状态。

三. 关于loss的计算

1. Loss共有3部分构成,即来自teacher_logits的loss_ce, 来自true_labels的loss_mlm, 和来自中间层的loss_hid.

        对应3个loss部分在总的loss中的比例系数可以分别用alpha_ce, alpha_mlm, 和alpha_hid表示。因此总的loss可以表示为:

        loss_total = (alpha_ce * loss_ce) + (alpha_mlm * loss_mlm) + (alpha_hid * loss_hid)

        其中,

        loss_ce = distill_loss_fn(student_logits, teacher_logits,temperature)

        loss_mlm = loss_fn(student_logits, true_labels)

        loss_hid = mse_loss(student_hid, teacher_hid).

2. 关于loss_hid可以这样理解,采用teacher中的某些层来监督student中的各层的结果。例如采用一个12层的teacher模型,来蒸馏一个3层的student模型,如果只关注encoder端,可以用teacher_encoder  [0, 6, 11]层来分别监督student_encoder  [0, 1, 2]层的训练结果。

        如果是离线蒸馏,并且需要在loss中计算student各层的损失,则在需要将teacher模型各层的结果,和teacher logits一起预先计算并保存。

3. loss中涉及的3个部分的损失函数不同,其中mlm对应的是一般的cross_entropy, hid对应的为mse,ce部分对应的为和温度相关的KLDivLoss, loss_ce具体可以描述为:

        loss_ce = KLDivLoss(

                         softmax(student_logits/temperature, dim=-1),  # vocab_size

                         softmax(teacher_logits/temperature, dim=-1)

                        ) * (temperature ^ 2)

        关于最后需要乘温度的平方,可以阅读【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 - 知乎,简单表述为,loss_ce乘上 temperature^2 后,与loss_mlm的值相当,因此为了平衡loss_ce对损失的贡献,要乘temperature^2 。

4. 机器翻译(生成式)中使用离线蒸馏的问题

        尝试将100句中英语料的teacher logtis预测并保存,发现保存后的文件为869M(.npy格式).这很大原因是因为label的维度过大导致的,因为teacher logits的最后一个维度为词表大小,词表大小为5w左右(裁剪后的mbart50模型)。

        考虑到机器翻译的语句对经常为千万级别,对teacher logtis的存储空间要求较高,因此离线蒸馏在现有方法改进之前,并不适用机器翻译。

5. 综合来看,蒸馏涉及的主要参数有

        --teacher_model

        --student_encoder_layers=3

        --student_decoder_layers=3

        --temperature=2

        --alpha_ce=0.5

        --alpha_mlm=0.5

        --alpha_hid=0

        --freeze_encoder=False

        --freeze_embeds

        --max_sentence_length=64

        --train_batch_size

        --train_epochs=5

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/921621.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存