- 数据集准备
数据集格式
文件夹格式:Data/ #保存Dota数据集的目录
Train #存放images和labelTxt的文件夹
Images#存放所有训练集图片的文件夹
labelTxt #存放所有训练集txt标注文件的文件夹
LabelTxt中的txt文件可通过转换脚本ro_xml2txt.py将RolabelImg标注的xml转换成DOTA格式的txt文件。
其中--xml_dir为需要转换的存放xml的路径。
--output_dir为转换后的数据集存放路径。
- 修改
- mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py
- 修改mmrotate/mmrotate/datasets/rolabel.py
三、训练
训练命令格式:
# 单 GPU 训练
python tools/train.py ${CONFIG_FILE} [optional arguments]
# 多 GPU 训练
bash tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]
说明:
config_file:模型配置文件的路径
gpu_num:使用 GPU 的数量
--work-dir:设置存放训练生成文件的路径
--resume-from:设置恢复训练的模型检查点文件的路径
--no-validate(不建议):设置训练时不验证模型
--seed:设置随机种子,便于复现结果
这里以oriented_rcnn为例,cd 到yuml_web目录下,运行命令:
Python mmrotate/tools/train.py /
mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py
即可开始训练模型。其中训练产生的所有日志文件都保存在work_dir中。
- 验证
# 单 GPU 测试
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]
# 多 GPU 测试
bash tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} \
[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]
config_file:模型配置文件的路径
checkpoint_file:模型检查点文件的路径
gpu_num:使用的 GPU 数量
--out:设置输出 pkl 测试结果文件的路径
--work-dir:设置存放 json 日志文件的路径
--eval:设置度量指标(voc:mAP, recall | coco:bbox, segm, proposal)
--show:设置显示有预测框的测试集图像
--show-dir:设置存放有预测框的测试集图像的路径
--show-score-thr:设置显示预测框的阈值,默认值为 0.3
--fuse-conv-bn: 设置融合卷积层和批归一化层,能够稍微提升推理速度
这里以oriented_rcnn为例,建议在work_dir中需要验证的pth模型文件复制到yuml_web/checkpoints/下,cd 到yuml_web目录下,运行命令:
python mmrotate/tools/test.py mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py /
checkpoints/pth模型文件名 --show
注:验证的数据集为test.txt中的图片名
转换脚本:
# -*- coding: utf-8 -*- # @Time : 2021/5/2 15:42 # @Author : Bob.Xu # @Site : 根据rolabelimg标注的xml文件转换为txt格式文件 # @File : xml2txt.py # @Software: PyCharm # 将标记后的xml文件转为advanceeast训练的格式 import os import xml.etree.ElementTree as ET import shutil import glob import time import math import argparse def rotatePoint(xc, yc, xp, yp, theta): ''' xc:x中心点 yc:y中心点 xp:x边长度 yp:y边长度 thete:旋转角度 ''' xoff = xp - xc; yoff = yp - yc; cosTheta = math.cos(theta) sinTheta = math.sin(theta) pResx = cosTheta * xoff + sinTheta * yoff pResy = - sinTheta * xoff + cosTheta * yoff # pRes = (xc + pResx, yc + pResy) return str(xc + pResx), str(yc + pResy) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--xml_dir', type=str,default='yaodai_data/Annotations', help='Directory of images and xml.') parser.add_argument('--output_dir', type=str,default='yaodai_data/labelTxt/', help='Directory of output.') a=parser.parse_args() path = a.xml_dir file = os.listdir(path) file = glob.glob(path + "/*.xml") output_dir = a.output_dir n = 0 if not os.path.exists(output_dir): os.makedirs(output_dir) for filename in file: # start=time.time() first = os.path.splitext(filename)[0] last = os.path.splitext(filename)[1] if last == ".xml": # print(first,last) next = first.split("/") name = next[-1] + ".xml" n = n + 1 print("正在处理第{}个xml文件,名称为{}".format(n, name)) filetxt = first + ".txt" f = open(filetxt, 'w', encoding='utf-8') aa = [] tree = ET.parse(filename) root = tree.getroot() for tt in root.iter("object"): if tt.find('bndbox'): lefttopx = tt.find("bndbox")[0].text lefttopy = tt.find("bndbox")[1].text righttopx = tt.find("bndbox")[2].text righttopy = tt.find("bndbox")[1].text rightdownx = tt.find("bndbox")[2].text rightdowny = tt.find("bndbox")[3].text leftdownx = tt.find("bndbox")[0].text leftdowny = tt.find("bndbox")[3].text tb = tt.find("name").text df = tt.find("difficult").text aa = list([lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df]) bb = " ".join(aa) f.writelines(bb) f.writelines("\n") elif tt.find('robndbox'): cx = float(tt.find("robndbox")[0].text) cy = float(tt.find("robndbox")[1].text) w = float(tt.find("robndbox")[2].text) h = float(tt.find("robndbox")[3].text) angle = float(tt.find("robndbox")[4].text) lefttopx, lefttopy = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle) righttopx, righttopy = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle) rightdownx, rightdowny = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle) leftdownx, leftdowny = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle) tb = tt.find("name").text df = tt.find("difficult").text aa = list( [lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df]) bb = " ".join(aa) f.writelines(bb) f.writelines("\n") else: continue
替换:
/mmrotate/mmrotate/datasets/__init__.py
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataset # noqa: F401, F403
from .dota import DOTADataset # noqa: F401, F403
from .hrsc import HRSCDataset # noqa: F401, F403
from .pipelines import * # noqa: F401, F403
from .sar import SARDataset # noqa: F401, F403
from .rolabel import ROLDataset # noqa: F401, F403
__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset', 'ROLDataset']
"/mmrotate/mmrotate/datasets/rolabel.py"
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import re
import tempfile
import time
import zipfile
from collections import defaultdict
from functools import partial
import mmcv
import numpy as np
import torch
from mmcv.ops import nms_rotated
from mmdet.datasets.custom import CustomDataset
from mmrotate.core import obb2poly_np, poly2obb_np
from mmrotate.core.evaluation import eval_rbbox_map
from .builder import ROTATED_DATASETS
@ROTATED_DATASETS.register_module()
class ROLDataset(CustomDataset):
"""DOTA dataset for detection.
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
version (str, optional): Angle representations. Defaults to 'oc'.
difficulty (bool, optional): The difficulty threshold of GT.
"""
class_dir = 'yuml_web/data/class.txt'
with open(class_dir, "r", encoding="utf-8") as f:
a = [i.strip() for i in f.readlines()]
CLASSES = a
def __init__(self,
ann_file,
pipeline,
version='oc',
difficulty=100,
image_type='.png',
**kwargs):
self.version = version
self.difficulty = difficulty
self.image_type = image_type
super(ROLDataset, self).__init__(ann_file, pipeline, **kwargs)
def __len__(self):
"""Total number of samples of data."""
return len(self.data_infos)
def load_annotations(self, ann_folder):
"""
Args:
ann_folder: folder that contains DOTA v1 annotations txt files
"""
cls_map = {c: i
for i, c in enumerate(self.CLASSES)
} # in mmdet v2.0 label is 0-based
ann_files = glob.glob(ann_folder + '/*.txt')
data_infos = []
if not ann_files: # test phase
ann_files = glob.glob(ann_folder + '/*'+self.image_type)
for ann_file in ann_files:
data_info = {}
img_id = osp.split(ann_file)[1][:-4]
img_name = img_id + self.image_type
data_info['filename'] = img_name
data_info['ann'] = {}
data_info['ann']['bboxes'] = []
data_info['ann']['labels'] = []
data_infos.append(data_info)
else:
for ann_file in ann_files:
data_info = {}
img_id = osp.split(ann_file)[1][:-4]
img_name = img_id + self.image_type
data_info['filename'] = img_name
data_info['ann'] = {}
gt_bboxes = []
gt_labels = []
gt_polygons = []
gt_bboxes_ignore = []
gt_labels_ignore = []
gt_polygons_ignore = []
if os.path.getsize(ann_file) == 0:
continue
with open(ann_file) as f:
s = f.readlines()
for si in s:
bbox_info = si.split()
poly = np.array(bbox_info[:8], dtype=np.float32)
try:
x, y, w, h, a = poly2obb_np(poly, self.version)
except: # noqa: E722
continue
cls_name = bbox_info[8]
difficulty = int(bbox_info[9])
label = cls_map[cls_name]
if difficulty > self.difficulty:
pass
else:
gt_bboxes.append([x, y, w, h, a])
gt_labels.append(label)
gt_polygons.append(poly)
if gt_bboxes:
data_info['ann']['bboxes'] = np.array(
gt_bboxes, dtype=np.float32)
data_info['ann']['labels'] = np.array(
gt_labels, dtype=np.int64)
data_info['ann']['polygons'] = np.array(
gt_polygons, dtype=np.float32)
else:
data_info['ann']['bboxes'] = np.zeros((0, 5),
dtype=np.float32)
data_info['ann']['labels'] = np.array([], dtype=np.int64)
data_info['ann']['polygons'] = np.zeros((0, 8),
dtype=np.float32)
if gt_polygons_ignore:
data_info['ann']['bboxes_ignore'] = np.array(
gt_bboxes_ignore, dtype=np.float32)
data_info['ann']['labels_ignore'] = np.array(
gt_labels_ignore, dtype=np.int64)
data_info['ann']['polygons_ignore'] = np.array(
gt_polygons_ignore, dtype=np.float32)
else:
data_info['ann']['bboxes_ignore'] = np.zeros(
(0, 5), dtype=np.float32)
data_info['ann']['labels_ignore'] = np.array(
[], dtype=np.int64)
data_info['ann']['polygons_ignore'] = np.zeros(
(0, 8), dtype=np.float32)
data_infos.append(data_info)
self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)]
return data_infos
def _filter_imgs(self):
"""Filter images without ground truths."""
valid_inds = []
for i, data_info in enumerate(self.data_infos):
if data_info['ann']['labels'].size > 0:
valid_inds.append(i)
return valid_inds
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
All set to 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
def evaluate(self,
results,
metric='mAP',
logger=None,
proposal_nums=(100, 300, 1000),
iou_thr=0.5,
scale_ranges=None,
nproc=4):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thr (float | list[float]): IoU threshold. It must be a float
when evaluating mAP, and can be a list when evaluating recall.
Default: 0.5.
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
Default: None.
nproc (int): Processes used for computing TP and FP.
Default: 4.
"""
nproc = min(nproc, os.cpu_count())
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mAP']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
annotations = [self.get_ann_info(i) for i in range(len(self))]
eval_results = {}
if metric == 'mAP':
assert isinstance(iou_thr, float)
mean_ap, _ = eval_rbbox_map(
results,
annotations,
scale_ranges=scale_ranges,
iou_thr=iou_thr,
dataset=self.CLASSES,
logger=logger,
nproc=nproc)
eval_results['mAP'] = mean_ap
else:
raise NotImplementedError
return eval_results
def merge_det(self, results, nproc=4):
"""Merging patch bboxes into full image.
Args:
results (list): Testing results of the dataset.
nproc (int): number of process. Default: 4.
"""
collector = defaultdict(list)
for idx in range(len(self)):
result = results[idx]
img_id = self.img_ids[idx]
splitname = img_id.split('__')
oriname = splitname[0]
pattern1 = re.compile(r'__\d+___\d+')
x_y = re.findall(pattern1, img_id)
x_y_2 = re.findall(r'\d+', x_y[0])
x, y = int(x_y_2[0]), int(x_y_2[1])
new_result = []
for i, dets in enumerate(result):
bboxes, scores = dets[:, :-1], dets[:, [-1]]
ori_bboxes = bboxes.copy()
ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array(
[x, y], dtype=np.float32)
labels = np.zeros((bboxes.shape[0], 1)) + i
new_result.append(
np.concatenate([labels, ori_bboxes, scores], axis=1))
new_result = np.concatenate(new_result, axis=0)
collector[oriname].append(new_result)
merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)
if nproc <= 1:
print('Single processing')
merged_results = mmcv.track_iter_progress(
(map(merge_func, collector.items()), len(collector)))
else:
print('Multiple processing')
merged_results = mmcv.track_parallel_progress(
merge_func, list(collector.items()), nproc)
return zip(*merged_results)
def _results2submission(self, id_list, dets_list, out_folder=None):
"""Generate the submission of full images.
Args:
id_list (list): Id of images.
dets_list (list): Detection results of per class.
out_folder (str, optional): Folder of submission.
"""
if osp.exists(out_folder):
raise ValueError(f'The out_folder should be a non-exist path, '
f'but {out_folder} is existing')
os.makedirs(out_folder)
files = [
osp.join(out_folder, 'Task1_' + cls + '.txt')
for cls in self.CLASSES
]
file_objs = [open(f, 'w') for f in files]
for img_id, dets_per_cls in zip(id_list, dets_list):
for f, dets in zip(file_objs, dets_per_cls):
if dets.size == 0:
continue
bboxes = obb2poly_np(dets, self.version)
for bbox in bboxes:
txt_element = [img_id, str(bbox[-1])
] + [f'{p:.2f}' for p in bbox[:-1]]
f.writelines(' '.join(txt_element) + '\n')
for f in file_objs:
f.close()
target_name = osp.split(out_folder)[-1]
with zipfile.ZipFile(
osp.join(out_folder, target_name + '.zip'), 'w',
zipfile.ZIP_DEFLATED) as t:
for f in files:
t.write(f, osp.split(f)[-1])
return files
def format_results(self, results, submission_dir=None, nproc=4, **kwargs):
"""Format the results to submission text (standard format for DOTA
evaluation).
Args:
results (list): Testing results of the dataset.
submission_dir (str, optional): The folder that contains submission
files. If not specified, a temp folder will be created.
Default: None.
nproc (int, optional): number of process.
Returns:
tuple:
- result_files (dict): a dict containing the json filepaths
- tmp_dir (str): the temporal directory created for saving \
json files when submission_dir is not specified.
"""
nproc = min(nproc, os.cpu_count())
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
f'The length of results is not equal to '
f'the dataset len: {len(results)} != {len(self)}')
if submission_dir is None:
submission_dir = tempfile.TemporaryDirectory()
else:
tmp_dir = None
print('\nMerging patch bboxes into full image!!!')
start_time = time.time()
id_list, dets_list = self.merge_det(results, nproc)
stop_time = time.time()
print(f'Used time: {(stop_time - start_time):.1f} s')
result_files = self._results2submission(id_list, dets_list,
submission_dir)
return result_files, tmp_dir
def _merge_func(info, CLASSES, iou_thr):
"""Merging patch bboxes into full image.
Args:
CLASSES (list): Label category.
iou_thr (float): Threshold of IoU.
"""
img_id, label_dets = info
label_dets = np.concatenate(label_dets, axis=0)
labels, dets = label_dets[:, 0], label_dets[:, 1:]
big_img_results = []
for i in range(len(CLASSES)):
if len(dets[labels == i]) == 0:
big_img_results.append(dets[labels == i])
else:
try:
cls_dets = torch.from_numpy(dets[labels == i]).cuda()
except: # noqa: E722
cls_dets = torch.from_numpy(dets[labels == i])
nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1],
iou_thr)
big_img_results.append(nms_dets.cpu().numpy())
return img_id, big_img_results
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)