DeepSnake源代码阅读笔记—-process on Sbd datasets

  • Post author:
  • Post category:其他





DeepSnake

源代码阅读笔记

  • 本篇主要分析

    Deep Snake for Real-Time Instance Segmentation

    文献中代码的实现,其开源代码地址为:https://github.com/zju3dv/snake。

  • [good18Levin](

    (6条消息) 论文阅读_fighterflyfly的博客-CSDN博客

    ) 的deepsnake代码阅读记录对我启发很大,该博客主要分析了模型测试过程中的代码。

  • 本篇笔记主要用于记录本人分析



    deepsnake在sbd数据集上训练



    的过程,本人水平有限,难免有理解不到位的地方,如有疑问欢迎交流。dongximing@webmail.hzau.edu.cn



〇、 环境配置

源码要求环境为pytorch1.1+cuda9.0,由于我显卡是RTX3090,并不支持老版本的cuda9.0,因此我采用了pytorch1.8.1+cuda11.1的环境,python版本为3.7。

  1. pytorch及cuda的安装

    pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
    

    也可以使用conda安装,但是使用conda安装会在后续编译cuda拓展的时候报错。

  2. 安装apex及编译

    git clone https://github.com/NVIDIA/apex.git
    cd apex 
    python setup.py install --cuda_ext --cpp_ext
    

    这里要注意nvidia cuda 版本要和pytorch cuda版本一致,否则可能会报错。

  3. dcn_v2的安装

    源码中的dcn_v2版本不支持高版本的pytorch及cuda,可以下载安装新版本的dcn_v2

    cd lib/csrc
    rm -rf dcn_v2
    git clone https://github.com/jinfagang/DCNv2_latest.git
    mv DCNv2_latest dcn_v2
    cd dcn_v2
    git checkout fa9b2fd740ced2a22e0e7e913c3bf3934fd08098
    python setup.py build develop
    

    因为最新版本支持pytorch1.11可能不会向下兼容,所以使用git checkout 回退之前支持pytorch1.8的版本。

  4. extreme_utils和roi_align_layer的安装按照源码进行安装即可

    cd lib/csrc/extreme_utils
    python setup.py build_ext --inplace
    cd ../roi_align_layer
    python setup.py build_ext --inplace
    



一、代码结构

  • 整体项目使用了许多种数据集并且包含了许多种网络,需要配置的参数量较多,因此项目使用了cfg读取yaml并结合parser来进行全局参数的储存。根据

    README.md

    针对sbd训练传入的参数以及

    configs/sbd_snake.yaml

    文件,可以得到cfg的参数:

    ct_score: 0.05

    demo_path:

    det_dir: data/model/snake/

    eval_ep: 5

    gpus: [0, 1, 2, 3]

    head_conv: 256

    heads:

    ct_hm: 20

    wh: 2

    model: sbd_snake

    model_dir: data/model/snake/sbd

    network: ro_34

    record_dir: data/record/snake/sbd

    result_dir: data/result/snake/sbd

    resume: True

    save_ep: 5

    segm_or_bbox: segm

    skip_eval: False

    task: snake

    test:

    batch_size: 1

    dataset: SbdMini

    epoch: -1

    train:

    batch_size: 80

    dataset: SbdTrain

    epoch: 150

    gamma: 0.5

    lr: 0.0001

    milestones: [80, 120, 150, 170]

    num_workers: 32

    optim: adam

    scheduler:

    warmup: False

    weight_decay: 0.0005

    use_gt_det: False

  • 项目训练采用了许多的封装,使得代码调用十分方便,但是对于代码阅读性也增添了些许难度,在接下来的代码阅读中我会利用cfg参数等对底层构建代码进行分析。

    #train_net.py
    from lib.config import cfg, args
    from lib.networks import make_network
    from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
    from lib.datasets import make_data_loader
    from lib.utils.net_utils import load_model, save_model, load_network
    from lib.evaluators import make_evaluator
    import torch.multiprocessing
    
    def train(cfg, network):
        trainer = make_trainer(cfg, network) #将损失函数与网络模型封装
        optimizer = make_optimizer(cfg, network) #封装优化器
        scheduler = make_lr_scheduler(cfg, optimizer) #封装学习率
        recorder = make_recorder(cfg) #记录
        evaluator = make_evaluator(cfg) #评估
    
        begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
        # set_lr_scheduler(cfg, scheduler)
    
        train_loader = make_data_loader(cfg, is_train=True) #封装数据集构建函数
        val_loader = make_data_loader(cfg, is_train=False)
    
        for epoch in range(begin_epoch, cfg.train.epoch):
            recorder.epoch = epoch
            trainer.train(epoch, train_loader, optimizer, recorder) #训练过程
            scheduler.step()
    		#保存数据
            if (epoch + 1) % cfg.save_ep == 0:
                save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)
    
            if (epoch + 1) % cfg.eval_ep == 0:
                trainer.val(epoch, val_loader, evaluator, recorder)
    
        return network
    
    def main():
        network = make_network(cfg) #利用cfg参数构建网络
        if args.test:
            test(cfg, network)
        else:
            train(cfg, network)
    
    if __name__ == "__main__":
        main()
    
    



二、数据集处理

首先找到构建数据集的封装函数

lib.datasets.make_data_loader

#lib/datasets/make_data_loader.py
from .transforms import make_transforms
from . import samplers
from .dataset_catalog import DatasetCatalog
import torch
import torch.utils.data
import imp
import os
from .collate_batch import make_collator

torch.multiprocessing.set_sharing_strategy('file_system')


def _dataset_factory(data_source, task):
    module = '.'.join(['lib.datasets', data_source, task])
    path = os.path.join('lib/datasets', data_source, task+'.py')
    dataset = imp.load_source(module, path).Dataset #加载lib.datasets.sbd.snake.Dataset函数
    return dataset


def make_dataset(cfg, dataset_name, transforms, is_train=True):
    args = DatasetCatalog.get(dataset_name)#SbdTrain
    data_source = args['id']
    dataset = _dataset_factory(data_source, cfg.task) #('sbd','snake')
    del args['id']
    # args['cfg'] = cfg
    # args['transforms'] = transforms
    # args['is_train'] = is_train
    dataset = dataset(**args)
    return dataset


def make_data_sampler(dataset, shuffle):
    if shuffle:
        sampler = torch.utils.data.sampler.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.sampler.SequentialSampler(dataset)
    return sampler


def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter):
    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last)
    if max_iter != -1:
        batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, max_iter)
    return batch_sampler


def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):
    if is_train:
        batch_size = cfg.train.batch_size #80
        shuffle = True
        drop_last = False
    else:
        batch_size = cfg.test.batch_size
        shuffle = True if is_distributed else False
        drop_last = False

    dataset_name = cfg.train.dataset if is_train else cfg.test.dataset #True

    transforms = make_transforms(cfg, is_train) #转变为Tensor并归一化,但是后续并没有用到
    dataset = make_dataset(cfg, dataset_name, transforms, is_train)
    sampler = make_data_sampler(dataset, shuffle)
    batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter)
    num_workers = cfg.train.num_workers #32
    collator = make_collator(cfg)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=collator
    )

    return data_loader

使用了

DatasetCatalog

中封装的数据集的信息,以及调用了

lib.datasets.sbd.snake.Dataset

