前言
这段事件跑实验,正好用到了ABCNet, 中间遇到了很多的问题,特此记录,以避免大家再遇到这样的问题
一、ABCNet的下载与demo
1.下载
ABCNet是AdelaiDet中对于BAText的一个高效的端到端场景文本定位框架
是基于Detectron2的,所以首先要下载Detectron2
我的 Requirements:
Linux with Python = 3.7.11 ,cuda = 10.1,PyTorch = 1.8.1
pip install torch==1.8.1+cu101 torchvision==0.9.1+cu101 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
其他torch版本:
# 下载 detectron2
git clone https://github.com/facebookresearch/detectron2.git
cd detectron2
git checkout -f 9eb4831
cd ..
python -m pip install -e detectron2/
# 下载 AdelaiDet
git clone https://github.com/aim-uofa/AdelaiDet.git
cd AdelaiDet
python setup.py build develop
2. demo
先使用预训练的权重模型测试一下。
下载 CTW1500 数据集,
cd AdelaiDet/datasets
wget https://drive.google.com/file/d/1ntlnlnQHZisDoS_bgDvrcrYFomw9iTZ0/view?usp=sharing -O CTW1500.zip
unzip CTW1500.zip
rm CTW1500.zip
下载model,再 demo
# Download ctw1500_attn_R_50.pth above
wget -O ctw1500_attn_R_50.pth https://universityofadelaide.box.com/shared/static/okeo5pvul5v5rxqh4yg8pcf805tzj2no.pth
python demo/demo.py \
--config-file configs/BAText/CTW1500/attn_R_50.yaml \
--input datasets/CTW1500/ctwtest_text_image/ \
--opts MODEL.WEIGHTS ctw1500_attn_R_50.pth
二、训练自己的数据集
1. 使用标注工具windows_label_tool
链接:
windows_label_tool
提取码: exvx
格式如下(示例):
windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标(顺序如上图),后面是文本内容
4
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,“DOUGLASTON”
50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,“E-313”
41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,“L164”
39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,“F.D.N.Y.”
2. 转换为json (很重要,json文件错了,会出很多问题)
我的标签格式为:每个txt文件中只有一行, 所以不需要标注个数
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84||||“DOUGLASTON”
由于后面json转换代码的问题,由14个点改为了8个点即四对点
45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84||||“DOUGLASTON”
# 四对点的顺序 0 3 4 7 为顶点, 1 2 5 6 为控制点
0--1--2--3
| |
7--6--5--4
所需classes.txt文件, 我的只有一类,所以只有 text
text
转换代码
# -*- coding: utf-8 -*-
"""
@File : convert_ann_to_json.py
@Time : 2020-8-17 16:13
@Author : yizuotian
@Description : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
"""
import argparse
import json
import os
import sys
import cv2
import numpy as np
def gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):
"""
根据abcnet的gt标注生成coco格式的json标注
:param abc_gt_dir: windows_label_tool标注工具生成标注文件目录
:param abc_json_path: ABCNet训练需要json标注路径
:param image_dir:
:param classes_path: 类别文件路径
:return:
"""
# Desktop Latin_embed.
# 这是标注列表,可以根据自己的改,但是中文在训练时需要下载 simsun.ttc 字体文件(新宋体)
cV2 = [
"皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑",
"苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤",
"桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁",
"新", '0', '1', '2', '3', '4', '5', '6', '7', '8',
'9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S',
'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
dataset = {
'licenses': [],
'info': {},
'categories': [],
'images': [],
'annotations': []
}
with open(classes_path) as f:
classes = f.read().strip().split()
for i, cls in enumerate(classes, 1):
dataset['categories'].append({
'id': i,
'name': cls,
'supercategory': 'beverage',
'keypoints': ['mean',
'xmin',
'x2',
'x3',
'xmax',
'ymin',
'y2',
'y3',
'ymax',
'cross'] # only for BDN
})
def get_category_id(cls):
for category in dataset['categories']:
if category['name'] == cls:
return category['id']
# 遍历abcnet txt 标注
indexes = sorted([f.split('.')[0]
for f in os.listdir(abc_gt_dir)])
print(indexes)
j = 1 # 标注边框id号
for index in indexes:
# if int(index) >3: continue
# print('Processing: ' + index)
im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))
im_height, im_width = im.shape[:2]
dataset['images'].append({
'coco_url': '',
'date_captured': '',
'file_name': index + '.jpg',
'flickr_url': '',
'id': int(index.split('_')[-1]), # img_1
'license': 0,
'width': im_width,
'height': im_height
})
anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))
with open(anno_file) as f:
lines = [line for line in f.readlines() if line.strip()]
# 没有清晰的标注,跳过
if len(lines) <= 1:
continue
for i, line in enumerate(lines[1:]):
elements = line.strip().split(',')
# polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32) # [14,(x,y)]
# control_points = bezier_utils.polygon_to_bezier_pts(polygon, im) # [8,(x,y)]
# 由14个点改为8个点
control_points = np.array(elements[:16]).reshape((-1, 2)).astype(np.float32) # [8,(x,y)]
ct = elements[-1].replace('"', '').strip()
cls = 'text'
# segs = [float(kkpart) for kkpart in parts[:16]]
segs = [float(kkpart) for kkpart in control_points.flatten()]
xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]
# 过滤越界边框
if max(xt) > im_width or max(yt) > im_height:
print('The annotation bounding box is outside of the image:{}'.format(index))
print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))
continue
xmin = min([xt[0], xt[3], xt[4], xt[7]])
ymin = min([yt[0], yt[3], yt[4], yt[7]])
xmax = max([xt[0], xt[3], xt[4], xt[7]])
ymax = max([yt[0], yt[3], yt[4], yt[7]])
width = max(0, xmax - xmin + 1)
height = max(0, ymax - ymin + 1)
if width == 0 or height == 0:
continue
# 根据自己标签长度范围而定
max_len = 7
recs = [len(cV2) + 1 for ir in range(max_len)]
ct = str(ct)
# print('rec', ct)
for ix, ict in enumerate(ct):
if ix >= max_len:
continue
if ict in cV2:
recs[ix] = cV2.index(ict)
else:
recs[ix] = len(cV2)
dataset['annotations'].append({
'area': width * height,
'bbox': [xmin, ymin, width, height],
'category_id': get_category_id(cls),
'id': j,
'image_id': int(index.split('_')[-1]), # img_1
'iscrowd': 0,
'bezier_pts': segs,
'rec': recs
})
j += 1
# 写入json文件
folder = os.path.dirname(abc_json_path)
if not os.path.exists(folder):
os.makedirs(folder)
with open(abc_json_path, 'w') as f:
json.dump(dataset, f)
def main(args):
gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)
if __name__ == '__main__':
"""
Usage: python convert_ann_to_json.py \
--ann-dir /path/to/gt \
--image-dir /path/to/image \
--dst-json-path train.json
"""
parse = argparse.ArgumentParser()
parse.add_argument("--ann-dir", type=str, default=None) # 标签路径
parse.add_argument("--image-dir", type=str, default=None) # 对应的图片路径
parse.add_argument("--dst-json-path", type=str, default=None) # 保存json路径
parse.add_argument("--classes-path", type=str, default='./classes.txt')
arguments = parse.parse_args() # sys.argv[1:]
main(arguments)
部分json文件参考:
{
"licenses": [],
"info": {},
"categories": [
{
"id": 1,
"name": "text",
"supercategory": "beverage",
"keypoints": [
"mean",
"xmin",
"x2",
"x3",
"xmax",
"ymin",
"y2",
"y3",
"ymax",
"cross"
]
}
],
"images": [
{
"coco_url": "",
"date_captured": "",
"file_name": "000001.jpg",
"flickr_url": "",
"id": 1,
"license": 0,
"width": 720,
"height": 1160
},
...
],
"annotations": [
{
"area": 6868.0,
"bbox": [
304.0,
343.0,
101.0,
68.0
],
"category_id": 1,
"id": 1,
"image_id": 1,
"iscrowd": 0,
"bezier_pts": [
304.0,
357.0,
454.0,
341.0,
458.0,
394.0,
308.0,
410.0,
329.0,
343.0,
354.0,
345.0,
379.0,
347.0,
404.0,
349.0
],
"rec": [
19,
16,
20,
12,
19,
21,
23
]
},
...
]
}
3. 训练
显卡:GeForce RTX 2080 Ti *2 batch_size = 2
1. 修改相关配置文件
-
将制作好的data数据目录放在”AdelaiDet/datasets”目录
我的目录结构是:
COCO
--annotations
--train.json
--val.json
--train
--val
- 修改”adet/data/builtin.py”中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
_PREDEFINED_SPLITS_TEXT = {
"totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
"totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
...
# 以下为修改 (改为自己的)
"COCO_train": ("COCO/train/", "COCO/annotations/train.json"),
"COCO_val": ("COCO/val/", "COCO/annotations/val.json"),
- 在需要训练的配置文件中指定数据集即可.以configs/BAText/Pretrain/Base-Pretrain.yaml为例
_BASE_: "../Base-BAText.yaml"
DATASETS:
# 以下为修改(改为自己的)
TRAIN: ("COCO_train",)
TEST: ("COCO_val",)
-
label 中有中文, 需下载这个 simsun.ttc 字体文件 放于 AdelaiDet/simsun.ttc 中
链接:
simsun.ttc
提取码:7tr2
2. 训练
OMP_NUM_THREADS=1 python tools/train_net.py \
--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \
--num-gpus 2 \
OUTPUT_DIR output/batext/pretrain/coco01 # 保存路径
3. 测试
MP_NUM_THREADS=1 python tools/train_net.py \
--config-file configs/BAText/Pretrain/v2_attn_R_50.yaml \
--eval-only \
--num-gpus 1 \
OUTPUT_DIR output/batext/pretrain/coco01_result \ # 保存路径
MODEL.WEIGHTS output/batext/pretrain/coco01/model_0019999.pth # model 路径
总结
中间遇到了很多问题,自己也参考了很多文章,特此记录,以便后来者参考。
inference
https://blog.csdn.net/weixin_43823854/article/details/108916498
https://www.tqwba.com/x_d/jishu/286353.html