Mask R-CNN代码分析(二)

  • Post author:
  • Post category:其他


二、Mask R-CNN代码解读

FAIR在发布detectron的同时,也发布了一系列的tutorial文件,接下来将根据

Detectron/GETTING_STARTED.md

文件来解读代码。

先来看一下detectron的文件结构。

  • config中是训练和测试的配置文件,官方的baseline的参数配置都以.yaml文件的形式存放在其中。
  • demo一些图像示例还有分割好的结果。
  • detectron核心程序,包括参数配置、数据集准备、模型、训练和测试的一些工具,都存放在其中。
  • tools运行模型时调用的工具,包括推断、训练、测试等。

使用预训练好的模型进行推断

1.目录中的图片

推断目录中的图片使用tools/infer_simple.py工具,命令如下:

python2 tools/infer_simple.py \
    --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml \
    --output-dir /tmp/detectron-visualizations \
    --image-ext jpg \
    --wts https://s3-us-west-2.amazonaws.com/detectron/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl \
    demo

–cfg是之前提到的配置文件,detectron在运行程序时首先导入存放在core/config.py的所有参数的默认值,然后在调用函数merge_cfg_from_file(args.cfg),将–cfg参数引用的配置文件中存放的参数将默认值替换。举个例子,在config.py中关于数据集中的类别数有默认的定义:

# Number of classes in the dataset; must be set
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1

这显然是一个默认值,需要我们在–cfg的配置文件中重新设置。故在configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml中有:

MODEL:
  TYPE: generalized_rcnn
  CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
  NUM_CLASSES: 81
  FASTER_RCNN: True
  MASK_ON: True

merge_cfg_from_file()函数的功能就是将cfg文件中的MODEL.NUM_CLASSES的值(=81)替换config.py中的__C.MODEL.NUM_CLASSES(=-1)。

–image-ext是输出图像的后缀。

–wts是模型的参数文件,其实也就意味着是训练好可以拿来直接使用的模型。这里给的是一个地址,是官方训练好上传到亚马逊云上的模型,因为这样下载会很慢,所以也可以提前下载好(用迅雷)存放在本地,将–wts参数替换为本地的地址。在运行程序中,会检测–wts后的参数是网址还是地址,自动调取模型文件。

接下来就来看infer_simple.py的代码。

if __name__ == '__main__':
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) # 对工作区的全局初始化
    setup_logging(__name__) # 日志设置
    args = parse_args() # 参数读取
    main(args)

对于日志的设置,要事先导入logging模块。

from detectron.utils.logging import setup_logging

setup_logging定义如下:

def setup_logging(name):
    FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s'
    # %(levelname)s: 打印日志级别名称
    # %(filename)s: 打印当前执行程序名
    # %(lineno)d: 打印日志的当前行号
    # %(message)s: 打印日志信息

    # Manually clear root loggers to prevent any module that may have called
    # logging.basicConfig() from blocking our logging setup
    logging.root.handlers = []
    logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
    logger = logging.getLogger(name)
    return logger

读取参数:

# 命令行解析模块
def parse_args():
    # 创建解析器
    # 创建一个ArgumentParser实例,ArgumentParser的参数都为关键字参数
    # description :help信息前显示的信息
    parser = argparse.ArgumentParser(description='End-to-end inference')
    # 添加参数选项 :add_argument
    # name or flags :参数有两种,可选参数和位置参数
    # dest :参数名
    # default :默认值
    # type :参数类型,默认为str
    parser.add_argument(
        '--cfg',
        dest='cfg',
        help='cfg model file (/path/to/model_config.yaml)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--wts',
        dest='weights',
        help='weights model file (/path/to/model_weights.pkl)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--output-dir',
        dest='output_dir',
        help='directory for visualization pdfs (default: /tmp/infer_simple)',
        default='/tmp/infer_simple',
        type=str
    )
    parser.add_argument(
        '--image-ext',
        dest='image_ext',
        help='image file name extension (default: jpg)',
        default='jpg',
        type=str
    )
    parser.add_argument(
        '--always-out',
        dest='out_when_no_box',
        help='output image even when no object is found',
        # action参数指定应该如何处理命令行参数,
        # action='store'仅仅保存参数值,为action默认值,
        # action='store_true'或'store_false'只保存True和False
        action='store_true'
    )
    parser.add_argument(
        'im_or_folder', help='image or folder of images', default=None
    )
    parser.add_argument(
        '--output-ext',
        dest='output_ext',
        help='output image file format (default: pdf)',
        default='pdf',
        type=str
    )
    if len(sys.argv) == 1:  # 如果sys.srgv的长度为1,说明运行时的输入只有文件名而没有参数
        parser.print_help() # 这时便打印命令行解析模块定义的help信息,帮助输入参数
        sys.exit(1)         # 引发一个异常然后退出程序
    return parser.parse_args() # 进行解析

接下来是主函数:

def main(args): # 读取的参数被传入主函数
    logger = logging.getLogger(__name__)

    merge_cfg_from_file(args.cfg) # 前面提到的如何将yaml文件中的参数作用到程序中,
                                  # 详解见下方代码段

详解merge_cfg_from_file(args.cfg),首先是导入该函数,位于detectron/core/config.py中。

from detectron.core.config import merge_cfg_from_file
def merge_cfg_from_file(cfg_filename):
    """Load a yaml config file and merge it into the global config."""
    """加载一个YAML配置文件并将其合并到全局配置文件中。"""
    with open(cfg_filename, 'r') as f:
        yaml_cfg = AttrDict(load_cfg(f)) 
    """A simple attribute dictionary used for representing configuration options."""
    """AttrDict类:用于表示配置选项的简单属性字典。"""
    _merge_a_into_b(yaml_cfg, __C)

上面函数中使用的load_cfg()函数:

def load_cfg(cfg_to_load):
    """Wrapper around yaml.load used for maintaining backward compatibility"""
    assert isinstance(cfg_to_load, (file, basestring)), \ # 判断cfg_to_load的类型
                                # 是否为文件或者基础字符串,如果都不是的话则引发一个异常。
        'Expected {} or {} got {}'.format(file, basestring, type(cfg_to_load))
    if isinstance(cfg_to_load, file):
        cfg_to_load = ''.join(cfg_to_load.readlines()) # list转为str
    if isinstance(cfg_to_load, basestring):
        for old_module, new_module in iteritems(_RENAMED_MODULES):
            # yaml object encoding: !!python/object/new:<module>.<object>
            old_module, new_module = 'new:' + old_module, 'new:' + new_module
            cfg_to_load = cfg_to_load.replace(old_module, new_module)
    return yaml.load(cfg_to_load) # python读取yaml文件以字典形式存放,所以这里返回的是一个字典

最后是_merge_a_into_b(yaml_cfg, __C)将yaml中的参数通过

递归

的方式替换config.py中的默认值。

def _merge_a_into_b(a, b, stack=None): # 将yaml中的参数加入到global参数中替换默认值
    """Merge config dictionary a into config dictionary b, clobbering the
    options in b whenever they are also specified in a.
    """
    assert isinstance(a, AttrDict), \
        '`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict)
    assert isinstance(b, AttrDict), \
        '`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict)

    for k, v_ in a.items():
        full_key = '.'.join(stack) + '.' + k if stack is not None else k
        # a must specify keys that are in b
        if k not in b:
            if _key_is_deprecated(full_key): # full_key是否是被弃用的键名
                continue
            elif _key_is_renamed(full_key): # full_key是否是更改了名字的键名
                _raise_key_rename_error(full_key)
            else:
                raise KeyError('Non-existent config key: {}'.format(full_key))

        v = copy.deepcopy(v_) # deepcopy一个新的v
        v = _decode_cfg_value(v) # 将从yaml文件读取的原始值变为python对象
        v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) # 检查和对应要替换的
                                                                   # 值是否在类型上一致
        # Recursively merge dicts
        if isinstance(v, AttrDict):
            try:
                stack_push = [k] if stack is None else stack + [k]
                _merge_a_into_b(v, b[k], stack=stack_push) # 递归调用,见yaml文件的结构
            except BaseException:
                raise
        else:
            b[k] = v