函数构建数据集。

##lib/dataset/dataset_catalog.py
from lib.config import cfg
class DatasetCatalog(object):
    dataset_attrs = {  'SbdTrain': {
            'id': 'sbd',
            'data_root': 'data/sbd/img',
            'ann_file': 'data/sbd/annotations/sbd_train_instance.json',
            'split': 'train'
        }
    }
    @staticmethod
    def get(name):
        attrs = DatasetCatalog.dataset_attrs[name]
        return attrs.copy()



lib.datasets.sbd.snake.Dataset

又调用了

lib.datasets.voc.snake.Dataset

函数。

#lib/datasets/voc/snake.py
import os
from lib.utils.snake import snake_voc_utils, snake_config, visualize_utils
import cv2
import numpy as np
import math
from lib.utils import data_utils
import torch.utils.data as data
from pycocotools.coco import COCO
from lib.config import cfg

class Dataset(data.Dataset):
    def __init__(self, ann_file, data_root, split):
        super(Dataset, self).__init__()

        """
        data_root='data/sbd/img'
        ann_file='data/sbd/annotations/sbd_train_instance.json'
        split='train'
        """

        self.data_root = data_root
        self.split = split
        """
        构建COCO对象
        """
        self.coco = COCO(ann_file)
        self.anns = np.array(sorted(self.coco.getImgIds()))  # shape(5623,)
        self.anns = self.anns[:500] if split == 'mini' else self.anns
        self.json_category_id_to_contiguous_id = {v: i for i, v in enumerate(self.coco.getCatIds())}

    def process_info(self, img_id):
        """
        根据给定的imgid获得anno标注对象以及图片路径
        """
        ann_ids = self.coco.getAnnIds(imgIds=img_id)  # 获取标注id
        anno = self.coco.loadAnns(ann_ids)  # 读取标注信息
        path = os.path.join(self.data_root, self.coco.loadImgs(int(img_id))[0]['file_name'])
        return anno, path, img_id

    def read_original_data(self, anno, path):
        """
        读取图片
        将标注的segmentation由xyxyxyxyxy一维转为二维[[x,y],[x,y]] -> instance_polys
        读取标注对应的classid -> cls_ids
        """
        img = cv2.imread(path)
        instance_polys = [[np.array(poly).reshape(-1, 2) for poly in obj['segmentation']] for obj in anno]
        cls_ids = [self.json_category_id_to_contiguous_id[obj['category_id']] for obj in anno]
        return img, instance_polys, cls_ids

    def transform_original_data(self, instance_polys, flipped, width, trans_output, inp_out_hw):
        """
        因为数据增强的缘故,也需要对原始的标注进行相应的变换
        使用了trans_output进行仿射变换,使标注与图像的仿射变换保持一致
        """
        output_h, output_w = inp_out_hw[2:]
        instance_polys_ = []
        for instance in instance_polys:
            polys = [poly.reshape(-1, 2) for poly in instance]

            if flipped:
                polys_ = []
                for poly in polys:
                    poly[:, 0] = width - np.array(poly[:, 0]) - 1
                    polys_.append(poly.copy())
                polys = polys_

            polys = snake_voc_utils.transform_polys(polys, trans_output, output_h, output_w)
            instance_polys_.append(polys)
        return instance_polys_

    def get_valid_polys(self, instance_polys, inp_out_hw):
        """
        对变换后的多边形标注进行过滤:
        1.过滤掉点数小于等于四的多边形
        2.对多边形坐标进行clip,防止多边形超出output size
        3.filter删除面积小于5的多边形
        4.cw获取顺时针的poly
        5.对坐标重复的顶点进行去重,保持顶点的顺序不变
        """
        output_h, output_w = inp_out_hw[2:]
        instance_polys_ = []
        for instance in instance_polys:
            instance = [poly for poly in instance if len(poly) >= 4]
            for poly in instance:
                poly[:, 0] = np.clip(poly[:, 0], 0, output_w - 1)
                poly[:, 1] = np.clip(poly[:, 1], 0, output_h - 1)
            polys = snake_voc_utils.filter_tiny_polys(instance)
            polys = snake_voc_utils.get_cw_polys(polys)
            polys = [poly[np.sort(np.unique(poly, axis=0, return_index=True)[1])] for poly in polys]
            instance_polys_.append(polys)
        return instance_polys_

    def get_extreme_points(self, instance_polys):
        """
        获取多边形的极值点坐标(上左下右)(逆时针)
        """
        extreme_points = []
        for instance in instance_polys:
            points = [snake_voc_utils.get_extreme_points(poly) for poly in instance]
            extreme_points.append(points)
        return extreme_points

    def prepare_detection(self, box, poly, ct_hm, cls_id, wh, ct_cls, ct_ind):
        """
        计算center
        将已计算的class的id保存clsid,将center编码保存在ctind
        计算高斯模糊半径,并更新cthm
        """
        ct_hm = ct_hm[cls_id]#ct_hm.shape:(class,128,128) ->(128,128)
        ct_cls.append(cls_id)#将已经prepare的class保存

        x_min, y_min, x_max, y_max = box
        ct = np.array([(x_min + x_max) / 2, (y_min + y_max) / 2], dtype=np.float32)
        ct = np.round(ct).astype(np.int32)

        h, w = y_max - y_min, x_max - x_min
        radius = data_utils.gaussian_radius((math.ceil(h), math.ceil(w)))
        radius = max(0, int(radius))
        data_utils.draw_umich_gaussian(ct_hm, ct, radius)

        wh.append([w, h])
        ct_ind.append(ct[1] * ct_hm.shape[1] + ct[0])

        x_min, y_min = ct[0] - w / 2, ct[1] - h / 2
        x_max, y_max = ct[0] + w / 2, ct[1] + h / 2
        decode_box = [x_min, y_min, x_max, y_max]

        return decode_box

    def prepare_init(self, box, extreme_point, i_it_4pys, c_it_4pys, i_gt_4pys, c_gt_4pys, h, w):
        x_min, y_min = np.min(extreme_point[:, 0]), np.min(extreme_point[:, 1])
        x_max, y_max = np.max(extreme_point[:, 0]), np.max(extreme_point[:, 1])

        img_init_poly = snake_voc_utils.get_init(box)  # 中点连接的四边形
        """
        uniformsaple 对中点连接的四边形进行均匀采样,获得40个点;如果原始点大于40,则减去最短的路径。
        """
        img_init_poly = snake_voc_utils.uniformsample(img_init_poly, snake_config.init_poly_num)#在图片中的poly坐标
        can_init_poly = snake_voc_utils.img_poly_to_can_poly(img_init_poly, x_min, y_min, x_max, y_max)#归一化的poly坐标
        img_gt_poly = extreme_point
        can_gt_poly = snake_voc_utils.img_poly_to_can_poly(img_gt_poly, x_min, y_min, x_max, y_max)

        i_it_4pys.append(img_init_poly)
        c_it_4pys.append(can_init_poly)
        i_gt_4pys.append(img_gt_poly)
        c_gt_4pys.append(can_gt_poly)

    def prepare_evolution(self, poly, extreme_point, img_init_polys, can_init_polys, img_gt_polys, can_gt_polys):
        x_min, y_min = np.min(extreme_point[:, 0]), np.min(extreme_point[:, 1])
        x_max, y_max = np.max(extreme_point[:, 0]), np.max(extreme_point[:, 1])

        octagon = snake_voc_utils.get_octagon(extreme_point) #根据极值点建立八边形

        img_init_poly = snake_voc_utils.uniformsample(octagon, snake_config.poly_num)
        can_init_poly = snake_voc_utils.img_poly_to_can_poly(img_init_poly, x_min, y_min, x_max, y_max)

        img_gt_poly = snake_voc_utils.uniformsample(poly, len(poly) * snake_config.gt_poly_num)
        tt_idx = np.argmin(np.power(img_gt_poly - img_init_poly[0], 2).sum(axis=1))
        img_gt_poly = np.roll(img_gt_poly, -tt_idx, axis=0)[::len(poly)] #将最大值顶点移动到首位元素
        can_gt_poly = snake_voc_utils.img_poly_to_can_poly(img_gt_poly, x_min, y_min, x_max, y_max)

        img_init_polys.append(img_init_poly)
        can_init_polys.append(can_init_poly)
        img_gt_polys.append(img_gt_poly)
        can_gt_polys.append(can_gt_poly)

    def prepare_merge(self, is_id, cls_id, cp_id, cp_cls):
        cp_id.append(is_id)
        cp_cls.append(cls_id)

    def __getitem__(self, index):
        ann = self.anns[index]

        anno, path, img_id = self.process_info(ann)
        img, instance_polys, cls_ids = self.read_original_data(anno, path)

        height, width = img.shape[0], img.shape[1]
        """
        imgsize:442*500 -> inp size:512*512
        进行图片增强:随机裁剪翻转缩放,颜色增强,图片归一化,仿射变换
        trans为仿射变换矩阵
        fipped记录是否翻转
        inp_out_hw 记录输入输出宽高 (512,512,128,128),因为降采样倍数为4
        scale 随机裁剪后,仿射变换前的图片尺寸(缩放尺度)
        图片缩小了四倍。
        """
        orig_img, inp, trans_input, trans_output, flipped, center, scale, inp_out_hw = \
            snake_voc_utils.augment(
                img, self.split,
                snake_config.data_rng, snake_config.eig_val, snake_config.eig_vec,
                snake_config.mean, snake_config.std, instance_polys
            )
        instance_polys = self.transform_original_data(instance_polys, flipped, width, trans_output, inp_out_hw)
        instance_polys = self.get_valid_polys(instance_polys, inp_out_hw)
        extreme_points = self.get_extreme_points(instance_polys)

        # detection
        output_h, output_w = inp_out_hw[2:]
        ct_hm = np.zeros([cfg.heads.ct_hm, output_h, output_w], dtype=np.float32)
        wh = []
        ct_cls = []
        ct_ind = []

        # init
        i_it_4pys = []
        c_it_4pys = []
        i_gt_4pys = []
        c_gt_4pys = []

        # evolution
        i_it_pys = []
        c_it_pys = []
        i_gt_pys = []
        c_gt_pys = []

        for i in range(len(anno)):
            cls_id = cls_ids[i]
            instance_poly = instance_polys[i]
            instance_points = extreme_points[i]

            for j in range(len(instance_poly)):
                poly = instance_poly[j]
                extreme_point = instance_points[j]

                x_min, y_min = np.min(poly[:, 0]), np.min(poly[:, 1])
                x_max, y_max = np.max(poly[:, 0]), np.max(poly[:, 1])
                bbox = [x_min, y_min, x_max, y_max]
                h, w = y_max - y_min + 1, x_max - x_min + 1
                if h <= 1 or w <= 1:
                    continue

                """
                计算center
                将已计算的class的id保存clsid,将center编码保存在ctind
                计算高斯模糊半径,并更新cthm
                """
                self.prepare_detection(bbox, poly, ct_hm, cls_id, wh, ct_cls, ct_ind)
                """
                根据bbox中点连接的四边形,均匀采样四十个点生成i_it_4pys,c_it_4pys为其归一化的点坐标
                i_gt_4pys为极值点,c_gt_4pys为其归一化的点坐标
                """
                self.prepare_init(bbox, extreme_point, i_it_4pys, c_it_4pys, i_gt_4pys, c_gt_4pys, output_h, output_w)
                """
                根据极值点建立八边形,并采样128个点,生成i_it_pys
                i_gt_pys为根据标注采样的128个点
                """
                self.prepare_evolution(poly, extreme_point, i_it_pys, c_it_pys, i_gt_pys, c_gt_pys)

        ret = {'inp': inp}
        detection = {'ct_hm': ct_hm, 'wh': wh, 'ct_cls': ct_cls, 'ct_ind': ct_ind}
        init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'i_gt_4py': i_gt_4pys, 'c_gt_4py': c_gt_4pys}
        evolution = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys, 'i_gt_py': i_gt_pys, 'c_gt_py': c_gt_pys}
        ret.update(detection)
        ret.update(init)
        ret.update(evolution)
        # visualize_utils.visualize_snake_detection(orig_img, ret)
        # visualize_utils.visualize_snake_evolution(orig_img, ret)

        ct_num = len(ct_ind)
        meta = {'center': center, 'scale': scale, 'img_id': img_id, 'ann': ann, 'ct_num': ct_num}
        ret.update({'meta': meta})

        return ret

    def __len__(self):
        return len(self.anns)

