大多数现有的方法,都假设每个训练示例的候选标记是由真实标记和随机选取的不正确的标记组成的。
然而,这种假设是不切实际的,因为候选标记总是依赖示例的。
本期 AI Drive,东南大学计算机科学与工程学院硕士生-乔聪玉,解读其团队发表于 NeurIPS 2021 的最新工作:示例依赖的偏标记学习。
在这项研究中,他们考虑了示例依赖的偏标记学习,并假设每个示例都与每个标记的潜在标记分布相关联,其中标记分布代表了每个标记描述特征的程度。
描述程度越高的不正确标记更有可能被注释为候选标记。
因此,潜在标记分布是部分标记示例中必不可少的标记信息,值得用于预测模型训练。
本文将主要分为以下 5 个部分进行介绍:·Introduction·Related work·Proposed Method·Experiment·Conclusion其中,第一部分(Introduction)介绍偏标记学习定义、研究示例依赖的偏标记学习的理由,以及其中运用到的其他技巧。
第二部分(Related work),简要介绍目前在偏标记领域所用到的主流方法(包括五种传统算法、近两年兴起的基于深度学习算法在偏标记领域内的应用)。
第三部分(Proposed Method)部分,是我们提出算法的部分,此处会详细介绍有关的算法细节。
最后两部分(Experiments 以及 Conclusion)介绍实验结果和结论。
1 偏标记学习传统监督学习框架在建模时采用强监督的假设。
即对象的类别标记信息是单一、明确的。
传统监督学习框架已经取得了巨大成功。
值得注意的是,强监督假设虽然为学习建模的过程提供了便利,但却是对真实世界问题的一种简化处理的方式,在许多情况下,并不成立。
实际上会受外部环境问题特性、物理资源等各方面因素的制约,学习系统往往只能从训练样本中获取有限的标记信息及弱监督信息。
如何在弱监督信息条件下有效进行学习建模,已经成为机器学习领域研究的热点问题。
在偏标记学习的框架下,每个对象可同时获得多个语义标记,但其中仅有一个标记反映对象的真实语义,该形式的学习场景在现实问题中广泛存在。
例如,在医疗诊断中,医生虽然可以排除病人患有某些疾病的可能性,却难以从若干症状相似的疾病中给予确诊。
在互联网应用中,用户可以自由为各种在线对象提供标注,但在对象获得的多个标注中,可能仅有一个是正确的。
再举个例子,人们可以从图像附属标题文本内,获取图像中各个人物名称作为语义标记,但对于图像中特定人物、人脸而言,他与各个语义标记以及具体人物的名称对应关系却并未确定。
以上两个例子都是偏标记的应用的场景。
简而言之,在互联网应用中,用户可以自由为各种在线对象提供标注,但在对象获得的多个标注中,可能仅有一个是正确的。
以下是其形式化表达之一。
如下图,在偏标记训练集内,每个事例 x 对应一个候选集合 s,真实标记隐藏在候选集合中。
最终的学习目标是得到,能将示例x映射到真实标记 y 分类器(用 f 表示)。
在我们团队工作之前,算法偏标记的生成过程都是这样产生:除真实标记外,其他候选标记都是经过随机抽取出来的。
这是一种非常朴素的假设,这种假设常用来从非偏标记的数据集手动生成偏标记数据集。
比如说手写数字数据集 MNIST,对于手写数字 1,通过算法随机取 2 和 5 作为偏标记数据的假阳性标记,和 1 共同作为候选标记集合。
再举个例子,CIFAR10 数据集中的一张飞机图片,如果手动取猫、路,和飞机三者组成的图片作为候选标记集合。
那么这种假设显然不合理。
例如考虑三个正常人标记数据集,对于瘦长数字1而言,标注时在两眼发昏的情况下,是更容易把图片中瘦长的数字标注成它的候选标注成1的候选集,而不是宽大的数字。
因为1的特征之一就是瘦长,所以也有可能把写的瘦长的 6、7 标注成候选集,不太可能把写得宽大的 6、7 标注为 1。
对于飞机而言,更可能把背景看上去像蓝天或图像中长得像翅膀的目标标注成飞机,而不太可能把公路上运输的卡车标注为飞机,这是常识。
这也说明真在真实场景下,偏标记集合以视力依赖型为特征,而不是随机选出来的。
视力依赖型的偏标记,也更加符合实际偏标记的生成过程,因而针对其设计的算法也更加实用。
所以本文介绍的工作就是,提出示例依赖性的偏标记学习,并为其设计相应算法。
最后在 benchmark 数据集(还有 minist、fashion minist Kuzushiji minist, CIFAR10 数据集)、UCI 数据集、真实场景的偏标记数据集,这三大数据集上验证本文提出算法的有效性。
此处引入一个概念——标记分布 Label Distribution。
近两年,软标记的方法比较流行。
比如说 label smoothing、蒸馏等方法。
较早提出软标记学习的是我的导师耿新老师提出的标记分布 Label Distribution。
标注是标记多义性问题,是机器学习领域的热门方向之一。
在现有的机器学习范式中,主要存在两种数据标注方式:一是一个示例分配一个标记,二是一个示例分配多个标记。
单标记学习(Single Label Learning),假设训练集内所有示例都是用第一种方式标记。
多标记学习(Multiple Label Learning),允许训练示例用第二种方式标记,所以多标记学习可以处理的示例属于多个类别的多义性情况。
但总而言之,无论是单标记学习还是多标记学习,都只在回答一个最本质的问题——哪些标记可以描述具体事例?但却都没有直接回答另外更深层的问题——每个标记如何描述该示例?或每个标记对该示例的相对重要性程度如何?对于真实世界中的许多问题,不同标记的重要程度往往不一致。
例如,一幅自然场景图像被标注了天空、水、森林和云等多个标记,而这些标记具体描述该图像的程度却有所不同。
再例如,在人脸情感分析中,人的面部表情常常是多种基础情感,例如快乐、悲伤、惊讶、愤怒、厌恶、恐惧等基础感情。
而这些基础情感会在具体的表情中表达出不同强度。
从而呈现出纷繁复杂的情感。
类似的例子还有很多。
一般情况下,一旦一个事例与多个标记同时相关,这些标记对该事例不会恰好都一样重要,会有主次先后之分。
对于类似上述例子的应用,有一种很自然的方法。
对于一个示例x,将实数 d_xy(如图)赋予每一个可能的标记,y 描述 x 的程度。
这就是一个标记分布。
然而实践中,一般标注都是以 0、1 逻辑符号数据去标注。
其表达是或否的逻辑关系,所以对一个示例而言,所有标记逻辑值,构成的逻辑向量被称为逻辑标记。
例如常见的 one-hot 向量,这也是对问题的简化方式之一。
尽管如此,数据中的监督信息,本质上是遵循某种标记分布的。
例如鸟是有翅膀的,所以能飞。
那显然它可能会被标注为 bird 或 airplane,而不太可能被标注为 frog。
所以对于两者而言,对鸟图片的描述程度是不一样的。
但是目前的工作就是需要从逻辑标记(例如 one-hot),转化为置信度、描述度问题。
这个过程就属于标记增强过程,简而言之,标记增强就是将训练样本中的原始逻辑标记转化为标记分布的过程。
对于示例依赖的偏标记学习而言,如何描述偏标记集合中,元素之间的关系?其实就是利用标记分布,通过标记增强的方法,恢复其中潜在的标记分布。
还是刚刚的例子,对于数字 1,它的候选集合可能是 3 或 6,但这两者中,是3对1的描述度高?还是 6 对 1 的描述度高?1 对 3 和 1 对 6 哪个相关度更高?对飞机而言,到底是鸟标记对飞机的描述度更高,还是卡车的描述度更高?飞机跟鸟更相关,还是跟卡车更相关?例如以上这类信息的挖掘,需要借助标记增强,增强逻辑标记的描述度和相关性,这就是标记分布。
2 偏标记学习领域相关工作偏标记算法从直觉上来说,可以把不正确的标记找出来,学习、使用算法时将其排除,这个过程被称为消歧。
对于消歧的策略,分为两种,一是基于辨识的消歧,二是平均消歧。
在辨识消歧中,真实标记被当成隐变量,并以迭代的方式逐渐被识别出来。
在平均消歧策略中,所有候选标记都是被同等对待的,最终的预测,取自于模型最后输出的平均值。
现有大多数算法,都通过结合广泛使用机器学习技术与偏标记数据相匹配,完成学习任务。
例如观察每个部分标记训练示例的可能性,定义在其候选标记集上,而不是未知的 ground -truth 标记。
K 近邻技术也可以解决偏标记问题,其通过在相似示例的候选标记中投票来确定不可见示例的类别。
对于最大边界的技术,通过区分后验标记和非后验标记的建模输出,定义了偏标记示例的权重及候选标记的置信度。
传统机器学习算法中也有标记增强技术运用。
每个偏标记的训练示例的权重,以及后验标记的置信度,在每轮增强后都会更新。
接下来介绍深度学习方法在偏标记领域中的应用。
首先最开始的是 D2CNN,D2CNN 是通过为图像数据设计两个特定的网络,再不断学习偏标记。
这之后有一篇文章介为偏标记学习设计了广泛适用的算法框架。
这也是我们实验室一位师姐的文章,她提出了具有一致性的风险估计和渐进的识别算法,其算法可以兼容任意深度模型和随机优化器。
这篇文章正式开启了深度学习在偏标记领域的应用。
随后重庆大学的冯磊教授,提出了 RC、CC 这两种算法。
分别是风险一致和分类器一致的方法。
但是他们所提出的这些算法,都是假设偏标记是随机生成,例如 RC 和 CC,都是假设生成 uniform 的过程,最终的算法也是基于推导出来的。
PRODEN 算法在实验时,除了真实标记,其他每个偏标记都赋予一个伯努利概率 p,对于非真实标记,也有一定的概率被翻转成真实标记。
3 本次研究的新方法接下来介绍我们的算法,整个算法流程并不复杂。
下图形式化的表达之一。
以下是算法模型结构图,便于更好的理解整个算法流程。
模型分为上、下两层。
上层是辅助性网络。
最后需要用到估计出的标记分布,去监督下层网络,下面网络是分类器,也就是目标网络。
例如,一张图片,首先会进入 low level 层,推断标记分布。
其中需要用到很多信息,例如被抽取的特征、邻接矩阵等。
benchmark 数据集内是没有这个邻接矩阵的,所以需要首先要抽取特征。
因为 cifar10 是原始图像数据,直接做建模,就是邻接矩阵直接生成的话,肯定是不准确的。
例如,卷积神经网络效果为什么这么好,因为其有一定的频率不变性。
那么对于 cifar10,就需要做特征抽取,然后用 resnet32 网络收取,抽取出来后,利用编码器和解码器,就是一个 VGAE 编码器。
与以前的方法不一样的在于我们通过编码器参数化的 Dirichlet,从 Dirichlet 分布中取到值 D。
我们认为这就是一个标记分布。
下层的网络也不难,例如 high level,可以采用 MLP、感知机,作为聚合然后输出,得出最终的结果。
上面增强出来的 Label Distribution 标记分布,就用作下层网络的监督信息,使最终得出的结果更好。
上层网络,可以认为是不断挖掘潜在标记分布的过程。
以上所提出的算法是端到端的学习过程。
模型训练分为几个阶段:第一阶段,是模型的预热阶段,在提到要抽取特征,此前就需要预热一下。
这时用的是 minimal loss。
直觉上讲损失函数值最小的标记,可能就是真实标记。
对于抽取出来的特征,用 KNN 做邻接矩阵。
K 的值是超参。
第二阶段,是标记增强的阶段。
VALEN 算法在标语增强阶段,目标是推断出已知逻辑标记邻接矩阵特征的条件后验— p(D)。
但是如果想直接精确计算p(D)是不太现实的,所以此时需要用到一些技巧。
例如我们用 q(D) 去估算 p(D), q(D) 是用 Dirichlet 作为建模。
对于前面模型编码器输出的 α,就作为 Dirichlet 的参数。
采样后,采出来的就是需要的标记分布。
为了更好聚合拓扑关系,可以采用图卷积神经网络。
以下是贝叶斯变分推断技巧,具体的可以参考我们论文的补充材料。
与论文结合起来,了解详细的推导过程。
在本文就不展开介绍了,但也是从那边推算演化过来的。
除此之外,对于标记分布 D,则需要给其加上限制条件。
对于以下的网络输出,可以认为是一种置信度。
上文的实验(例如 PRODEN),也相应证明了网络输出对真实标记的置信度可能是最大。
所以增强后的标记分布,不能距离置信度太远。
简而言之,不能偏离置信度。
同时,对于偏标记候选集合之外的标记,我认为其置信度为零。
这是一个比较直观的假设。
例如上文提到的,鸟与飞机相关性相对较强。
在标注的时候,可能就只标注为飞机和鸟,对于其他(例如 frog)类别的置信度就为零。
因为这些类别相关度太低。
最后,会介绍为什么采用迪利克雷分布。
因为狄利克雷分布从直观上来看,分布采样得到的值与标记分布的值很相似。
其现实条件也是一样的。
因为标记分布的要求之一就是 ∑ 为 1,通过迪利克雷采样得出的值就是类似的形式。
其次,迪利克雷分布属于类别分布,类别分布可以作为真实标记分布。
所以可以采用 Dirichlet 分布表示,去挖掘潜在的标记分布。
最后在模型的训练阶段,下图为损失函数,我们采用的是交叉熵 log 值,再加上权重。
这个权重就是标记分布,通过以上函数不断训练,得出好的效果。
4 实验结果实验部分,首先是对于数据集问题,如何生成示例依赖型的偏标记数据?其实就是用干净的标记去训练网络,对于网络输出的值,每一个输出的值我们认为就是这个示例在这个标记上的置信度,每个标记对应的置信度与除了真实标记外最大的置信度相除,再用大部分的式子规划一下,那么就可以得出每个标记被翻转出来的概率。
即 one-hot 中的 0 的标记有一定几率被翻转成为 1。
这样就可以得出示例依赖的偏标记数据集了。
其背后的思想是把神经网络看成一个打分者,例如我在这个标记上犯错误的概率是多少?它就有相应的可能被翻转过来变为 1。
Benchmark datasets 和 UCI datasets 都是经过上述步骤生成。
对于真实场景下的偏标记数据,是来自各方各面的领域,有人脸、目标检测、甚至还有音频方面、都有涉及到。
对于下图的 BirdSong、Soccer Player、Yahoo news,这三个数据的标记训练集的个数是庞大的。
在示例依赖型的数据集上,我们的方法比其他几个深度的方法都要高很多。
在 uniform 数据。
对于随机抽取一些随机生成的偏标记过程中,我们的方法也是可比的,均值基本上都是最高的(除了在 MNIST 上)。
MNIST 数据集稍微有点落后,和 UCI 数据集一样。
对比于传统方法,因为大数据集的图像数据维度较大。
所以传统方法并不太适用。
但对于小数据,我们也将传统数据增加进去了,传统方法在小数据集上得出的效果也是很好的。
一些传统方法的表现也是很好的,如下图。
在真实场景下,一些传统方法处理小数据级得出的效果明显优于我们的算法。
上文提出的算法可能更适合处理大规模数据集,但是我们算法与深度方法相比,还是优于深度方法。
5 总结我们最主要的贡献,在于首次提出示例依赖的偏标记的学习框架。
关键技巧,就是分为两个网络,一个是辅助网络,另一个是主要的目标网络。
辅助网络通过迭代的方式,去恢复潜在的标记分布。
然后利用这个标记分布,在每个阶段训练预测模型。
对于未来的工作,我们会去继续探究其他更好的方法去学习示例依赖的偏标记学习。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)