终于分析完了merge_cfg_from_file()回过头来发现infer_simple.py的main()函数才分析到第二行。继续分析main():

def main(args): # 读取的参数被传入主函数
    logger = logging.getLogger(__name__)

    merge_cfg_from_file(args.cfg) # 前面提到的如何将yaml文件中的参数作用到程序中,
                                  # 详解见下方代码段
    cfg.NUM_GPUS = 1 # 可以通过这种方式直接设置全局的cfg,此处是设置使用的GPU数量为1.
    args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
    """对于cache_url函数,官方注释中写道:Download the file specified by the URL to the          
    cache_dir and return the path to the cached file. If the argument is not a URL,
    simply return it as is. 一目了然,故这里不作展开。
    具体代码见 : detectron.utils.io.cache_url
    """
    assert_and_infer_cfg(cache_urls=False)

    assert not cfg.MODEL.RPN_ONLY, \
        'RPN models are not supported'
    assert not cfg.TEST.PRECOMPUTED_PROPOSALS, \
        'Models that require precomputed proposals are not supported'
    # 终于到了加载模型的时候
    model = infer_engine.initialize_model_from_cfg(args.weights)
    dummy_coco_dataset = dummy_datasets.get_coco_dataset()

    if os.path.isdir(args.im_or_folder):
        im_list = glob.iglob(args.im_or_folder + '/*.' + args.image_ext)
    else:
        im_list = [args.im_or_folder]

    for i, im_name in enumerate(im_list):
        out_name = os.path.join(
            args.output_dir, '{}'.format(os.path.basename(im_name) + '.' + args.output_ext)
        )
        logger.info('Processing {} -> {}'.format(im_name, out_name))
        im = cv2.imread(im_name)
        timers = defaultdict(Timer)
        t = time.time()
        with c2_utils.NamedCudaScope(0):
            cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
                model, im, None, timers=timers
            )
        logger.info('Inference time: {:.3f}s'.format(time.time() - t))
        for k, v in timers.items():
            logger.info(' | {}: {:.3f}s'.format(k, v.average_time))
        if i == 0:
            logger.info(
                ' \ Note: inference on the first image will be slower than the '
                'rest (caches and auto-tuning need to warm up)'
            )

        vis_utils.vis_one_image(
            im[:, :, ::-1],  # BGR -> RGB for visualization
            im_name,
            args.output_dir,
            cls_boxes,
            cls_segms,
            cls_keyps,
            dataset=dummy_coco_dataset,
            box_alpha=0.3,
            show_class=True,
            thresh=0.7,
            kp_thresh=2,
            ext=args.output_ext,
            out_when_no_box=args.out_when_no_box
        )

加载模型的语句:

model = infer_engine.initialize_model_from_cfg(args.weights)

其中,infer_engine是下面的包导入的:

import detectron.core.test_engine as infer_engine

打开detectron/core/test_engine.py,找到initialize_model_from_cfg()函数。

def initialize_model_from_cfg(weights_file, gpu_id=0):
    """Initialize a model from the global cfg. Loads test-time weights and
    creates the networks in the Caffe2 workspace.
    从全局cfg初始化模型,在caffe2工作空间中加载测试时所用权值并创建网络。
    """
    model = model_builder.create(cfg.MODEL.TYPE, train=False, gpu_id=gpu_id)
    net_utils.initialize_gpu_from_weights_file(
        model, weights_file, gpu_id=gpu_id,
    )
    model_builder.add_inference_inputs(model)
    workspace.CreateNet(model.net)
    workspace.CreateNet(model.conv_body_net)
    if cfg.MODEL.MASK_ON:
        workspace.CreateNet(model.mask_net)
    if cfg.MODEL.KEYPOINTS_ON:
        workspace.CreateNet(model.keypoint_net)
    return model

