ABCNet的下载与训练–训练自己的数据集

  • Post author:
  • Post category:其他





前言

这段事件跑实验,正好用到了ABCNet, 中间遇到了很多的问题,特此记录,以避免大家再遇到这样的问题




一、ABCNet的下载与demo



1.下载

ABCNet是AdelaiDet中对于BAText的一个高效的端到端场景文本定位框架

是基于Detectron2的,所以首先要下载Detectron2

我的 Requirements:

Linux with Python = 3.7.11 ,cuda = 10.1,PyTorch = 1.8.1

pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

其他torch版本:


torch版本

# 下载 detectron2
git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
git checkout -f 9eb4831
cd ..
python -m pip install -e detectron2/
# 下载 AdelaiDet
git clone https://github.com/aim-uofa/AdelaiDet.git
cd AdelaiDet
python setup.py build develop



2. demo

先使用预训练的权重模型测试一下。

下载 CTW1500 数据集,

cd AdelaiDet/datasets
wget https://drive.google.com/file/d/1ntlnlnQHZisDoS_bgDvrcrYFomw9iTZ0/view?usp=sharing -O CTW1500.zip
unzip CTW1500.zip 
rm CTW1500.zip

下载model,再 demo

# Download ctw1500_attn_R_50.pth above
wget -O ctw1500_attn_R_50.pth https://universityofadelaide.box.com/shared/static/okeo5pvul5v5rxqh4yg8pcf805tzj2no.pth
python demo/demo.py \
    --config-file configs/BAText/CTW1500/attn_R_50.yaml \
    --input datasets/CTW1500/ctwtest_text_image/ \
    --opts MODEL.WEIGHTS ctw1500_attn_R_50.pth



二、训练自己的数据集



1. 使用标注工具windows_label_tool

链接:

windows_label_tool

提取码: exvx

格式如下(示例):

windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标(顺序如上图),后面是文本内容

4

45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,“DOUGLASTON”

50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,“E-313”

41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,“L164”

39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,“F.D.N.Y.”



2. 转换为json (很重要,json文件错了,会出很多问题)

我的标签格式为:每个txt文件中只有一行, 所以不需要标注个数

45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84||||“DOUGLASTON”

由于后面json转换代码的问题,由14个点改为了8个点即四对点

45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84||||“DOUGLASTON”

# 四对点的顺序 0 3 4 7 为顶点, 1 2 5 6 为控制点
0--1--2--3
|        |
7--6--5--4

所需classes.txt文件, 我的只有一类,所以只有 text

text

转换代码

# -*- coding: utf-8 -*-
"""
 @File    : convert_ann_to_json.py
 @Time    : 2020-8-17 16:13
 @Author  : yizuotian
 @Description    : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
"""
import argparse
import json
import os
import sys
import cv2
import numpy as np

def gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):
    """
    根据abcnet的gt标注生成coco格式的json标注
    :param abc_gt_dir: windows_label_tool标注工具生成标注文件目录
    :param abc_json_path: ABCNet训练需要json标注路径
    :param image_dir:
    :param classes_path: 类别文件路径
    :return:
    """
    # Desktop Latin_embed.
    # 这是标注列表,可以根据自己的改,但是中文在训练时需要下载 simsun.ttc 字体文件(新宋体)
    cV2 = [
    	"皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑",
   		"苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤",
    	"桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁",
    	"新", '0',  '1',  '2',  '3',  '4', '5',  '6',  '7',  '8',
   		'9',  'A',  'B',  'C',  'D',  'E', 'F',  'G',  'H',  'I', 
    	'J',  'K',  'L',  'M',  'N',  'O', 'P',  'Q',  'R',  'S', 
    	'T',  'U',  'V',  'W',  'X',  'Y', 'Z']

    dataset = {
        'licenses': [],
        'info': {},
        'categories': [],
        'images': [],
        'annotations': []
    }
    with open(classes_path) as f:
        classes = f.read().strip().split()
    for i, cls in enumerate(classes, 1):
        dataset['categories'].append({
            'id': i,
            'name': cls,
            'supercategory': 'beverage',
            'keypoints': ['mean',
                          'xmin',
                          'x2',
                          'x3',
                          'xmax',
                          'ymin',
                          'y2',
                          'y3',
                          'ymax',
                          'cross']  # only for BDN
        })

    def get_category_id(cls):
        for category in dataset['categories']:
            if category['name'] == cls:
                return category['id']

    # 遍历abcnet txt 标注
    indexes = sorted([f.split('.')[0]
                      for f in os.listdir(abc_gt_dir)])
    print(indexes)

    j = 1  # 标注边框id号
    for index in indexes:
        # if int(index) >3: continue
        # print('Processing: ' + index)
        im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))
        im_height, im_width = im.shape[:2]
        dataset['images'].append({
            'coco_url': '',
            'date_captured': '',
            'file_name': index + '.jpg',
            'flickr_url': '',
            'id': int(index.split('_')[-1]),  # img_1
            'license': 0,
            'width': im_width,
            'height': im_height
        })
        anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))

        with open(anno_file) as f:
            lines = [line for line in f.readlines() if line.strip()]
        # 没有清晰的标注,跳过
        if len(lines) <= 1:
            continue
        for i, line in enumerate(lines[1:]):
            elements = line.strip().split(',')
            # polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32)  # [14,(x,y)]
            # control_points = bezier_utils.polygon_to_bezier_pts(polygon, im)  # [8,(x,y)]
            # 由14个点改为8个点 
            control_points = np.array(elements[:16]).reshape((-1, 2)).astype(np.float32)  # [8,(x,y)]
            ct = elements[-1].replace('"', '').strip()

            cls = 'text'
            # segs = [float(kkpart) for kkpart in parts[:16]]
            segs = [float(kkpart) for kkpart in control_points.flatten()]
            xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
            yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]

            # 过滤越界边框
            if max(xt) > im_width or max(yt) > im_height:
                print('The annotation bounding box is outside of the image:{}'.format(index))
                print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))
                continue
            xmin = min([xt[0], xt[3], xt[4], xt[7]])
            ymin = min([yt[0], yt[3], yt[4], yt[7]])
            xmax = max([xt[0], xt[3], xt[4], xt[7]])
            ymax = max([yt[0], yt[3], yt[4], yt[7]])
            width = max(0, xmax - xmin + 1)
            height = max(0, ymax - ymin + 1)
            if width == 0 or height == 0:
                continue
			# 根据自己标签长度范围而定
            max_len = 7
            recs = [len(cV2) + 1 for ir in range(max_len)]

            ct = str(ct)
            # print('rec', ct)

            for ix, ict in enumerate(ct):
                if ix >= max_len:
                    continue
                if ict in cV2:
                    recs[ix] = cV2.index(ict)
                else:
                    recs[ix] = len(cV2)

            dataset['annotations'].append({
                'area': width * height,
                'bbox': [xmin, ymin, width, height],
                'category_id': get_category_id(cls),
                'id': j,
                'image_id': int(index.split('_')[-1]),  # img_1
                'iscrowd': 0,
                'bezier_pts': segs,
                'rec': recs
            })
            j += 1

    # 写入json文件
    folder = os.path.dirname(abc_json_path)
    if not os.path.exists(folder):
        os.makedirs(folder)
    with open(abc_json_path, 'w') as f:
        json.dump(dataset, f)


