luckygay

  • 5

    获得赞
  • 3

    发布的文章
  • 0

    答辩的项目

YOLO(v3)PyTorch版 训练自己的数据集

YOLO PyTorch

最后更新 2020-05-11 14:00 阅读 4267

最后更新 2020-05-11 14:00

阅读 4267

YOLO PyTorch

Yolo v3比Frcnn好调试多了……就是数据集准备比较麻烦…… 但是好Debug,linux和win10差别不大……

代码链接(cpu版本):https://github.com/eriklindernoren/PyTorch-YOLOv3/issues/340

这个代码……作者说的太草率了……data怎么准备都没说清……好歹issue里面有大神解答,给了傻瓜版教程,运行他的几个脚本就好了,data文件夹就准备好啦!

data文件准备,按照这个数据集准备

虽然这个作者是用它来训练coco数据集,但是data整个是个四不像……不用json不用xml用txt……所以训练自己的比较麻烦…… 准备好data,还有修改config/yolov3.cfg文件。

参考链接:链接文字

打开yolov3.cfg文件后,搜索yolo,共有三处yolo,下面以一处的修改作为示例。

[convolutional] #紧挨着[yolo]上面的[convolutional]
size=1
stride=1
pad=1
filters=21 #filters=3*(你的class种类数+5)
activation=linear

[yolo]
mask = 6,7,8
anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
classes=2 #修改classes
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=0 #显存大的写1 反之0

除此之外,cfg中的其他参数可以参考这个链接来进行修改,比如可以修改一些数据增强的参数,如果想要接着上次训练的weight继续训练,就参考这个链接进行微调(但是我使用的代码不支持clear操作,只能使用第二种方法)。

然后开始训练吧!个人感觉yolov3学的效果不是很好,frcnn训练了十轮能达到的效果,yolov3可能要80轮左右,开始前几十轮mAP很低就多加几轮试试,issue里面提到这个代码训练100轮,coco也达不到作者所说的mAP……所以……慎重……男票用这个帮我把自己的数据集跑到了92%左右的mAP,效果还是很好的。 准备数据集的那个代码有些bug调不出来,所以还是使用第一个的代码吧。 记录一下用过的脚本之 voc的xml转化为coco的json

import os
import json
import xml.etree.ElementTree as ET
import numpy as np
import cv2

def _isArrayLike(obj):
    return hasattr(obj, '__iter__') and hasattr(obj, '__len__')

