多标签分类

多标签分类,第1张

多标签分类

https://www.jianshu.com/p/ac3bec3dde3e

多标签分类任务损失函数
在二分类、多分类任务中通常使用交叉熵损失函数,即Pytorch中的CrossEntorpy,但是在多标签分类任务中使用的是BCEWithLogitsLoss函数。

BCEWithLogitsLoss与CrossEntorpy的不同之处在于计算样本所属类别概率值时使用的计算函数不同:
1)CrossEntorpy使用softmax函数,即将模型输出值作为softmax函数的输入,进而计算样本属于每个类别的概率,softmax计算得到的类别概率值加和为1。
2)BCEWithLogitsLoss使用sigmoid函数,将模型输出值作为sigmoid函数的输入,计算得到的多个类别概率值加和不一定为1。

共同点是计算概率值后都继续计算预测概率值和真实标签之间的交叉熵作为最终的损失函数值。

为什么在多标签任务中使用BCEWithLogitsLoss(sigmoid)函数呢?个人理解如下:
1)二分类/多分类任务是在两个/多个类别中取出一个类别,并且各个类别之间是互斥的,因此要保证多个类别的概率值加和为1(在类别概率值加和为1的情况下,一个类别概率值增加时必然有其他类别概率值减小,体现了各个类别之间的互斥),并且最终取出概率值最大的类别。
2)多分类任务是在多个类别中取出一个或多个类别,各个类别之间不互斥,因此无需保证各类别概率加和为1,只需要计算样本属于每一个类别的概率,如果样本属于某一类别的概率高于阈值则代表样本属于该类别(此时类别概率值加和不一定为1),例如样本A经过BCEWithLogitsLoss函数计算后得到属于类别1、类别2、类别3和类别4的概率值分别为[0.6,0.7,0.3,0.4],阈值为0.5,则样本A同时属于类别1和类别2。

使用方法如下:

import torch
import torch.nn as nn

#创建输入
input = torch.tensor([[0.1,0.2,0.3],[0.4,0.5,0.6]])#共有两个输入样本
target = torch.tensor([[1,0,1],[0,1,1]])#每个样本的标签值都是多标签

#创建模型并计算
model = nn.Linear(input_dim = 3,hidden_dim=5)#此处随便写一个模型示意
model_out = model(input)

#计算损失函数值
loss = torch.nn.BCEWithLogitsLoss(model_out,target)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5670416.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存