其中,创建模型所使用的model_builder.create代码如下:(位于modeling/model_builder.py)

def create(model_type_func, train=False, gpu_id=0):
    """Generic model creation function that dispatches to specific model
    building functions.
    通用模型创建功能,用于特定的模型构建功能。并且可以选择gpu的个数和编号。
    By default, this function will generate a data parallel model configured to
    run on cfg.NUM_GPUS devices. However, you can restrict it to build a model
    targeted to a specific GPU by specifying gpu_id. This is used by
    optimizer.build_data_parallel_model() during test time.
    """
    model = DetectionModelHelper(
        name=model_type_func,
        train=train,
        num_classes=cfg.MODEL.NUM_CLASSES,
        init_params=train
    )
    model.only_build_forward_pass = False
    model.target_gpu_id = gpu_id
    return get_func(model_type_func)(model)

这里边模型的创建使用了DetectionModelHelper类,位于detectron/modeling/detector.py

from detectron.modeling.detector import DetectionModelHelper

这个类的实现代码太长了这里就不贴全部了,只贴一下其中的__init__,如下:

class DetectionModelHelper(cnn.CNNModelHelper): # 继承自cnn.CNNModelHelper超类
    def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper
        self.train = kwargs.get('train', False) # get() 函数返回指定键的值,如果值不在字典中返回默认值
        self.num_classes = kwargs.get('num_classes', -1)
        assert self.num_classes > 0, 'num_classes must be > 0'
        for k in ('train', 'num_classes'):
            if k in kwargs:
                del kwargs[k]
        kwargs['order'] = 'NCHW'
        # Defensively set cudnn_exhaustive_search to False in case the default
        # changes in CNNModelHelper. The detection code uses variable size
        # inputs that might not play nicely with cudnn_exhaustive_search.
        kwargs['cudnn_exhaustive_search'] = False
        super(DetectionModelHelper, self).__init__(**kwargs)
        self.roi_data_loader = None
        self.losses = []
        self.metrics = []
        self.do_not_update_params = []  # Param on this list are not updated
        self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
        self.net.Proto().num_workers = cfg.NUM_GPUS * 4
        self.prev_use_cudnn = self.use_cudnn
        self.gn_params = []  # Param on this list are GroupNorm parameters

结合上边对DetectionModelHelper的调用,

# infer_simple.py的main(args)中
model = infer_engine.initialize_model_from_cfg(args.weights)
# test_engine.py的initialize_model_from_cfg(weights_file, gpu_id=0)中
model = model_builder.create(cfg.MODEL.TYPE, train=False, gpu_id=gpu_id)
# model_builder.py的create(model_type_func, train=False, gpu_id=0)中
model = DetectionModelHelper(
        name=model_type_func, # 模型名字来自于cfg.MODEL.TYPE = generalized_rcnn
        train=train, # train = train,由于是推断,所以这里是False
        num_classes=cfg.MODEL.NUM_CLASSES, # MS COCO数据集,NUM_CLASSES = 81
        init_params=train # 同样因为推断,使用训练好的模型,init_params = False
    )
# detector.py的class DetectionModelHelper()中
    def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper
        self.train = kwargs.get('train', False)
        self.num_classes = kwargs.get('num_classes', -1)

DetectionModelHelper继承自cnn.CNNModelHelper,cnn.CNNModelHelper又继承自ModelHelper,ModelHelper是caffe2为了方便的构造网络而编写的一个类。继承关系如下:

class DetectionModelHelper(cnn.CNNModelHelper):
    def __init__(self, **kwargs):
        # Handle args specific to the DetectionModelHelper, others pass through
        # to CNNModelHelper

