pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

pytorch实战:详解查准率(Precision)、查全率(Recall)与F1,第1张

pytorch实战:详解查准率(Precision)、查全率(Recall)与F1 pytorch实战:详解查准率(Precision)、查全率(Recall)与F1 1、概述

本文首先介绍了机器学习分类问题的性能指标查准率(Precision)、查全率(Recall)与F1度量,阐述了多分类问题中的混淆矩阵及各项性能指标的计算方法,然后介绍了PyTorch中scatter函数的使用方法,借助该函数实现了对Precision、Recall、F1及正确率的计算,并对实现过程进行了解释。

观前提示:阅读本文需要你对机器学习与PyTorch框架具有一定的了解。

Tips:如果你只是想利用PyTorch计算查准率(Precision)、查全率(Recall)、F1这几个指标,不想深入了解,请直接跳到第3部分copy代码使用即可。

2、查准率(Precision)、查全率(Recall)与F1 2.1、二分类问题

对于一个二分类(正例与反例)问题,其分类结果的混淆矩阵(Confusion Matrix)如下:

预测的正例预测的反例真实的正例TP(真正例)FN(假反例)真实的反例FP(假正例)TN(真 反例)

则查准率P定义为:
P = T P T P + F P P=frac{TP}{TP+FP} P=TP+FPTP​
查全率R定义为:
R = T P T P + F N R=frac{TP}{TP+FN} R=TP+FNTP​
可见,查准率与查全率是一对相互矛盾的量,前者表示的是预测的正例真的是正例的概率,而后者表示的是将真正的正例预测为正例的概率。听上去有点绕,通俗地讲,假设这个分类问题是从一批西瓜中辨别哪些是好瓜哪些是不好的瓜,查准率高则意味着你挑选的好瓜大概率真的是好瓜,但是由于你选瓜的标准比较高,这意味着你也错失了一些好瓜(宁缺毋滥);而查全率高则意味着你能选到大部分的好瓜,但是由于你为了挑选到尽可能多的好瓜,降低了选瓜的标准,这样许多不太好的瓜也被当成了好瓜选进来。通过调整你选瓜的“门槛”,就可以调整查准率与查全率,即:“门槛”高,则查准率高而查全率低;“门槛”低,则查准率低而查全率高。通常,查准率与查全率不可兼得,在不同的任务中,对查准率与查全率的的重视程度也会有所不同。这时,我们就需要一个综合考虑查准率与查全率的性能指标了,比如 F β F_{beta} Fβ​,该值定义为:
F β = ( 1 + β 2 ) × P × R ( β 2 × P ) + R F_{beta}=frac{(1+{beta}^2) times P times R}{({beta}^2 times P)+R} Fβ​=(β2×P)+R(1+β2)×P×R​
其中 β beta β度量了查全率对查准率的相对重要性, β > 1 beta >1 β>1时,查全率影响更大, β < 1 beta <1 β<1时,查准率影响更大,当 β = 1 beta =1 β=1时,即是标准的F1度量:
F 1 = 2 × P × R P + R F1=frac{2 times P times R}{P+R} F1=P+R2×P×R​
此时查准率与查全率影响相同。

2.2、多分类问题

对于多分类问题,在计算查准率与查全率的时候,可以将其当作二分类问题,即正确类和其他类。为了方便阐述,还是借助一个例子来说明吧,假设有一个五分类的问题,每类分别标记为0、1、2、3、4,其分类结果的混淆矩阵如下:

在这个问题中,对于类别0,可以看作区分类别{0}与类别{1,2,3,4}的二分类问题,则对应的TP=m00,FP=m10+m20+m30+m40,FN=m01+m02+m03+m04,同样可以计算出相应的指标。因此,对于一个n分类问题,对应的查准率与查全率应该是一个n维向量,向量中的元素表示每类的查准率与查全率。

3、PyTorch中scatter()函数的使用

在讲解第二节中各个指标的计算方法的时候,首先来学习一下PyTorch中scatter()函数的使用方法。查看Pytorch官方文档中torch.Tensor.scatter()的说明,发现最终跳转到了torch.Tensor.scatter_(),有的朋友可能会犯迷糊,这两个函数有什么区别?其实区别很简单,scatter()不改变原来的张量,而scatter_()是在原张量上进行改变。官网的介绍为:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

这里,self表示的是调用该函数的张量本身,其结果为:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