最后制作的dataloader保存数据通过

for x,y in enumerate(train_loader): break

可以查看:

在这里插入图片描述



三、网络结构

#train_net.py
from lib.networks import make_network
def main():
    network = make_network(cfg)
    if args.test:
        test(cfg, network)
    else:
        train(cfg, network)

根据train_net.py找到调用函数

lib.networks.make_network

#lib/networks/make_network.py
def make_network(cfg):
    module = '.'.join(['lib.networks', cfg.task])
    path = os.path.join('lib/networks', cfg.task, '__init__.py')
    return imp.load_source(module, path).get_network(cfg)

cfg.task为snake,找到

lib.networks.snake.\__init__.get\_network

函数

#lib/networks/snkae/__init__.py
from lib.utils.snake import snake_config
from .ct_snake import get_network as get_ro

_network_factory = {
    'ro': get_ro
}
def get_network(cfg):
    arch = cfg.network
    heads = cfg.heads
    head_conv = cfg.head_conv
    num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
    arch = arch[:arch.find('_')] if '_' in arch else arch
    get_model = _network_factory[arch]
    network = get_model(num_layers, heads, head_conv, snake_config.down_ratio, cfg.det_dir)
    return network

看来最后的网络构建函数来自于

lib.networks.snake.ct_snake.get_network

#lib/networks/snake/ct_snake.py
import torch.nn as nn
from .dla import DLASeg
from .evolve import Evolution
from lib.utils import net_utils, data_utils
from lib.utils.snake import snake_decode
import torch
from lib.config import cfg