class CNNModelHelper(ModelHelper):
    """A helper model so we can write CNN models more easily, without having to
    manually define parameter initializations and operators separately.
    """

class ModelHelper(object):
    """A helper model so we can manange models more easily. It contains net def
    and parameter storages. You can add an Operator yourself, e.g.

        model = model_helper.ModelHelper(name="train_net")
        # init your weight and bias as w and b
        w = model.param_init_net.XavierFill(...)
        b = model.param_init_net.ConstantFill(...)
        fc1 = model.FC([input, w, b], output, **kwargs)

    or you can use helper functions in brew module without manually
    defining parameter initializations and operators.

        model = model_helper.ModelHelper(name="train_net")
        fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs)

    """

回到model_builder.py,在使用DetectionModelHelper实例化了一个model对象后,又设置了only_build_forward_pass和target_gpu_id两个参数,根据命名能清楚知道它们的含义。对这两个参数的使用均在detectron/modeling/optimizer.py中,具体原理在此不做深究。

    model.only_build_forward_pass = False
    model.target_gpu_id = gpu_id

最后,返回创建的模型。

   return get_func(model_type_func)(model)

get_func函数:根据名字返回函数对象。

def get_func(func_name): # 传入参数为模型类型,如:generalized_rcnn
    """Helper to return a function object by name. func_name must identify a
    function in this module or the path to a function relative to the base
    'modeling' module.
    """
    if func_name == '':
        return None
    new_func_name = name_compat.get_new_name(func_name)
    """name_compat.py中存放着修改过名称的模型,
    其中的get_new_name()是根据旧名称得到新名称,
    如果确实发现传入参数的名称是之前修改过的,则发出警告。
    """
    if new_func_name != func_name:
        logger.warn(
            'Remapping old function name: {} -> {}'.
            format(func_name, new_func_name)
        )
        func_name = new_func_name
    try:
        parts = func_name.split('.')
        # Refers to a function in this module
        if len(parts) == 1:
            return globals()[parts[0]] # 全局变量(一个字典)中的part[0]键名对应的键字
        # Otherwise, assume we're referencing a module under modeling
        module_name = 'detectron.modeling.' + '.'.join(parts[:-1]) # 取最前面的一个(第一个'.'之前)
        module = importlib.import_module(module_name) # 动态导入模块
        return getattr(module, parts[-1]) # 返回对应的属性
    except Exception:
        logger.error('Failed to find function: {}'.format(func_name))
        raise

了解了get_func(func_name)以后,我们再回过头来看create()的最后一句return。

   return get_func(model_type_func)(model)

举个例子,比如这里我的model_type_func = generalized_rcnn,那么get_func(model_type_func)所返回的就是名字为generalized_rcnn的函数对象。 generalized_rcnn()同样定义在model_builder.py中,这样上面一行代码就等价于:

    return generalized_rcnn(model)

generalized_rcnn()定义如下:

def generalized_rcnn(model):
    """This model type handles:
      - Fast R-CNN
      - RPN only (not integrated with Fast R-CNN)
      - Faster R-CNN (stagewise training from NIPS paper)
      - Faster R-CNN (end-to-end joint training)
      - Mask R-CNN (stagewise training from NIPS paper)
      - Mask R-CNN (end-to-end joint training)
    """
    return build_generic_detection_model(
        model,
        get_func(cfg.MODEL.CONV_BODY),
        add_roi_box_head_func=get_func(cfg.FAST_RCNN.ROI_BOX_HEAD),
        add_roi_mask_head_func=get_func(cfg.MRCNN.ROI_MASK_HEAD),
        add_roi_keypoint_head_func=get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD),
        freeze_conv_body=cfg.TRAIN.FREEZE_CONV_BODY
    )

再来看其中用到的build_generic_detection_model()函数:(太长只截取函数名和参数部分)