看上去有点绕,用大白话讲就是将参数index中给定元素作为一位索引(具体是哪位由dim参数决定),将src中的值填入self张量。这里的src也可以是个标量,如果是标量就直接填充就行。之所以用“填充”这个词,就是因为并不是self中所有的元素都会改变,只有少数张量会改变,具体改变的位置由index与dim决定。话不多说,直接看代码

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

该代码样例中对一个维度为(3,5)的全0张量进行了计算,参数dim=0,结果张量中,值发生变化的位置分别是(0,0)、(1,1)、(2,2)、(0,3)。细心的朋友应该能发现,这四个位置的第一个维度(dim=0)刚好是张量index的各个元素,而第二个维度则是index中对应元素的另一个维度的值。再看另一个例子

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

该代码中值发生变化的位置是(0,0)、(0,1)、(0,2)、(1,0)、(1,1)、(1,4),同样的,这几个位置的第二个维度(dim=1)是index中的各个元素,而第一个维度则是index中对应元素的另一个维度的值。在涉及到高维张量的时候,这个 *** 作的空间意义我们可能难以想想,但是记住这点而不去追求理解其空间意义,再高的维度,也能很好地理解这个计算的具体 *** 作。

4、PyTorch实战与代码解析

接下来就是在PyTorch实现对查准率、查全率与F1的计算了,在该实例中,我们用到了scatter_()函数,首先来看完整的数据集测试过程代码,便于各位理解与取用,然后再对具体的计算代码进行阐述。

def test(valid_queue, net, criterion):
    net.eval()
    test_loss = 0 
    target_num = torch.zeros((1, n_classes)) # n_classes为分类任务类别数量
    predict_num = torch.zeros((1, n_classes))
    acc_num = torch.zeros((1, n_classes))

    with torch.no_grad():
        for step, (inputs, targets) in enumerate(valid_queue):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            
            pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)
            predict_num += pre_mask.sum(0)  # 得到数据中每类的预测量
            tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.)
            target_num += tar_mask.sum(0)  # 得到数据中每类的数量
            acc_mask = pre_mask * tar_mask 
            acc_num += acc_mask.sum(0) # 得到各类别分类正确的样本数量

        recall = acc_num / target_num
        precision = acc_num / predict_num
        F1 = 2 * recall * precision / (recall + precision)
        accuracy = 100. * acc_num.sum(1) / target_num.sum(1)

        print('Test Acc {}, recal {}, precision {}, F1-score {}'.format(accuracy, recall, precision, F1))

    return accuracy

首先看下面这行代码,产生一个大小为(batch_size , n_classes)的全0张量,然后将predicted的维度变成(batch_size,1),每个元素都代表的是其分类对应的编号,通过scatter_()函数,将1写入了全0张量中的对应位置。得到的张量pre_mask就是每次预测结果的one-hot编码。

pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)

然后将pre_mask的第一个维度上的所有值进行加和,就得到了一个维度为(1 ,n_classes)的张量,该张量中的每个元素都是预测结果中对应的类的数量,对其进行累加,便得到了整个测试数据集的预测结果predict_num。

predict_num += pre_mask.sum(0)

tar_mask与target_num同理,此处不再赘述。

acc_mask = pre_mask * tar_mask

然后通过pre_mask * tar_mask,得到了一个表示分类正确的样本与对应类别的矩阵acc_mask,其维度为(batch_size , n_classes),对于其中一个元素acc_mask[i][j]=1,表示这个batch_size 中的第i个样本分类正确,类别为j。

acc_num += acc_mask.sum(0)

将acc_mask的第一个维度上的所有值进行加和,便可得到该batch_size数据中每个类别预测正确的数量,累加即可得到整个验证数据集中各类正确预测的样本数。

recall = acc_num / target_num
precision = acc_num / predict_num
F1 = 2 * recall * precision / (recall + precision)
accuracy = 100. * acc_num.sum(1) / target_num.sum(1)

然后就可以计算各个指标了,很好理解,就不再解释了。

希望对你有所帮助,也欢迎在评论区提出你的想法与意见。

5、参考

[1].《机器学习》,周志华

[2].https://blog.csdn.net/qq_16234613/article/details/80039080

[3].https://zhuanlan.zhihu.com/p/46204175

[4].https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html?highlight=scatter_

rticle/details/80039080

[3].https://zhuanlan.zhihu.com/p/46204175

[4].https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html?highlight=scatter_

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5580533.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-14

发表评论

登录后才能评论

评论列表(0条)

保存