class Network(nn.Module):
    def __init__(self, num_layers, heads, head_conv=256, down_ratio=4, det_dir=''):
        super(Network, self).__init__()
		"""
		DLASeg('dla34',heads={'ct_hm':20,'wh':2},
				pretrained=True,down_ratio=4,
				final_kernel=1,last_level=5,
				head_conv=256)
		"""
        self.dla = DLASeg('dla{}'.format(num_layers), heads,
                          pretrained=True,
                          down_ratio=down_ratio,
                          final_kernel=1,
                          last_level=5,
                          head_conv=head_conv)
        self.gcn = Evolution()

    def decode_detection(self, output, h, w):
        ct_hm = output['ct_hm']
        wh = output['wh']
        ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh)
        detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w)
        output.update({'ct': ct, 'detection': detection})
        return ct, detection

    def forward(self, x, batch=None):
        output, cnn_feature = self.dla(x)
        with torch.no_grad():
            ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3))
        if cfg.use_gt_det:
            self.use_gt_detection(output, batch)
        output = self.gcn(output, cnn_feature, batch)
        return output


def get_network(num_layers, heads, head_conv=256, down_ratio=4, det_dir=''):
    network = Network(num_layers, heads, head_conv, down_ratio, det_dir)
    return network

由上述代码可以看到,网络结构主要分为三个部分:

  • DLAseg,使用了dla34的特征提取网络;
  • decode_detection,解码特征图;
  • Evolution,主要涉及轮廓演变。



1. DLASeg

该DLAseg主要分为:

  • base层(dla34)
  • DLAUP和IDAUP
  • 检测头

最后输出:包含两个检测头输出的字典,CNNfeature。

class DLASeg(nn.Module):
    def __init__(self, base_name, heads, pretrained, down_ratio, final_kernel,
                 last_level, head_conv, out_channel=0):
        super(DLASeg, self).__init__()
        """
        base_name=dla34
        heads={'ct_hm':20,'wh':2}
        pretrained=True
        dwon_ratio=4
        final_kernel=1
        last_level=5
        head_conv=256
        """
        assert down_ratio in [2, 4, 8, 16]
        self.first_level = int(np.log2(down_ratio)) # 2
        self.last_level = last_level
        self.base = globals()[base_name](pretrained=pretrained) #dla34(pretrained=True)
        channels = self.base.channels
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)

        if out_channel == 0:
            out_channel = channels[self.first_level]

        self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level],
                            [2 ** i for i in range(self.last_level - self.first_level)])

        self.heads = heads
        
        """
        根据heads建立检测头,本任务有两个检测头:ct_hm:heatmap检测;wh:bbox的宽高。
        """
        for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
                fc = nn.Sequential(
                    nn.Conv2d(channels[self.first_level], head_conv,
                              kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, classes,
                              kernel_size=final_kernel, stride=1,
                              padding=final_kernel // 2, bias=True))
                if 'hm' in head:
                    fc[-1].bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            else:
                fc = nn.Conv2d(channels[self.first_level], classes,
                               kernel_size=final_kernel, stride=1,
                               padding=final_kernel // 2, bias=True)
                if 'hm' in head:
                    fc.bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            self.__setattr__(head, fc)

    def forward(self, x):
        x = self.base(x)
        x = self.dla_up(x)

        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))

        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(y[-1])
        return z, y[-1]


1.1 dla34
def dla34(pretrained=True, **kwargs):  # DLA-34
    model = DLA([1, 1, 1, 2, 2, 1],
                [16, 32, 64, 128, 256, 512],
                block=BasicBlock, **kwargs)
    if pretrained:
        model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86')
    return model

class DLA(nn.Module):
    def __init__(self, levels, channels, num_classes=1000,
                 block=BasicBlock, residual_root=False, linear_root=False):
        super(DLA, self).__init__()
        self.channels = channels
        self.num_classes = num_classes
        self.base_layer = nn.Sequential(
            nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
                      padding=3, bias=False),
            nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
            nn.ReLU(inplace=True))
        self.level0 = self._make_conv_level(
            channels[0], channels[0], levels[0])
        self.level1 = self._make_conv_level(
            channels[0], channels[1], levels[1], stride=2)
        """
        levelroot可以看作一个标记,看以看做是否接受上一个level输出的downsample传入(下图红色方块之间的蓝色箭头);
        默认的level_root为False,通过递归创建的Tree并未特定level_root,表示与前一个Tree为同一个level。
        """
        self.level2 = Tree(levels[2], block, channels[1], channels[2], 2,
                           level_root=False,
                           root_residual=residual_root)
        self.level3 = Tree(levels[3], block, channels[2], channels[3], 2,
                           level_root=True, root_residual=residual_root)
        self.level4 = Tree(levels[4], block, channels[3], channels[4], 2,
                           level_root=True, root_residual=residual_root)
        self.level5 = Tree(levels[5], block, channels[4], channels[5], 2,
                           level_root=True, root_residual=residual_root)


    def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
        modules = []
        for i in range(convs):
            modules.extend([
                nn.Conv2d(inplanes, planes, kernel_size=3,
                          stride=stride if i == 0 else 1,
                          padding=dilation, bias=False, dilation=dilation),
                nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),
                nn.ReLU(inplace=True)])
            inplanes = planes
        return nn.Sequential(*modules)

    def forward(self, x):
        y = []
        x = self.base_layer(x)
        for i in range(6):
            x = getattr(self, 'level{}'.format(i))(x)
            y.append(x)
        return y

dla34是首先通过Conv-BN-RL基础层然后再通过五个level进行特征提取的,level0和level1都是Conv-BN-RL层,level2、3、4、5则是调用Tree对象进行构建。DLA的输入参数levels,则表示了level0-5的循环调用与否。

levels:[1, 1, 1, 2, 2, 1]

表示level3,4会循环调用Tree。

class Tree(nn.Module):
    def __init__(self, levels, block, in_channels, out_channels, stride=1,
                 level_root=False, root_dim=0, root_kernel_size=1,
                 dilation=1, root_residual=False):
        super(Tree, self).__init__()
        if root_dim == 0:
            root_dim = 2 * out_channels
        if level_root:
            root_dim += in_channels
		"""
		当输入level等于1时直接用BasicBlock进行构建,当level大于1时进行Tree循环构建。
		"""
        if levels == 1:
            self.tree1 = block(in_channels, out_channels, stride,
                               dilation=dilation)
            self.tree2 = block(out_channels, out_channels, 1,
                               dilation=dilation)
        else:
            self.tree1 = Tree(levels - 1, block, in_channels, out_channels,
                              stride, root_dim=0,
                              root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
            self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
                              root_dim=root_dim + out_channels,
                              root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
        if levels == 1:
            self.root = Root(root_dim, out_channels, root_kernel_size,
                             root_residual)
        self.level_root = level_root
        self.root_dim = root_dim
        self.downsample = None
        self.project = None
        self.levels = levels
        if stride > 1:
            self.downsample = nn.MaxPool2d(stride, stride=stride)
        if in_channels != out_channels:
            self.project = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
            )

    def forward(self, x, residual=None, children=None):
        children = [] if children is None else children
        bottom = self.downsample(x) if self.downsample else x
        residual = self.project(bottom) if self.project else bottom
        if self.level_root:
            children.append(bottom)
        x1 = self.tree1(x, residual)
        if self.levels == 1:
            x2 = self.tree2(x1)
            x = self.root(x2, x1, *children)
        else:
            children.append(x1)
            x = self.tree2(x1, children=children)
        return x
    
class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, residual):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, 1,
            stride=1, bias=False, padding=(kernel_size - 1) // 2)
        self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.residual = residual

    def forward(self, *x):
        children = x
        x = self.conv(torch.cat(x, 1))
        x = self.bn(x)
        if self.residual:
            x += children[0]
        x = self.relu(x)

        return x