def build_generic_detection_model(
    model,
    add_conv_body_func,
    add_roi_box_head_func=None,
    add_roi_mask_head_func=None,
    add_roi_keypoint_head_func=None,
    freeze_conv_body=False
):

这时候我们就可以来解释detectron对于物体检测模型的构成原理了。对于一般的模型构建,detectron均遵循TYPE,BODY, HEAD三大部分的方式来构建,这也和yaml文件中MODEL中的参数设置是一致的。比如构建一个Fast R-CNN模型,使用ResNet-50-C4的主干网络:(截取自官方注释)

为什么要这样来构建,这里参考论文结构。

"""
    Generic recomposable model builders

    For example, you can create a Fast R-CNN model with the ResNet-50-C4 backbone
    with the configuration:

    MODEL:
        TYPE: generalized_rcnn
        CONV_BODY: ResNet.add_ResNet50_conv4_body
        ROI_HEAD: ResNet.add_ResNet_roi_conv5_head
"""

根据MODEL中的这个对网络结构的配置,再使用get_func(func_name)通过函数对象名返回具体的代表网络结构的函数,就可以构成一个完整的网络了。所以说detectron在最后一步组成网络上几乎可以说是一步到位的。而它麻烦的地方也正是在于,为了最后一步能够直接根据yaml中的配置像拼积木一样拼成任意我们想要的网络,前面所做的各种准备工作量相当巨大。不过从另一个角度说,官方已经给出了在图像检测领域所有的积木块,那我们只需要去使用就好了。比如我可以去定义一个新的YAML文件,随意拼一个结构出来进行训练和测试。但是我觉得既然官方发布了一个这么完备的平台,那么可以做的排列组合应该都已经被官方实验过了。从config/12_2017_baselines里众多的配置文件就可以看出来官方所做的实验量还是十分巨大的。

build_generic_detection_model()中的参数就是函数对象的方法,对于conv_body,box_head和mask_head都是函数的不断调用。我们来看一个例子,add_conv_body_func,假设cfg.MODEL.CONV_BODY: ResNet.add_ResNet50_conv4_body,则add_conv_body_func = add_ResNet50_conv4_body。在modeling/ResNet.py中:

def add_ResNet50_conv4_body(model):
    return add_ResNet_convX_body(model, (3, 4, 6))

add_ResNet_convX_body:

def add_ResNet_convX_body(model, block_counts):
    """Add a ResNet body from input data up through the res5 (aka conv5) stage.
    The final res5/conv5 stage may be optionally excluded (hence convX, where
    X = 4 or 5)."""
    freeze_at = cfg.TRAIN.FREEZE_AT
    assert freeze_at in [0, 2, 3, 4, 5]

    # add the stem (by default, conv1 and pool1 with bn; can support gn)
    p, dim_in = globals()[cfg.RESNETS.STEM_FUNC](model, 'data')

    dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP
    (n1, n2, n3) = block_counts[:3]
    s, dim_in = add_stage(model, 'res2', p, n1, dim_in, 256, dim_bottleneck, 1)
    if freeze_at == 2:
        model.StopGradient(s, s)
    s, dim_in = add_stage(
        model, 'res3', s, n2, dim_in, 512, dim_bottleneck * 2, 1
    )
    if freeze_at == 3:
        model.StopGradient(s, s)
    s, dim_in = add_stage(
        model, 'res4', s, n3, dim_in, 1024, dim_bottleneck * 4, 1
    )
    if freeze_at == 4:
        model.StopGradient(s, s)
    if len(block_counts) == 4:
        n4 = block_counts[3]
        s, dim_in = add_stage(
            model, 'res5', s, n4, dim_in, 2048, dim_bottleneck * 8,
            cfg.RESNETS.RES5_DILATION
        )
        if freeze_at == 5:
            model.StopGradient(s, s)
        return s, dim_in, 1. / 32. * cfg.RESNETS.RES5_DILATION
    else:
        return s, dim_in, 1. / 16.

(未完待续)



版权声明:本文为lscelory原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。