首页 > 技术文章 > 修改xml成正方形,保存

crazybird123 2019-01-28 20:10 原文

import os
import xml.etree.ElementTree as ET
import cv2


origin_ann_dir = 'D:/Data/MyAnnoData/vmwareData/Annotations/'
new_ann_dir = 'D:/Data/MyAnnoData/vmwareData/save/Annotations/'
new_img_dir = 'D:/Data/MyAnnoData/vmwareData/save/JPEGImages/'
image_dir = 'D:/Data/MyAnnoData/vmwareData/JPEGImages/'

#for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):
#    for filename in filenames:  #xml文件名
file = open("D:/Data/MyAnnoData/vmwareData/0.txt")
for filename in file:
    filename = filename.strip("\n")
#    if os.path.isfile(r'%s%s' % (origin_ann_dir, filename)):
#        origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))
#        new_ann_path = os.path.join(r'%s%s' % (new_ann_dir, filename))

    origin_ann_path = origin_ann_dir + filename
    new_ann_path = new_ann_dir + filename
    tree = ET.parse(origin_ann_path)
    root = tree.getroot()

    for image_size in root.findall('size'):
        imgwidth = int(image_size.find('width').text)     #读取图像宽度
        imgheight = int(image_size.find('height').text)   #读取图像高度

        image_size.find('width').text = str(imgwidth)
        image_size.find('height').text = str(imgheight)

    image_name =filename.split('.xml')
    imgpath = image_dir + image_name[0] + '.jpg'
    img = cv2.imread(imgpath)
    if not img.data:
        break

    maxl = max(imgheight, imgwidth)  #图片size是maxl*maxl
    paddingleft = (maxl - imgwidth) >> 1
    paddingright = (maxl - imgwidth) >> 1
    paddingbottom = (maxl - imgheight) >> 1
    paddingtop = (maxl - imgheight) >> 1
    saveimg = cv2.copyMakeBorder(img, paddingtop, paddingbottom, paddingleft, paddingright, cv2.BORDER_CONSTANT,value=0)
    cv2.imwrite(new_img_dir  + image_name[0] + '.jpg', saveimg)

    for image_size in root.findall('size'):
        image_size.find('width').text = str(maxl)
        image_size.find('height').text = str(maxl)


    for object in root.findall('object'):
        name = str(object.find('name').text) #标注对象名
        v_bndbox = object.find('bndbox')   #目标框位置
        x1 = int(v_bndbox.find('xmin').text)
        y1 = int(v_bndbox.find('ymin').text)
        x2 = int(v_bndbox.find('xmax').text)
        y2 = int(v_bndbox.find('ymax').text)

        x11 = x1 + paddingleft
        y11 = y1 + paddingtop
        x22 = x11 + (x2-x1+1)
        y22 = y11 + (y2-y1+1)

        v_bndbox.find('xmin').text = str(x11)
        v_bndbox.find('ymin').text = str(y11)
        v_bndbox.find('xmax').text = str(x22)
        v_bndbox.find('ymax').text = str(y22)

    print(filename)
    tree.write(new_ann_path)

 

推荐阅读