def main(args):
    gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)


if __name__ == '__main__':
    """
    Usage: python convert_ann_to_json.py \
    --ann-dir /path/to/gt \
    --image-dir /path/to/image \
    --dst-json-path train.json 
    """
    parse = argparse.ArgumentParser()
    parse.add_argument("--ann-dir", type=str, default=None) # 标签路径
    parse.add_argument("--image-dir", type=str, default=None) # 对应的图片路径
    parse.add_argument("--dst-json-path", type=str, default=None) # 保存json路径
    parse.add_argument("--classes-path", type=str, default='./classes.txt') 
    arguments = parse.parse_args() # sys.argv[1:]
    main(arguments)

部分json文件参考:

{
    "licenses": [],
    "info": {},
    "categories": [
        {
            "id": 1,
            "name": "text",
            "supercategory": "beverage",
            "keypoints": [
                "mean",
                "xmin",
                "x2",
                "x3",
                "xmax",
                "ymin",
                "y2",
                "y3",
                "ymax",
                "cross"
            ]
        }
    ],
    "images": [
        {
            "coco_url": "",
            "date_captured": "",
            "file_name": "000001.jpg",
            "flickr_url": "",
            "id": 1,
            "license": 0,
            "width": 720,
            "height": 1160
        },
        ...
    ],
    "annotations": [
        {
            "area": 6868.0,
            "bbox": [
                304.0,
                343.0,
                101.0,
                68.0
            ],
            "category_id": 1,
            "id": 1,
            "image_id": 1,
            "iscrowd": 0,
            "bezier_pts": [
                304.0,
                357.0,
                454.0,
                341.0,
                458.0,
                394.0,
                308.0,
                410.0,
                329.0,
                343.0,
                354.0,
                345.0,
                379.0,
                347.0,
                404.0,
                349.0
            ],
            "rec": [
                19,
                16,
                20,
                12,
                19,
                21,
                23
            ]
        },
        ...
    ]
}        



3. 训练

显卡:GeForce RTX 2080 Ti *2 batch_size = 2



1. 修改相关配置文件

  • 将制作好的data数据目录放在”AdelaiDet/datasets”目录

    我的目录结构是:
COCO
	--annotations
		--train.json
		--val.json
	--train
	--val	
  • 修改”adet/data/builtin.py”中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
_PREDEFINED_SPLITS_TEXT = {
"totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
"totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
...
# 以下为修改 (改为自己的)
"COCO_train": ("COCO/train/", "COCO/annotations/train.json"),
"COCO_val": ("COCO/val/", "COCO/annotations/val.json"),
  • 在需要训练的配置文件中指定数据集即可.以configs/BAText/Pretrain/Base-Pretrain.yaml为例
_BASE_: "../Base-BAText.yaml"
DATASETS:
  # 以下为修改(改为自己的)
  TRAIN: ("COCO_train",)
  TEST: ("COCO_val",)
  • label 中有中文, 需下载这个 simsun.ttc 字体文件 放于 AdelaiDet/simsun.ttc 中

    链接:

    simsun.ttc


    提取码:7tr2



2. 训练

OMP_NUM_THREADS=1 python tools/train_net.py \
    --config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \
    --num-gpus 2 \
    OUTPUT_DIR output/batext/pretrain/coco01  # 保存路径



3. 测试

MP_NUM_THREADS=1 python tools/train_net.py \
    --config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \
    --eval-only \
    --num-gpus 1 \
    OUTPUT_DIR output/batext/pretrain/coco01_result \ # 保存路径
    MODEL.WEIGHTS output/batext/pretrain/coco01/model_0019999.pth # model 路径



总结

中间遇到了很多问题,自己也参考了很多文章,特此记录,以便后来者参考。



inference

https://blog.csdn.net/weixin_43823854/article/details/108916498

https://www.tqwba.com/x_d/jishu/286353.html



版权声明:本文为weixin_42319164原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。