DOTA数据集测试mAP

DOTA数据集测试mAP,第1张

前言

测试mAP的程序是官方提供的DOTA_devkit,但mmdetection生成的测试结果不能直接用来测试mAP,需要一定的格式转换。

DOTA_devkit需要的文件

测试mAP需要三个文件夹

detpath = r'/home/zhangxiao/DOTA_devkit-master/map_test/predict_rusult/Task1_ship.txt'  # Path to detectionsdetpath.format(classname) should produce the detection results file.
annopath = r'/home/zhangxiao/DOTA_devkit-master/map_test/gt/SSDD_orient_annotation/{:s}.txt'  # change the directory to the path of val/labelTxt, if you want to do evaluation on the valset
imagesetfile = r'/home/zhangxiao/DOTA_devkit-master/map_test/test_img_name.txt'

detpath 是存放检测结果的文件夹,每个文件夹的格式如下。第一个数字是文件名,第二个数字是预测概率,最后八个数字对应四个点的坐标。

annopath存放标签文件,前两行不用管,第三行的0表示是否是困难样本标志,ship是框的类别。

所有标签的文件夹

imagesetfile是所要测试的图片

生成需要的预测文件 Task1_{%}.txt

自己写了一个程序,需要根据自己的需求做一些调整。

import os
import numpy as np
from mmdet.apis import init_detector, inference_detector
import torch


def obb2poly(obboxes):  # 如果生成的格式是obb,需要从obb转到poly
    obboxes = torch.tensor(obboxes)
    center, w, h, theta = torch.split(obboxes, [2, 1, 1, 1], dim=-1)
    Cos, Sin = torch.cos(theta), torch.sin(theta)

    vector1 = torch.cat(
        [w/2 * Cos, -w/2 * Sin], dim=-1)
    vector2 = torch.cat(
        [-h/2 * Sin, -h/2 * Cos], dim=-1)

    point1 = center + vector1 + vector2
    point2 = center + vector1 - vector2
    point3 = center - vector1 - vector2
    point4 = center - vector1 + vector2
    return torch.cat([point1, point2, point3, point4], dim=-1)

config_file = '/home/zhangxiao/OBBDetection-master/configs/obb/oriented_rcnn/faster_rcnn_orpn_r50_fpn_1x_dota10.py'
checkpoint_file = '/data0/zhangxiao/results/faster_rcnn/epoch_20.pth'

model = init_detector(config_file,checkpoint_file)

img_dir = '/data0/zhangxiao/SSDD/test/images'  # 需要测试图片的路径

imgs = os.listdir(img_dir)

with open('Task1_ship.txt', 'w') as f:  # 由于我检测的类别只用ship,所以就偷个懒。不适用检测多种类别的。
    for img in imgs:
        img2 = os.path.join(img_dir, img)
        result = inference_detector(model,img2)
        for res in result[0]:  # res = [273.06485  , 227.75511  ,  51.681828 ,  22.212055 ,   0.5608363,  0.9988368]
            out = obb2poly(res[:5])
            if res[5] < 0.6:
                continue
            f.write(img[:6] + ' ' + str(res[5]) + ' ' + ' '.join(str(int(a)) for a in out) + '\n')  # 这里我需要对文件名进行修改 
对dota_evaluation_task1.py进行调整

主要把之前三个文件夹的路径做一个修改,然后对应的类别数也要做一个调整。

最终得到结果

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/915572.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存