class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
                               stride=stride, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out

img



1.2 DLAUP与IDAUP
class DLAUp(nn.Module):
    def __init__(self, startp, channels, scales, in_channels=None):
        super(DLAUp, self).__init__()
        """
        startp=2
        channels=[64,128,256,512]
        sclaes[1,2,4,8]
        """
        self.startp = startp
        if in_channels is None:
            in_channels = channels
        self.channels = channels
        channels = list(channels)
        scales = np.array(scales, dtype=int)
        
        for i in range(len(channels) - 1): #range(3)
            j = -i - 2
            setattr(self, 'ida_{}'.format(i),
                    IDAUp(channels[j], in_channels[j:],
                          scales[j:] // scales[j]))
            scales[j + 1:] = scales[j]
            in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]

    def forward(self, layers):
        out = [layers[-1]]  # start with 32
        for i in range(len(layers) - self.startp - 1):
            ida = getattr(self, 'ida_{}'.format(i))
            ida(layers, len(layers) - i - 2, len(layers))
            out.insert(0, layers[-1])
        return out

class IDAUp(nn.Module):
    def __init__(self, o, channels, up_f):
        super(IDAUp, self).__init__()
        for i in range(1, len(channels)):
            c = channels[i]
            f = int(up_f[i])
            """
            使用可形变卷积
            """
            proj = DeformConv(c, o)
            node = DeformConv(o, o)
			"""
			转置卷积,上采样
			"""
            up = nn.ConvTranspose2d(o, o, f * 2, stride=f,
                                    padding=f // 2, output_padding=0,
                                    groups=o, bias=False)
            fill_up_weights(up)

            setattr(self, 'proj_' + str(i), proj)
            setattr(self, 'up_' + str(i), up)
            setattr(self, 'node_' + str(i), node)

    def forward(self, layers, startp, endp):
        for i in range(startp + 1, endp):
            upsample = getattr(self, 'up_' + str(i - startp))
            project = getattr(self, 'proj_' + str(i - startp))
            layers[i] = upsample(project(layers[i]))
            node = getattr(self, 'node_' + str(i - startp))
            layers[i] = node(layers[i] + layers[i - 1])

在这里插入图片描述

DLAUP与IDAUP结构图,图中红框表示了本项目中使用的部分,4s,8s,16s,32s分别表示了leve2、3、4、5的输出,将特征融合最终输出4s特征图。



1.3 检测头
       """
       DLASeg中检测头的构造
       """
    	for head in self.heads:
            classes = self.heads[head]
            if head_conv > 0:
                fc = nn.Sequential(
                    nn.Conv2d(channels[self.first_level], head_conv,
                              kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, classes,
                              kernel_size=final_kernel, stride=1,
                              padding=final_kernel // 2, bias=True))
                if 'hm' in head:
                    fc[-1].bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            else:
                fc = nn.Conv2d(channels[self.first_level], classes,
                               kernel_size=final_kernel, stride=1,
                               padding=final_kernel // 2, bias=True)
                if 'hm' in head:
                    fc.bias.data.fill_(-2.19)
                else:
                    fill_fc_weights(fc)
            self.__setattr__(head, fc)
   
"""
forward中的传递,y[-1]为IDAUP输出的4s特征图
"""
	z = {}
    for head in self.heads:
        z[head] = self.__getattr__(head)(y[-1])



2.decode_detection

还是先来看一下在ct_snake中decode_detection的使用:

    def forward(self, x, batch=None):
        output, cnn_feature = self.dla(x)
        """
        output为检测头的输出、cnn_feature为检测头之前的骨干网络提取的特征
        """
        with torch.no_grad():
            ct, detection = self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3))
        if cfg.use_gt_det: #False
            self.use_gt_detection(output, batch)
        output = self.gcn(output, cnn_feature, batch)
        return output
    
    def decode_detection(self, output, h, w):
        ct_hm = output['ct_hm']
        wh = output['wh']
        ct, detection = snake_decode.decode_ct_hm(torch.sigmoid(ct_hm), wh)
        detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w) #防止越界
        output.update({'ct': ct, 'detection': detection})
        return ct, detection

主要使用了

snake_decode.decode_ct_hm

函数

def decode_ct_hm(ct_hm, wh, reg=None, K=100):
    """
    假设传入:ct_hm [80,20,128,128],wh [80,2,128,128]
    """
    batch, cat, height, width = ct_hm.size()
    ct_hm = nms(ct_hm) #非极大值抑制
	"""
	输出topk的得分值,索引,类别,xy值 [80,100]
	"""
    scores, inds, clses, ys, xs = topk(ct_hm, K=K)
    wh = transpose_and_gather_feat(wh, inds) #获取topk对应的宽高 [80,100,2]
    wh = wh.view(batch, K, 2)

    if reg is not None:
        reg = transpose_and_gather_feat(reg, inds)
        reg = reg.view(batch, K, 2)
        xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
        ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
    else:
        xs = xs.view(batch, K, 1)
        ys = ys.view(batch, K, 1)

    clses = clses.view(batch, K, 1).float() #[80,100,1]
    scores = scores.view(batch, K, 1) #[80,100,1]
    ct = torch.cat([xs, ys], dim=2) #构建center [80,100,2]
    bboxes = torch.cat([xs - wh[..., 0:1] / 2,
                        ys - wh[..., 1:2] / 2,
                        xs + wh[..., 0:1] / 2,
                        ys + wh[..., 1:2] / 2], dim=2) #构建bounding box [80,100,4]
    detection = torch.cat([bboxes, scores, clses], dim=2) #[80,100,4(bbox)+1(scores)+1(clses)]

    return ct, detection 

def nms(heat, kernel=3):
    """
    使用3*3最大值池化进行非极大值抑制
    """
    pad = (kernel - 1) // 2
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=pad)
    keep = (hmax == heat).float()
    return heat * keep

def topk(scores, K=40):
    #假设现在scores维度为[80,10,128,128]
    batch, cat, height, width = scores.size()
	"""
	选择每类中最大的K个值 ->topk_scores [80,10,100]
	对应的索引 -> topk_inds [80,10,100]
	注意上述是在scores.view(batch,cat,-1):[80,10,128*128]中的索引。
	"""
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    topk_inds = topk_inds % (height * width)    #这一行貌似没有作用?
    topk_ys = (topk_inds / width).int().float() #获取对应的图片的第几行
    topk_xs = (topk_inds % width).int().float()	#第几列
	"""
	选择topk_scores所有类中最大的k个值 ->topk_score [80,100]
	对应的索引-> topk_ind [80,100]
	注意上述是在topk_scores.view(batch,-1):[80,10*100]中的索引。
	"""
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
    topk_clses = (topk_ind / K).int() #最大值对应的类别 [80,100]
    """
    获取在原始输入中最大K值对应的索引以及xy值
    """
    topk_inds = gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) #[80,100]
    topk_ys = gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)#[80,100]
    topk_xs = gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)#[80,100]

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

