- 摘要
- intro
- 相关工作
- Unsupervised domain adaptation
- Unsupervised representation learning
- Learning from noisy labels
- 方法
伪标签的缺点: 由于源域和目标域的不一致,标签会有噪声,特征会不一致
做法
- 估计伪标签的可能性,以便在线修正;
- 根据同一目标的两个不同视图的相对特征距离对齐原型分配,产生更紧凑的目标特征空间;
- 蒸馏预训练好的模型可以提升性能;
uda的两种方式
- 对齐源域和目标域的分布
- self-training(先生成伪标签,再重新训练网络)
self-training的缺点
- 依赖于置信度阈值,而高分并不是一定正确
- 存在域gap,在目标数据中,只会考虑接近源分布的数据
对应解决方法
- 在线进行伪标签去噪:根据距每个类的距离估计类别相关的熵
- 学习紧凑的目标结构:借鉴Deepcluster,为同目标的不同视图对齐软分配
- 域对齐方法侧重于通过优化差异(optimize divergence)或在图像级别、特征图级别、输出级别采用对抗性训练来减少分布不匹配。
- 对齐全局分布不能保证目标域上的误差很小
- 最近的方法也有从逐类别对齐特征的,但只要特征分离好,就没有必要严格对齐分布
- SSL(半监督学习)效果不错:熵最小化及其变体会对错误预测过度自信,所以self-training可以解决这问题,但是伪标签会有各种噪音(解决方法:对置信度正则、预测置信度分数、基于类中心生成伪标签,但是这些方法的label都是固定的)
- 本文提出了一种在线伪标签更新方案,其中根据目标域中估计的原型上下文来纠正错误预测
- 早期方法注重于auxiliary tasks的优化,之后有一些contrastive learning的工作比如MoCo拉近了和监督学习的距离
- 网络能够学习丰富的语义特征,只要它们在不同的增强视图下保持一致,但没有在像素级任务上应用
- 本文提出了一致学习和聚类表示学习能适用于uda
- 设计损失:难以处理真实的数据
- Self-label correction:同时训练两个或多个学习器,并利用他们的预测一致性来衡量标签的可靠性
- 本文方法根据一些复杂的启发式算法确定的原型在线纠正不正确的标签,并且是即时的
源域生成的伪标签不变变为给标签加权
- 根据源域生成的伪标签初始化prototype( prototype实际上是类别的中心embedding),并且用EMA去更新prototype,可以理解为学习源域的知识并对其利用目标域进行优化
- 用生成的伪标签训练目标域分割(用SCE是一个借鉴的创新点)
ℓ s c e t = α ℓ c e ( p t , y ^ t ) + β ℓ c e ( y ^ t , p t ) ell_{s c e}^{t}=alpha ell_{c e}left(p_{t}, hat{y}_{t}right)+beta ell_{c e}left(hat{y}_{t}, p_{t}right) ℓscet=αℓce(pt,y^t)+βℓce(y^t,pt) - 由于伪标记数据不能覆盖目标域中的整个分布,受无监督启发,作者用 momentum encoder 计算弱增强的图像,用当前模型计算强增强的图像,使两者KL散度对齐概率
ℓ r e g t = − ∑ i = 1 H × W ∑ j = 1 K log p t ( i , k ) ell_{r e g}^{t}=-sum_{i=1}^{H times W} sum_{j=1}^{K} log p_{t}^{(i, k)} ℓregt=−i=1∑H×Wj=1∑Klogpt(i,k)(避免某一类为空)
ℓ total = ℓ c e s + ℓ s c e t + γ 1 ℓ k l t + γ 2 ℓ r e g t ell_{text {total }}=ell_{c e}^{s}+ell_{s c e}^{t}+gamma_{1} ell_{k l}^{t}+gamma_{2} ell_{r e g}^{t} ℓtotal =ℓces+ℓscet+γ1ℓklt+γ2ℓregt - 用 SimCLRv2 预训练的权重初始化网络后,用前一个阶段的模型计算目标域的伪标签,用下面三个损失训练网络:
ℓ K D = ℓ c e s ( p s , y s ) + ℓ c e t ( p t † , ξ ( p t ) ) + β K L ( p t ∥ p t † ) ell_{mathrm{KD}}=ell_{c e}^{s}left(p_{s}, y_{s}right)+ell_{c e}^{t}left(p_{t}^{dagger}, xileft(p_{t}right)right)+beta mathrm{KL}left(p_{t} | p_{t}^{dagger}right) ℓKD=ℓces(ps,ys)+ℓcet(pt†,ξ(pt))+βKL(pt∥pt†)
第一项是源域的监督损失,第二项是用 teacher 模型获得的伪标签进行的监督损失,最后一项是利用 teacher 模型的输出进行蒸馏
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)