《检索Recall@k的计算脚本》
检索任务的评估指标是召回率,我们不光要看Recall@1也要看下Recall@3及Recall@10等,看下排序有多少空间,记录一下训练过程中对positive pair 计算 Recall的脚本;基于numpy
-
我们默认 query 和 doc 是一一匹配的 positive pair
-
query shape -> (n, c)
-
doc shape -> (n, c)
-
query.dot(doc.T) shape -> (n, n)
-
内积计算的相似度矩阵对角线上是正样本其余都是负样本
def top10_recall(query, doc, n):
# 这里的 query 和 doc 都是单位向量
top_ten = np.array([0] * 10) # 记录 top_ten 个数的列表
# measure accuracy and record loss
query = query.detach().cpu().numpy()
doc = doc.detach().cpu().numpy()
sim_matrix = query.dot(doc.T)
# 降序排序
top_sim = np.argsort(-sim_matrix, 1)[:, :10].tolist()
target = np.arange(n)
# 计算top 1 到 top 10 的 recall
for ii in range(n):
for top in range(10):
if target[ii] in top_sim[ii][:top + 1]:
top_ten[top:] += 1
break
top_ten /= n
return top_ten
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)