def gather_feat(feat, ind, mask=None):
    """
    以上述gather_feat(topk_inds.view(batch,-1,1),topk_ind)为例
    传入feat : topk_inds.view(batch,-1,1) : topk_inds[80,10,100]-> [80,10*100,1]
    传入ind:   topk_ind [80,100]
    """
    dim = feat.size(2) #1
    ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) #[80,100]->[80,100,1]
    """
    topk_inds 对应的是topk_scores的索引,也即在图片中位置的索引;
    而topk_ind对应的是topk_score的索引,是在topk_scores中最大的K个值的索引;
    因此使用topk_inds.gather(1,topk_ind)即可实现将topk_ind映射回在图片中对应位置的索引。
    """
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat #[80,100]

def transpose_and_gather_feat(feat, ind):
    """
    传入wh:[80,2,128,128] ->feat
    topk的索引[80,100] -> ind
    """
    feat = feat.permute(0, 2, 3, 1).contiguous() #[80,128,128,2]
    feat = feat.view(feat.size(0), -1, feat.size(3))#[80,128*128,2]
    feat = gather_feat(feat, ind) #返回topk对应的宽高 [80,100,2]
    return feat

  • 最后的结果是在output内更新了两个键值’ct’: [batch,topk,2] 和’detection’: [batch,topk,6],其中detection[…,:4]为bounding box,[…,4]为topk得分,[…,5]为class。
  • ct和detection貌似没有在训练过程中使用,而是在测试过程中使用。



3. Evolution

Evolution模块主要可以分为:

  • snake层
  • 轮廓演化:

    • 训练
    • 测试


3.1 snake

Snake层主要是对应于论文中的循环卷积。

  • input[batch,66,顶点数]

  • 首先构造head,将输入卷积(feature_dim66)至state_dim128;

  • 然后使用空洞圆卷积构建七层残差层,这里空洞圆卷积是通过将input后方的特征拼接到input前,再将input前方的特征拼接到input后通过fc实现循环卷积;

  • 接着将head和残差层按通道进行拼接后输入fuse层,输出通道256;

  • 取fuse层输出每个特征通道的最大值(shape=[batch,256,1]),并拓展至[batch,256,顶点数];

  • 将head+七层残差层+上一步的输出按照通道进行融合,传入prediction层,最后输出[batch,2,顶点数]。

"""
空洞圆卷积
可以看到在构造过程中与普通fc层无异,卷积核为9,dialtion根据Snake构造函数定义
重点是在forward函数中,先将input特征进行拼接:
	input后面的-n_adj*dialtion + input + input前面的n_adj*dialtion
形成环形,再进行fc卷积。
"""
class DilatedCircConv(nn.Module):
    def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):
        super(DilatedCircConv, self).__init__()

        self.n_adj = n_adj
        self.dilation = dilation
        out_state_dim = state_dim if out_state_dim is None else out_state_dim
        self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)

    def forward(self, input, adj):
        """
		此处传入adj邻接矩阵,但是好像并没有用到,推测可能是原来想用于GCN的
		"""
        if self.n_adj != 0:
            input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)
        return self.fc(input)

"""
Conv-relu-bn,Conv为空洞圆卷积
"""
class BasicBlock(nn.Module):
    def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):
        super(BasicBlock, self).__init__()

        self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation) # DilatedCircConv
        self.relu = nn.ReLU(inplace=True)
        self.norm = nn.BatchNorm1d(out_state_dim)

    def forward(self, x, adj=None):
        x = self.conv(x, adj)
        x = self.relu(x)
        x = self.norm(x)
        return x


class Snake(nn.Module):
    def __init__(self, state_dim, feature_dim, conv_type='dgrid'):
        super(Snake, self).__init__()
		"""
		主要构造
		1.head为普通basicBlock
		2.七个残差层
		3.融合层
		4.预测层
		"""
        self.head = BasicBlock(feature_dim, state_dim, conv_type)

        self.res_layer_num = 7
        dilation = [1, 1, 1, 2, 2, 4, 4] # 每个残差层的dilation参数
        #循环构建七个残差层
        for i in range(self.res_layer_num):
            conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])
            self.__setattr__('res'+str(i), conv)
		
        fusion_state_dim = 256 
        #将head及残差层所有的特征进行融合
        self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)
        #输出偏移量
        self.prediction = nn.Sequential(
            nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 64, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 2, 1)
        )

    def forward(self, x, adj):
        """
        假设传入x[80,66,40],state_dim=128,feature_dim=66
        """
        states = []

        x = self.head(x, adj) #[80,128,40]
        states.append(x)
        for i in range(self.res_layer_num):
            x = self.__getattr__('res'+str(i))(x, adj) + x
            states.append(x) #[80,128,40]
		#将head和res进行拼接
        state = torch.cat(states, dim=1)  #[80,128*8,40]
        """
         torch.max([80,256,40],dim=2,keepdim=True)[0] ->[80,256,1]
         相当于取每个通道的最大值
        """
        global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]
        global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2)) #[80,256,40]
        state = torch.cat([global_state, state], dim=1) #[80,256+128*8,40]
        x = self.prediction(state)

        return x


3.2 Evolution
class Evolution(nn.Module):
    def __init__(self):
        super(Evolution, self).__init__()

        self.fuse = nn.Conv1d(128, 64, 1)
        self.init_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')
        self.evolve_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')
        self.iter = 2
        for i in range(self.iter):
            evolve_gcn = Snake(state_dim=128, feature_dim=64 + 2, conv_type='dgrid')
            self.__setattr__('evolve_gcn' + str(i), evolve_gcn)

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


3.2.1训练过程
  • 训练过程中是先收集了标注的信息进行初始化,获得四十个点的初始坐标;
  • 将初始坐标(利用bbox中点构建的四边形)传入snake,然后返回了极值点的预测(因为gt为极值点),更新进ret中;
  • 根据标注的轮廓(利用极值点+bbox中点构建的八边形)进行初始轮廓演化,得到128个点的轮廓
  • 将得到的轮廓传入下一次演化,迭代两次
  • 将演化的轮廓更新进ret键值
def forward(self, output, cnn_feature, batch=None):
    ret = output

    if batch is not None and 'test' not in batch['meta']:
        with torch.no_grad():
            """
            init:
            ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置
            collect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)
            ct_num记录每张图片内有几个ct
            4py_ind记录了拼接后的poly对应的图片 size:(polynum,) 如:【0,1,2,2,3,4,4,4】
            output:
            更新了i_init_py..等四个py
            """
            init = self.prepare_training(output, batch)
        # 返回了预测的四个bbox中点坐标(极值点)
        ex_pred = self.init_poly(self.init_gcn, cnn_feature, init['i_it_4py'], init['c_it_4py'], init['4py_ind'])
        ret.update({'ex_pred': ex_pred, 'i_gt_4py': output['i_gt_4py']}) # ret 和output是一起改变的 为什么还要更新gt呢?

        # with torch.no_grad():
        #     init = self.prepare_training_evolve(output, batch, init)

        py_pred = self.evolve_poly(self.evolve_gcn, cnn_feature, init['i_it_py'], init['c_it_py'], init['py_ind'])
        py_preds = [py_pred]
        for i in range(self.iter): #range(2)
            py_pred = py_pred / snake_config.ro
            c_py_pred = snake_gcn_utils.img_poly_to_can_poly(py_pred)
            evolve_gcn = self.__getattr__('evolve_gcn' + str(i))
            py_pred = self.evolve_poly(evolve_gcn, cnn_feature, py_pred, c_py_pred, init['py_ind'])
            py_preds.append(py_pred)
        ret.update({'py_pred': py_preds, 'i_gt_py': output['i_gt_py'] * snake_config.ro})

