Retrieval task calculate Recall@k

Retrieval task calculate Recall@k,第1张

《检索Recall@k的计算脚本》

检索任务的评估指标是召回率,我们不光要看Recall@1也要看下Recall@3及Recall@10等,看下排序有多少空间,记录一下训练过程中对positive pair 计算 Recall的脚本;基于numpy

  • 我们默认 query 和 doc 是一一匹配的 positive pair

  • query shape -> (n, c)

  • doc shape -> (n, c)

  • query.dot(doc.T) shape -> (n, n)

  • 内积计算的相似度矩阵对角线上是正样本其余都是负样本

    def top10_recall(query, doc, n):
    # 这里的 query 和 doc 都是单位向量
    top_ten = np.array([0] * 10) # 记录 top_ten 个数的列表
    # measure accuracy and record loss
    query = query.detach().cpu().numpy()
    doc = doc.detach().cpu().numpy()
    sim_matrix = query.dot(doc.T)
    # 降序排序
    top_sim = np.argsort(-sim_matrix, 1)[:, :10].tolist()
    target = np.arange(n)
    # 计算top 1 到 top 10 的 recall
    for ii in range(n):
    for top in range(10):
    if target[ii] in top_sim[ii][:top + 1]:
    top_ten[top:] += 1
    break
    top_ten /= n
    return top_ten

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/922568.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存