class voc2coco:
    def __init__(self, devkit_path=None, year=None):
        # self.classes = ('__background__',
        #                 'aeroplane', 'bicycle', 'bird', 'boat',
        #                 'bottle', 'bus', 'car', 'cat', 'chair',
        #                 'cow', 'diningtable', 'dog', 'horse',
        #                 'motorbike', 'person', 'pottedplant',
        #                 'sheep', 'sofa', 'train', 'tvmonitor')
        self.classes = ('none',
                        'll', 'rr') #写你自己的class
        self.num_classes = len(self.classes)
        assert 'VOCdevkit' in devkit_path, 'VOC地址不存在: {}'.format(devkit_path)
        self.data_path = os.path.join(devkit_path, 'VOC' + year)
        self.annotaions_path = os.path.join(self.data_path, 'Annotations')
        self.image_set_path = os.path.join(self.data_path, 'ImageSets')
        self.year = year
        self.categories_to_ids_map = self._get_categories_to_ids_map()
        self.categories_msg = self._categories_msg_generator()
    def _load_annotation(self, ids=[]):
        ids = ids if _isArrayLike(ids) else [ids]
        image_msg = []
        annotation_msg = []
        annotation_id = 1
        for index in ids:
            filename = '{:0>6}'.format(index)
            json_file = os.path.join(self.data_path, 'Segmentation_json', filename + '.json')
            num=0
            if os.path.exists(json_file):
                img_file = os.path.join(self.data_path, 'JPEGImages', filename + '.jpg')
                im = cv2.imread(img_file)
                width = im.shape[1]
                height = im.shape[0]
                seg_data = json.load(open(json_file, 'r'))
                assert type(seg_data) == type(dict()), 'annotation file format {} not supported'.format(type(seg_data))
                for shape in seg_data['shapes']:
                    seg_msg = []
                    for point in shape['points']:
                        seg_msg += point
                    one_ann_msg = {"segmentation": [seg_msg],
                                   "area": self._area_computer(shape['points']),
                                   "iscrowd": 0,
                                   "image_id": int(index),
                                   "bbox": self._points_to_mbr(shape['points']),
                                   "category_id": self.categories_to_ids_map[shape['label']],
                                   "id": annotation_id,
                                   "ignore": 0
                                   }
                    annotation_msg.append(one_ann_msg)
                    annotation_id += 1
            else:
                xml_file = os.path.join(self.annotaions_path, filename + '.xml')
                tree = ET.parse(xml_file)
                size = tree.find('size')
                objs = tree.findall('object')
                width = size.find('width').text
                height = size.find('height').text
                for obj in objs:
                    bndbox = obj.find('bndbox')
                    [xmin, xmax, ymin, ymax] \
                        = [int(bndbox.find('xmin').text) - 1, int(bndbox.find('xmax').text),
                           int(bndbox.find('ymin').text) - 1, int(bndbox.find('ymax').text)]
                    if xmin < 0:
                        xmin = 0
                    if ymin < 0:
                        ymin = 0
                    bbox = [xmin, xmax, ymin, ymax]
                    one_ann_msg = {"segmentation": self._bbox_to_mask(bbox),
                                   "area": self._bbox_area_computer(bbox),
                                   "iscrowd": 0,
                                   "image_id": int(num),
                                   "bbox": [xmin, ymin, xmax - xmin, ymax - ymin],
                                   "category_id": self.categories_to_ids_map[obj.find('name').text],
                                   "id": annotation_id,
                                   "ignore": 0
                                   }
                    annotation_msg.append(one_ann_msg)
                    annotation_id += 1
            one_image_msg = {"file_name": filename + ".jpg",
                             "height": int(height),
                             "width": int(width),
                             "id": int(num)
                             }
            image_msg.append(one_image_msg)
            num=num+1
        return image_msg, annotation_msg
    def _bbox_to_mask(self, bbox):
        assert len(bbox) == 4, 'Wrong bndbox!'
        mask = [bbox[0], bbox[2], bbox[0], bbox[3], bbox[1], bbox[3], bbox[1], bbox[2]]
        return [mask]
    def _bbox_area_computer(self, bbox):
        width = bbox[1] - bbox[0]
        height = bbox[3] - bbox[2]
        return width * height
    def _save_json_file(self, filename=None, data=None):
        json_path = os.path.join(self.data_path, 'cocoformatJson')
        assert filename is not None, 'lack filename'
        if os.path.exists(json_path) == False:
            os.mkdir(json_path)
        if not filename.endswith('.json'):
            filename += '.json'
        assert type(data) == type(dict()), 'data format {} not supported'.format(type(data))
        with open(os.path.join(json_path, filename), 'w') as f:
            f.write(json.dumps(data))
    def _get_categories_to_ids_map(self):
        return dict(zip(self.classes, range(self.num_classes)))
    def _get_all_indexs(self):
        ids = []
        for root, dirs, files in os.walk(self.annotaions_path, topdown=False):
            for f in files:
                if str(f).endswith('.xml'):
                    id = int(str(f).strip('.xml'))
                    ids.append(id)
        assert ids is not None, 'There is none xml file in {}'.format(self.annotaions_path)
        return ids
    def _get_indexs_by_image_set(self, image_set=None):
        if image_set is None:
            return self._get_all_indexs()
        else:
            image_set_path = os.path.join(self.image_set_path, 'Main', image_set + '.txt')
            assert os.path.exists(image_set_path), 'Path does not exist: {}'.format(image_set_path)
            with open(image_set_path) as f:
                ids = [x.strip() for x in f.readlines()]
            return ids
    def _points_to_mbr(self, points):
        assert _isArrayLike(points), 'Points should be array like!'
        x = [point[0] for point in points]
        y = [point[1] for point in points]
        assert len(x) == len(y), 'Wrong point quantity'
        xmin, xmax, ymin, ymax = min(x), max(x), min(y), max(y)
        height = ymax - ymin
        width = xmax - xmin
        return [xmin, ymin, width, height]
    def _categories_msg_generator(self):
        categories_msg = []
        for category in self.classes:
            if category == 'none':
                continue
            one_categories_msg = {"supercategory": "none",
                                  "id": self.categories_to_ids_map[category],
                                  "name": category
                                  }
            categories_msg.append(one_categories_msg)
        return categories_msg
    def _area_computer(self, points):
        assert _isArrayLike(points), 'Points should be array like!'
        tmp_contour = []
        for point in points:
            tmp_contour.append([point])
        contour = np.array(tmp_contour, dtype=np.int32)
        area = cv2.contourArea(contour)
        return area
    def voc_to_coco_converter(self):
        img_sets = ['trainval', 'test']
        for img_set in img_sets:
            ids = self._get_indexs_by_image_set(img_set)
            img_msg, ann_msg = self._load_annotation(ids)
            result_json = {"images": img_msg,
                           "type": "instances",
                           "annotations": ann_msg,
                           "categories": self.categories_msg}
            self._save_json_file('voc_' + self.year + '_' + img_set, result_json)

def demo():
    # 转换pascal地址是'./VOC2007/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt'
    converter = voc2coco('D:\\Coding\\python\\data_myself\\VOCdevkit2007', '2007')
    converter.voc_to_coco_converter()

if __name__ == "__main__":
    demo()



本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,转载请附上原文出处链接和本声明。
本文链接地址:https://www.flyai.com/article/452
讨论
500字
表情
发送
删除确认
是否删除该条评论?
取消 删除