【代码篇】xml格式转为json格式(转为coco格式)python版

【代码篇】xml格式转为json格式(转为coco格式)python版,第1张

import json
import xml.etree.ElementTree as ET
import os
from lxml import etree
from tqdm import tqdm
import numpy as np



class GEN_Annotations:
    def __init__(self, filename):
        self.root = etree.Element("annotation")

        child1 = etree.SubElement(self.root, "folder")
        child1.text = "VOC2007"

        child2 = etree.SubElement(self.root, "filename")
        child2.text = filename

        child3 = etree.SubElement(self.root, "source")

        child4 = etree.SubElement(child3, "annotation")
        child4.text = "PASCAL VOC2007"
        child5 = etree.SubElement(child3, "database")
        child5.text = "Unknown"

        child6 = etree.SubElement(child3, "image")
        child6.text = "flickr"
        child7 = etree.SubElement(child3, "flickrid")
        child7.text = "35435"

    def set_size(self, witdh, height, channel):
        size = etree.SubElement(self.root, "size")
        widthn = etree.SubElement(size, "width")
        widthn.text = str(witdh)
        heightn = etree.SubElement(size, "height")
        heightn.text = str(height)
        channeln = etree.SubElement(size, "depth")
        channeln.text = str(channel)

    def savefile(self, filename):
        tree = etree.ElementTree(self.root)
        tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')

    def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
        object = etree.SubElement(self.root, "object")
        namen = etree.SubElement(object, "name")
        namen.text = label
        bndbox = etree.SubElement(object, "bndbox")
        xminn = etree.SubElement(bndbox, "xmin")
        xminn.text = str(xmin)
        yminn = etree.SubElement(bndbox, "ymin")
        yminn.text = str(ymin)
        xmaxn = etree.SubElement(bndbox, "xmax")
        xmaxn.text = str(xmax)
        ymaxn = etree.SubElement(bndbox, "ymax")
        ymaxn.text = str(ymax)

if __name__ == '__main__':
    base_path = r"E:-mytrain_person22-04-14\valid"
    jpg_path = os.path.join(base_path,"jpg")
    xml_path = os.path.join(base_path,"ann")
    json_path = os.path.join(base_path,"train.json")
    class_name = ["person"]
    img_dir = os.listdir(jpg_path)
    xml_dir = os.listdir(xml_path)
    assert len(img_dir) == len(xml_dir)

    # 解析json
    cat_info = []
    for i, cat in enumerate(class_name):
        cat_info.append({'name': cat, 'id': i + 1,"supercategory" : "none"})
    ret = {'images': [], 'annotations': [], "categories": []}
    print(cat_info)
    for image_id in tqdm(range(len(img_dir))):
        # 解析xml
        tree = ET.parse(xml_path + "/" + img_dir[i][:-4] + ".xml")
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)
        id = 1
        # 获取image_info
        image_info = {
            'file_name': img_dir[image_id],
            'height': h,
            'width': w,
            'id': image_id+1
        }
        ret['images'].append(image_info)

        # 获取categories
        ret["categories"] = cat_info
        for obj in root.iter('object'):
            cls = obj.find('name').text


            if cls in class_name:
                j = class_name.index(cls)
                xmlbox = obj.find('bndbox')
                b = [float(xmlbox.find('xmin').text), float(xmlbox.find('ymin').text), float(xmlbox.find('xmax').text),
                     float(xmlbox.find('ymax').text)]
                bbox_w = b[2] - b[0]
                bbox_h = b[3] - b[1]
                bbox_x = b[0] + bbox_w//2
                bbox_y = b[1] + bbox_h//2
                points= [[b[0], b[1]], [b[0] + b[2], b[1]], [b[2], b[1] + b[3]], [b[0], b[1] + b[3]]]
                # 获取annotations
                annotations = {
                    'image_id': image_id+1,
                    'id': int(len(ret['annotations']) + 1),
                    'category_id': class_name.index(cls) + 1,
                    'bbox': [bbox_x, bbox_y, bbox_w, bbox_h],
                    'area': bbox_w * bbox_h,
                    'segmentation': [np.asarray(points).flatten().tolist()],
                    'iscrowd': 0,}
                ret['annotations'].append(annotations)

            id +=1

    json.dump(ret, open(json_path, 'w'))




欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/917988.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存