这里训练过程其实是用到了标注的信息,收集标注中信息。

def prepare_training(self, output, batch):
    """
    ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置
    collect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)
    ct_num记录每张图片内有几个ct
    4py_ind记录了拼接后的poly对应的图片 size:(polynum,) 如:【0,1,2,2,3,4,4,4】
    """
    init = snake_gcn_utils.prepare_training(output, batch)
    output.update({'i_it_4py': init['i_it_4py'], 'i_it_py': init['i_it_py']})
    output.update({'i_gt_4py': init['i_gt_4py'], 'i_gt_py': init['i_gt_py']})
    return init

#snake_gcn_utils.prepare_training
def prepare_training(ret, batch):
    """
    ct_01记录了batch中每张图片里有几个poly,以及对应的poly位置 size(80,一张图片内最多的数量) 比如80张图片内最多有13个poly: (80,13)
    polynum=ct_01.sum
    collect_training将batch中所有的poly拼接在一起 size:(polynum,40,2)
    ct_num记录每张图片内有几个ct
    4py_ind记录了拼接后的poly对应的图片 shape:(polynum,) 如:【0,1,2,2,3,4,4,4】表示第一张图片有一个poly,位置为0;第二张图片有一个poly,位置1;第三张图片有两个poly,位置2,3;第四张图片有一个poly,位置4;第五张图片有三个poly,位置为5,6,7。
    """
    ct_01 = batch['ct_01'].byte()
    init = {}
    init.update({'i_it_4py': collect_training(batch['i_it_4py'], ct_01)})
    init.update({'c_it_4py': collect_training(batch['c_it_4py'], ct_01)})
    init.update({'i_gt_4py': collect_training(batch['i_gt_4py'], ct_01)})
    init.update({'c_gt_4py': collect_training(batch['c_gt_4py'], ct_01)})

    init.update({'i_it_py': collect_training(batch['i_it_py'], ct_01)})
    init.update({'c_it_py': collect_training(batch['c_it_py'], ct_01)})
    init.update({'i_gt_py': collect_training(batch['i_gt_py'], ct_01)})
    init.update({'c_gt_py': collect_training(batch['c_gt_py'], ct_01)})

    ct_num = batch['meta']['ct_num']
    init.update({'4py_ind': torch.cat([torch.full([ct_num[i]], i) for i in range(ct_01.size(0))], dim=0)})
    init.update({'py_ind': init['4py_ind']})

    if snake_config.train_pred_box:#Flase
        prepare_training_box(ret, batch, init)

    init['4py_ind'] = init['4py_ind'].to(ct_01.device)
    init['py_ind'] = init['py_ind'].to(ct_01.device)

    return init

def collect_training(poly, ct_01):
    """
    将batch中所有的poly拼接在一起 size:(polynum,40,2)
    """
    batch_size = ct_01.size(0)
    poly = torch.cat([poly[i][ct_01[i]] for i in range(batch_size)], dim=0)
    return poly

初始化轮廓

  • 获取i_it_4py(四十个中点连接采样多边形的坐标)的特征,[polynum,64,40]
  • 获取中心坐标特征[polynum,64,1],将中心点坐标特征与i_it_4py特征进行拼接,[polynum,128,40]
  • 传入fuse,获得i_init_feature: [polynum,64,40]
  • 将归一化的坐标合并 [polynum,66,40],传入snake,获得坐标偏移量[polynum,40,2]
  • 将坐标偏移量加上原始i_it_4py,并采样,获得预测的极值点坐标( 因为标注是将极值点作为gt):[polynum,4,2].
def init_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
    """
    Snake:Snake(state_dim=128,feature_dim=64+2,conv_type='dgrid'
    cnn_feature: size (80,64,128,128)
    i_it_poly: init['i_it_4py'] (199,40,2) 中点连接采样的多边形
    ind:       init['py_ind'] (199,)
    """
    if len(i_it_poly) == 0:
        return torch.zeros([0, 4, 2]).to(i_it_poly)

    h, w = cnn_feature.size(2), cnn_feature.size(3)
    init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w)  # size (199,64,40)
    center = (torch.min(i_it_poly, dim=1)[0] + torch.max(i_it_poly, dim=1)[0]) * 0.5  # 每列中最小的行和最大的行 size (199,2)
    ct_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, center[:, None], ind, h,w)  # poly中心的feature size(199,64,1)
    init_feature = torch.cat([init_feature, ct_feature.expand_as(init_feature)], dim=1) # size(199,128,40)
    init_feature = self.fuse(init_feature) #size(199,64,40)

    init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) #将归一化的坐标合并到特征 (199,64,40)+(199,2,40)-》(199,64,40)
    adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)
    i_poly = i_it_poly + snake(init_input, adj).permute(0, 2, 1) #这里adj实际上并没有用到 snake.permute -> (199,40,2)可以看作是偏移量
    i_poly = i_poly[:, ::snake_config.init_poly_num // 4] #size(199,4,2) 

    return i_poly



def get_gcn_feature(cnn_feature, img_poly, ind, h, w):
    img_poly = img_poly.clone()
    """
    对poly进行归一化
    直接除以w,h不就行了吗?为什么还要乘2-1? 
    img_poly[...,0]/w 取值范围为【0,1],此时*2-1 取值范围就变为了【-1,1】,可以应用到grid_sample中。
    """
    img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
    img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1

    batch_size = cnn_feature.size(0)
    """
    gcn_feature size : (多边形数量,特征维度,多边形顶点数)
    """
    gcn_feature = torch.zeros([img_poly.size(0), cnn_feature.size(1), img_poly.size(1)]).to(img_poly.device)
    for i in range(batch_size):
        poly = img_poly[ind == i].unsqueeze(0)
        feature = torch.nn.functional.grid_sample(cnn_feature[i:i + 1], poly)[0].permute(1, 0, 2)
        gcn_feature[ind == i] = feature

    return gcn_feature

这里跟init_poly很相似,不过没有将中心点特征进行融合;乘四倍是因为此时进行了4s下采样,后面gt也乘了四倍。

def evolve_poly(self, snake, cnn_feature, i_it_poly, c_it_poly, ind):
    if len(i_it_poly) == 0:
        return torch.zeros_like(i_it_poly)
    h, w = cnn_feature.size(2), cnn_feature.size(3)
    init_feature = snake_gcn_utils.get_gcn_feature(cnn_feature, i_it_poly, ind, h, w) #199,64,128
    c_it_poly = c_it_poly * snake_config.ro #将归一化的坐标乘以4倍
    init_input = torch.cat([init_feature, c_it_poly.permute(0, 2, 1)], dim=1) #将乘以四倍的归一化的坐标合并到特征
    adj = snake_gcn_utils.get_adj_ind(snake_config.adj_num, init_input.size(2), init_input.device)
    i_poly = i_it_poly * snake_config.ro + snake(init_input, adj).permute(0, 2, 1) #(199,128,2)
    return i_poly


3.2.2 损失函数

这里损失函数比较简单易懂:loss=ct_loss+0.1wh_loss+ex_loss+py_loss

class NetworkWrapper(nn.Module):
    def __init__(self, net):
        super(NetworkWrapper, self).__init__()

        self.net = net

        self.ct_crit = net_utils.FocalLoss()
        self.wh_crit = net_utils.IndL1Loss1d('smooth_l1')
        self.reg_crit = net_utils.IndL1Loss1d('smooth_l1')
        self.ex_crit = torch.nn.functional.smooth_l1_loss
        self.py_crit = torch.nn.functional.smooth_l1_loss

    def forward(self, batch):
        output = self.net(batch['inp'], batch)

        scalar_stats = {}
        loss = 0

        ct_loss = self.ct_crit(net_utils.sigmoid(output['ct_hm']), batch['ct_hm'])
        scalar_stats.update({'ct_loss': ct_loss})
        loss += ct_loss

        wh_loss = self.wh_crit(output['wh'], batch['wh'], batch['ct_ind'], batch['ct_01'])
        scalar_stats.update({'wh_loss': wh_loss})
        loss += 0.1 * wh_loss

        # reg_loss = self.reg_crit(output['reg'], batch['reg'], batch['ct_ind'], batch['ct_01'])
        # scalar_stats.update({'reg_loss': reg_loss})
        # loss += reg_loss

        ex_loss = self.ex_crit(output['ex_pred'], output['i_gt_4py'])
        scalar_stats.update({'ex_loss': ex_loss})
        loss += ex_loss

        py_loss = 0
        output['py_pred'] = [output['py_pred'][-1]]
        for i in range(len(output['py_pred'])):
            py_loss += self.py_crit(output['py_pred'][i], output['i_gt_py']) / len(output['py_pred'])
        scalar_stats.update({'py_loss': py_loss})
        loss += py_loss

        scalar_stats.update({'loss': loss})
        image_stats = {}

        return output, loss, scalar_stats, image_stats


3.2.3 训练过程

与训练过程中不同的地方主要在于:

  • 利用ouput[‘detection’]进行了bbox中点连接四边形的初始化,在训练中是直接用了标注信息;
  • 利用预测的ex建立八边形并采样,作为轮廓演化的初值,在训练中也是直接用了标注信息。
  • 最后输出的是ct值超过ct_score的多边形poly。
def forward(self, output, cnn_feature, batch=None):
    ret = output

    if not self.training:
        with torch.no_grad():
            """
            初始化,通过ct_score更新output['detection'] shape(>ct_score,6)
            将预测i_it_4py更新到it_ex:shape[>ct_score,40,2]
            """
            init = self.prepare_testing_init(output)
            ex = self.init_poly(self.init_gcn, cnn_feature, init['i_it_4py'], init['c_it_4py'], init['ind']) #[>ct_score,4,2]
            ret.update({'ex': ex})
			"""
			根据ex建立八边形并采样,获取轮廓it_py([>ct_score,128,2])更新进入output['it_py']
			evole则包含两个键:'i_it_py','c_it_py'shape与it_py相同
			"""
            evolve = self.prepare_testing_evolve(output, cnn_feature.size(2), cnn_feature.size(3))
            py = self.evolve_poly(self.evolve_gcn, cnn_feature, evolve['i_it_py'], evolve['c_it_py'], init['ind'])
            pys = [py / snake_config.ro]
            for i in range(self.iter):
                py = py / snake_config.ro
                c_py = snake_gcn_utils.img_poly_to_can_poly(py)
                evolve_gcn = self.__getattr__('evolve_gcn' + str(i))
                py = self.evolve_poly(evolve_gcn, cnn_feature, py, c_py, init['ind'])
                pys.append(py / snake_config.ro)
            ret.update({'py': pys})

    return output

test初始化,利用detection获取i_it_4py(ex)

def prepare_testing_init(self, output):
    """
    output['detection'] shape(batch,100,6)
    init通过detection获得超过ct_score的i_it_4py,c_it_4py,以及对应的索引ind
    output也通过ct_score筛选,并将i_it_4py更新入键值it_ex
    """
    init = snake_gcn_utils.prepare_testing_init(output['detection'][..., :4], output['detection'][..., 4])
    output['detection'] = output['detection'][output['detection'][..., 4] > snake_config.ct_score]
    output.update({'it_ex': init['i_it_4py']})
    return init

#snake_gcn_utils.repare_testing_init
def prepare_testing_init(box, score):
    """
    box=output['detection'][..., :4] shape(batch,100,4)
    score=output['detection'][..., 4] shape(batch,100)
    """
    i_it_4pys = snake_decode.get_init(box) #根据中点建立四边形 [batch,100,4,2]
    i_it_4pys = uniform_upsample(i_it_4pys, snake_config.init_poly_num) #采样40个点 [batch,100,40,2]
    c_it_4pys = img_poly_to_can_poly(i_it_4pys) #归一化

    ind = score > snake_config.ct_score #根据阈值筛选
    i_it_4pys = i_it_4pys[ind] #[过阈值数量,40,2]
    c_it_4pys = c_it_4pys[ind] #[过阈值数量,40,2]
    """
    ind shape[超过阈值数量,]
    for example:[0,0,0,1,2,2]表示第一个batch有3个超过阈值的,第二个batch有1个,第三个batch有2个。
    """
    ind = torch.cat([torch.full([ind[i].sum()], i) for i in range(ind.size(0))], dim=0)
    init = {'i_it_4py': i_it_4pys, 'c_it_4py': c_it_4pys, 'ind': ind}

    return init

test轮廓演化初始化,利用ex建立八边形轮廓

def prepare_testing_evolve(self, output, h, w):
    """
    根据极值点建立八边形,并进行采样,获得轮廓i_py ([ct_score,128,2])更新进output['it_py']
    """
    ex = output['ex'] #[>ct_score,4,2]
    ex[..., 0] = torch.clamp(ex[..., 0], min=0, max=w - 1)#越界修正
    ex[..., 1] = torch.clamp(ex[..., 1], min=0, max=h - 1)
    evolve = snake_gcn_utils.prepare_testing_evolve(ex) 
    output.update({'it_py': evolve['i_it_py']})
    return evolve
#snake_gcn_utils.prepare_testing_evolve
def prepare_testing_evolve(ex):
    if len(ex) == 0:
        i_it_pys = torch.zeros([0, snake_config.poly_num, 2]).to(ex)
        c_it_pys = torch.zeros_like(i_it_pys)
    else:
        i_it_pys = snake_decode.get_octagon(ex[None]) #根据极值点建立八边形 #[1,>ct_score,12,2]
        i_it_pys = uniform_upsample(i_it_pys, snake_config.poly_num)[0] #[>ct_score,128,2]
        c_it_pys = img_poly_to_can_poly(i_it_pys)
    evolve = {'i_it_py': i_it_pys, 'c_it_py': c_it_pys}
    return evolve



四、总结

  • 项目用了大量的高级封装调用起来很简单,阅读起来就比较麻烦,但也是让我学到了很多封装的技巧。
  • 由于项目的环境都比较老,迁移到本地的时候环境配置也需要调整,我也是花了不少时间搜索了不少资料才完成了cuda拓展的编译。
  • 该项目采用了DLA34作为特征提取网络,而没有接触过DLA的我第一次看DLA网络很头大,因为代码写的比较抽象吧,看懂backbone大概花了我一个星期吧,DLA系列看起来还挺有趣的,感兴趣的可以深入学习一下。
  • 项目数据集的处理和轮廓演化是相互配合的,需要结合一起学习。



版权声明:本文为qq_41396770原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。