二、Mask R-CNN代码解读
FAIR在发布detectron的同时,也发布了一系列的tutorial文件,接下来将根据
Detectron/GETTING_STARTED.md
文件来解读代码。
先来看一下detectron的文件结构。
- config中是训练和测试的配置文件,官方的baseline的参数配置都以.yaml文件的形式存放在其中。
- demo一些图像示例还有分割好的结果。
- detectron核心程序,包括参数配置、数据集准备、模型、训练和测试的一些工具,都存放在其中。
- tools运行模型时调用的工具,包括推断、训练、测试等。
使用预训练好的模型进行推断
1.目录中的图片
推断目录中的图片使用tools/infer_simple.py工具,命令如下:
python2 tools/infer_simple.py \
--cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml \
--output-dir /tmp/detectron-visualizations \
--image-ext jpg \
--wts https://s3-us-west-2.amazonaws.com/detectron/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl \
demo
–cfg是之前提到的配置文件,detectron在运行程序时首先导入存放在core/config.py的所有参数的默认值,然后在调用函数merge_cfg_from_file(args.cfg),将–cfg参数引用的配置文件中存放的参数将默认值替换。举个例子,在config.py中关于数据集中的类别数有默认的定义:
# Number of classes in the dataset; must be set
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1
这显然是一个默认值,需要我们在–cfg的配置文件中重新设置。故在configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml中有:
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
merge_cfg_from_file()函数的功能就是将cfg文件中的MODEL.NUM_CLASSES的值(=81)替换config.py中的__C.MODEL.NUM_CLASSES(=-1)。
–image-ext是输出图像的后缀。
–wts是模型的参数文件,其实也就意味着是训练好可以拿来直接使用的模型。这里给的是一个地址,是官方训练好上传到亚马逊云上的模型,因为这样下载会很慢,所以也可以提前下载好(用迅雷)存放在本地,将–wts参数替换为本地的地址。在运行程序中,会检测–wts后的参数是网址还是地址,自动调取模型文件。
接下来就来看infer_simple.py的代码。
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) # 对工作区的全局初始化
setup_logging(__name__) # 日志设置
args = parse_args() # 参数读取
main(args)
对于日志的设置,要事先导入logging模块。
from detectron.utils.logging import setup_logging
setup_logging定义如下:
def setup_logging(name):
FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s'
# %(levelname)s: 打印日志级别名称
# %(filename)s: 打印当前执行程序名
# %(lineno)d: 打印日志的当前行号
# %(message)s: 打印日志信息
# Manually clear root loggers to prevent any module that may have called
# logging.basicConfig() from blocking our logging setup
logging.root.handlers = []
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(name)
return logger
读取参数:
# 命令行解析模块
def parse_args():
# 创建解析器
# 创建一个ArgumentParser实例,ArgumentParser的参数都为关键字参数
# description :help信息前显示的信息
parser = argparse.ArgumentParser(description='End-to-end inference')
# 添加参数选项 :add_argument
# name or flags :参数有两种,可选参数和位置参数
# dest :参数名
# default :默认值
# type :参数类型,默认为str
parser.add_argument(
'--cfg',
dest='cfg',
help='cfg model file (/path/to/model_config.yaml)',
default=None,
type=str
)
parser.add_argument(
'--wts',
dest='weights',
help='weights model file (/path/to/model_weights.pkl)',
default=None,
type=str
)
parser.add_argument(
'--output-dir',
dest='output_dir',
help='directory for visualization pdfs (default: /tmp/infer_simple)',
default='/tmp/infer_simple',
type=str
)
parser.add_argument(
'--image-ext',
dest='image_ext',
help='image file name extension (default: jpg)',
default='jpg',
type=str
)
parser.add_argument(
'--always-out',
dest='out_when_no_box',
help='output image even when no object is found',
# action参数指定应该如何处理命令行参数,
# action='store'仅仅保存参数值,为action默认值,
# action='store_true'或'store_false'只保存True和False
action='store_true'
)
parser.add_argument(
'im_or_folder', help='image or folder of images', default=None
)
parser.add_argument(
'--output-ext',
dest='output_ext',
help='output image file format (default: pdf)',
default='pdf',
type=str
)
if len(sys.argv) == 1: # 如果sys.srgv的长度为1,说明运行时的输入只有文件名而没有参数
parser.print_help() # 这时便打印命令行解析模块定义的help信息,帮助输入参数
sys.exit(1) # 引发一个异常然后退出程序
return parser.parse_args() # 进行解析
接下来是主函数:
def main(args): # 读取的参数被传入主函数
logger = logging.getLogger(__name__)
merge_cfg_from_file(args.cfg) # 前面提到的如何将yaml文件中的参数作用到程序中,
# 详解见下方代码段
详解merge_cfg_from_file(args.cfg),首先是导入该函数,位于detectron/core/config.py中。
from detectron.core.config import merge_cfg_from_file
def merge_cfg_from_file(cfg_filename):
"""Load a yaml config file and merge it into the global config."""
"""加载一个YAML配置文件并将其合并到全局配置文件中。"""
with open(cfg_filename, 'r') as f:
yaml_cfg = AttrDict(load_cfg(f))
"""A simple attribute dictionary used for representing configuration options."""
"""AttrDict类:用于表示配置选项的简单属性字典。"""
_merge_a_into_b(yaml_cfg, __C)
上面函数中使用的load_cfg()函数:
def load_cfg(cfg_to_load):
"""Wrapper around yaml.load used for maintaining backward compatibility"""
assert isinstance(cfg_to_load, (file, basestring)), \ # 判断cfg_to_load的类型
# 是否为文件或者基础字符串,如果都不是的话则引发一个异常。
'Expected {} or {} got {}'.format(file, basestring, type(cfg_to_load))
if isinstance(cfg_to_load, file):
cfg_to_load = ''.join(cfg_to_load.readlines()) # list转为str
if isinstance(cfg_to_load, basestring):
for old_module, new_module in iteritems(_RENAMED_MODULES):
# yaml object encoding: !!python/object/new:<module>.<object>
old_module, new_module = 'new:' + old_module, 'new:' + new_module
cfg_to_load = cfg_to_load.replace(old_module, new_module)
return yaml.load(cfg_to_load) # python读取yaml文件以字典形式存放,所以这里返回的是一个字典
最后是_merge_a_into_b(yaml_cfg, __C)将yaml中的参数通过
递归
的方式替换config.py中的默认值。
def _merge_a_into_b(a, b, stack=None): # 将yaml中的参数加入到global参数中替换默认值
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
assert isinstance(a, AttrDict), \
'`a` (cur type {}) must be an instance of {}'.format(type(a), AttrDict)
assert isinstance(b, AttrDict), \
'`b` (cur type {}) must be an instance of {}'.format(type(b), AttrDict)
for k, v_ in a.items():
full_key = '.'.join(stack) + '.' + k if stack is not None else k
# a must specify keys that are in b
if k not in b:
if _key_is_deprecated(full_key): # full_key是否是被弃用的键名
continue
elif _key_is_renamed(full_key): # full_key是否是更改了名字的键名
_raise_key_rename_error(full_key)
else:
raise KeyError('Non-existent config key: {}'.format(full_key))
v = copy.deepcopy(v_) # deepcopy一个新的v
v = _decode_cfg_value(v) # 将从yaml文件读取的原始值变为python对象
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) # 检查和对应要替换的
# 值是否在类型上一致
# Recursively merge dicts
if isinstance(v, AttrDict):
try:
stack_push = [k] if stack is None else stack + [k]
_merge_a_into_b(v, b[k], stack=stack_push) # 递归调用,见yaml文件的结构
except BaseException:
raise
else:
b[k] = v
终于分析完了merge_cfg_from_file()回过头来发现infer_simple.py的main()函数才分析到第二行。继续分析main():
def main(args): # 读取的参数被传入主函数
logger = logging.getLogger(__name__)
merge_cfg_from_file(args.cfg) # 前面提到的如何将yaml文件中的参数作用到程序中,
# 详解见下方代码段
cfg.NUM_GPUS = 1 # 可以通过这种方式直接设置全局的cfg,此处是设置使用的GPU数量为1.
args.weights = cache_url(args.weights, cfg.DOWNLOAD_CACHE)
"""对于cache_url函数,官方注释中写道:Download the file specified by the URL to the
cache_dir and return the path to the cached file. If the argument is not a URL,
simply return it as is. 一目了然,故这里不作展开。
具体代码见 : detectron.utils.io.cache_url
"""
assert_and_infer_cfg(cache_urls=False)
assert not cfg.MODEL.RPN_ONLY, \
'RPN models are not supported'
assert not cfg.TEST.PRECOMPUTED_PROPOSALS, \
'Models that require precomputed proposals are not supported'
# 终于到了加载模型的时候
model = infer_engine.initialize_model_from_cfg(args.weights)
dummy_coco_dataset = dummy_datasets.get_coco_dataset()
if os.path.isdir(args.im_or_folder):
im_list = glob.iglob(args.im_or_folder + '/*.' + args.image_ext)
else:
im_list = [args.im_or_folder]
for i, im_name in enumerate(im_list):
out_name = os.path.join(
args.output_dir, '{}'.format(os.path.basename(im_name) + '.' + args.output_ext)
)
logger.info('Processing {} -> {}'.format(im_name, out_name))
im = cv2.imread(im_name)
timers = defaultdict(Timer)
t = time.time()
with c2_utils.NamedCudaScope(0):
cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
model, im, None, timers=timers
)
logger.info('Inference time: {:.3f}s'.format(time.time() - t))
for k, v in timers.items():
logger.info(' | {}: {:.3f}s'.format(k, v.average_time))
if i == 0:
logger.info(
' \ Note: inference on the first image will be slower than the '
'rest (caches and auto-tuning need to warm up)'
)
vis_utils.vis_one_image(
im[:, :, ::-1], # BGR -> RGB for visualization
im_name,
args.output_dir,
cls_boxes,
cls_segms,
cls_keyps,
dataset=dummy_coco_dataset,
box_alpha=0.3,
show_class=True,
thresh=0.7,
kp_thresh=2,
ext=args.output_ext,
out_when_no_box=args.out_when_no_box
)
加载模型的语句:
model = infer_engine.initialize_model_from_cfg(args.weights)
其中,infer_engine是下面的包导入的:
import detectron.core.test_engine as infer_engine
打开detectron/core/test_engine.py,找到initialize_model_from_cfg()函数。
def initialize_model_from_cfg(weights_file, gpu_id=0):
"""Initialize a model from the global cfg. Loads test-time weights and
creates the networks in the Caffe2 workspace.
从全局cfg初始化模型,在caffe2工作空间中加载测试时所用权值并创建网络。
"""
model = model_builder.create(cfg.MODEL.TYPE, train=False, gpu_id=gpu_id)
net_utils.initialize_gpu_from_weights_file(
model, weights_file, gpu_id=gpu_id,
)
model_builder.add_inference_inputs(model)
workspace.CreateNet(model.net)
workspace.CreateNet(model.conv_body_net)
if cfg.MODEL.MASK_ON:
workspace.CreateNet(model.mask_net)
if cfg.MODEL.KEYPOINTS_ON:
workspace.CreateNet(model.keypoint_net)
return model
其中,创建模型所使用的model_builder.create代码如下:(位于modeling/model_builder.py)
def create(model_type_func, train=False, gpu_id=0):
"""Generic model creation function that dispatches to specific model
building functions.
通用模型创建功能,用于特定的模型构建功能。并且可以选择gpu的个数和编号。
By default, this function will generate a data parallel model configured to
run on cfg.NUM_GPUS devices. However, you can restrict it to build a model
targeted to a specific GPU by specifying gpu_id. This is used by
optimizer.build_data_parallel_model() during test time.
"""
model = DetectionModelHelper(
name=model_type_func,
train=train,
num_classes=cfg.MODEL.NUM_CLASSES,
init_params=train
)
model.only_build_forward_pass = False
model.target_gpu_id = gpu_id
return get_func(model_type_func)(model)
这里边模型的创建使用了DetectionModelHelper类,位于detectron/modeling/detector.py
from detectron.modeling.detector import DetectionModelHelper
这个类的实现代码太长了这里就不贴全部了,只贴一下其中的__init__,如下:
class DetectionModelHelper(cnn.CNNModelHelper): # 继承自cnn.CNNModelHelper超类
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
self.train = kwargs.get('train', False) # get() 函数返回指定键的值,如果值不在字典中返回默认值
self.num_classes = kwargs.get('num_classes', -1)
assert self.num_classes > 0, 'num_classes must be > 0'
for k in ('train', 'num_classes'):
if k in kwargs:
del kwargs[k]
kwargs['order'] = 'NCHW'
# Defensively set cudnn_exhaustive_search to False in case the default
# changes in CNNModelHelper. The detection code uses variable size
# inputs that might not play nicely with cudnn_exhaustive_search.
kwargs['cudnn_exhaustive_search'] = False
super(DetectionModelHelper, self).__init__(**kwargs)
self.roi_data_loader = None
self.losses = []
self.metrics = []
self.do_not_update_params = [] # Param on this list are not updated
self.net.Proto().type = cfg.MODEL.EXECUTION_TYPE
self.net.Proto().num_workers = cfg.NUM_GPUS * 4
self.prev_use_cudnn = self.use_cudnn
self.gn_params = [] # Param on this list are GroupNorm parameters
结合上边对DetectionModelHelper的调用,
# infer_simple.py的main(args)中
model = infer_engine.initialize_model_from_cfg(args.weights)
# test_engine.py的initialize_model_from_cfg(weights_file, gpu_id=0)中
model = model_builder.create(cfg.MODEL.TYPE, train=False, gpu_id=gpu_id)
# model_builder.py的create(model_type_func, train=False, gpu_id=0)中
model = DetectionModelHelper(
name=model_type_func, # 模型名字来自于cfg.MODEL.TYPE = generalized_rcnn
train=train, # train = train,由于是推断,所以这里是False
num_classes=cfg.MODEL.NUM_CLASSES, # MS COCO数据集,NUM_CLASSES = 81
init_params=train # 同样因为推断,使用训练好的模型,init_params = False
)
# detector.py的class DetectionModelHelper()中
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
self.train = kwargs.get('train', False)
self.num_classes = kwargs.get('num_classes', -1)
DetectionModelHelper继承自cnn.CNNModelHelper,cnn.CNNModelHelper又继承自ModelHelper,ModelHelper是caffe2为了方便的构造网络而编写的一个类。继承关系如下:
class DetectionModelHelper(cnn.CNNModelHelper):
def __init__(self, **kwargs):
# Handle args specific to the DetectionModelHelper, others pass through
# to CNNModelHelper
class CNNModelHelper(ModelHelper):
"""A helper model so we can write CNN models more easily, without having to
manually define parameter initializations and operators separately.
"""
class ModelHelper(object):
"""A helper model so we can manange models more easily. It contains net def
and parameter storages. You can add an Operator yourself, e.g.
model = model_helper.ModelHelper(name="train_net")
# init your weight and bias as w and b
w = model.param_init_net.XavierFill(...)
b = model.param_init_net.ConstantFill(...)
fc1 = model.FC([input, w, b], output, **kwargs)
or you can use helper functions in brew module without manually
defining parameter initializations and operators.
model = model_helper.ModelHelper(name="train_net")
fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs)
"""
回到model_builder.py,在使用DetectionModelHelper实例化了一个model对象后,又设置了only_build_forward_pass和target_gpu_id两个参数,根据命名能清楚知道它们的含义。对这两个参数的使用均在detectron/modeling/optimizer.py中,具体原理在此不做深究。
model.only_build_forward_pass = False
model.target_gpu_id = gpu_id
最后,返回创建的模型。
return get_func(model_type_func)(model)
get_func函数:根据名字返回函数对象。
def get_func(func_name): # 传入参数为模型类型,如:generalized_rcnn
"""Helper to return a function object by name. func_name must identify a
function in this module or the path to a function relative to the base
'modeling' module.
"""
if func_name == '':
return None
new_func_name = name_compat.get_new_name(func_name)
"""name_compat.py中存放着修改过名称的模型,
其中的get_new_name()是根据旧名称得到新名称,
如果确实发现传入参数的名称是之前修改过的,则发出警告。
"""
if new_func_name != func_name:
logger.warn(
'Remapping old function name: {} -> {}'.
format(func_name, new_func_name)
)
func_name = new_func_name
try:
parts = func_name.split('.')
# Refers to a function in this module
if len(parts) == 1:
return globals()[parts[0]] # 全局变量(一个字典)中的part[0]键名对应的键字
# Otherwise, assume we're referencing a module under modeling
module_name = 'detectron.modeling.' + '.'.join(parts[:-1]) # 取最前面的一个(第一个'.'之前)
module = importlib.import_module(module_name) # 动态导入模块
return getattr(module, parts[-1]) # 返回对应的属性
except Exception:
logger.error('Failed to find function: {}'.format(func_name))
raise
了解了get_func(func_name)以后,我们再回过头来看create()的最后一句return。
return get_func(model_type_func)(model)
举个例子,比如这里我的model_type_func = generalized_rcnn,那么get_func(model_type_func)所返回的就是名字为generalized_rcnn的函数对象。 generalized_rcnn()同样定义在model_builder.py中,这样上面一行代码就等价于:
return generalized_rcnn(model)
generalized_rcnn()定义如下:
def generalized_rcnn(model):
"""This model type handles:
- Fast R-CNN
- RPN only (not integrated with Fast R-CNN)
- Faster R-CNN (stagewise training from NIPS paper)
- Faster R-CNN (end-to-end joint training)
- Mask R-CNN (stagewise training from NIPS paper)
- Mask R-CNN (end-to-end joint training)
"""
return build_generic_detection_model(
model,
get_func(cfg.MODEL.CONV_BODY),
add_roi_box_head_func=get_func(cfg.FAST_RCNN.ROI_BOX_HEAD),
add_roi_mask_head_func=get_func(cfg.MRCNN.ROI_MASK_HEAD),
add_roi_keypoint_head_func=get_func(cfg.KRCNN.ROI_KEYPOINTS_HEAD),
freeze_conv_body=cfg.TRAIN.FREEZE_CONV_BODY
)
再来看其中用到的build_generic_detection_model()函数:(太长只截取函数名和参数部分)
def build_generic_detection_model(
model,
add_conv_body_func,
add_roi_box_head_func=None,
add_roi_mask_head_func=None,
add_roi_keypoint_head_func=None,
freeze_conv_body=False
):
这时候我们就可以来解释detectron对于物体检测模型的构成原理了。对于一般的模型构建,detectron均遵循TYPE,BODY, HEAD三大部分的方式来构建,这也和yaml文件中MODEL中的参数设置是一致的。比如构建一个Fast R-CNN模型,使用ResNet-50-C4的主干网络:(截取自官方注释)
为什么要这样来构建,这里参考论文结构。
"""
Generic recomposable model builders
For example, you can create a Fast R-CNN model with the ResNet-50-C4 backbone
with the configuration:
MODEL:
TYPE: generalized_rcnn
CONV_BODY: ResNet.add_ResNet50_conv4_body
ROI_HEAD: ResNet.add_ResNet_roi_conv5_head
"""
根据MODEL中的这个对网络结构的配置,再使用get_func(func_name)通过函数对象名返回具体的代表网络结构的函数,就可以构成一个完整的网络了。所以说detectron在最后一步组成网络上几乎可以说是一步到位的。而它麻烦的地方也正是在于,为了最后一步能够直接根据yaml中的配置像拼积木一样拼成任意我们想要的网络,前面所做的各种准备工作量相当巨大。不过从另一个角度说,官方已经给出了在图像检测领域所有的积木块,那我们只需要去使用就好了。比如我可以去定义一个新的YAML文件,随意拼一个结构出来进行训练和测试。但是我觉得既然官方发布了一个这么完备的平台,那么可以做的排列组合应该都已经被官方实验过了。从config/12_2017_baselines里众多的配置文件就可以看出来官方所做的实验量还是十分巨大的。
build_generic_detection_model()中的参数就是函数对象的方法,对于conv_body,box_head和mask_head都是函数的不断调用。我们来看一个例子,add_conv_body_func,假设cfg.MODEL.CONV_BODY: ResNet.add_ResNet50_conv4_body,则add_conv_body_func = add_ResNet50_conv4_body。在modeling/ResNet.py中:
def add_ResNet50_conv4_body(model):
return add_ResNet_convX_body(model, (3, 4, 6))
add_ResNet_convX_body:
def add_ResNet_convX_body(model, block_counts):
"""Add a ResNet body from input data up through the res5 (aka conv5) stage.
The final res5/conv5 stage may be optionally excluded (hence convX, where
X = 4 or 5)."""
freeze_at = cfg.TRAIN.FREEZE_AT
assert freeze_at in [0, 2, 3, 4, 5]
# add the stem (by default, conv1 and pool1 with bn; can support gn)
p, dim_in = globals()[cfg.RESNETS.STEM_FUNC](model, 'data')
dim_bottleneck = cfg.RESNETS.NUM_GROUPS * cfg.RESNETS.WIDTH_PER_GROUP
(n1, n2, n3) = block_counts[:3]
s, dim_in = add_stage(model, 'res2', p, n1, dim_in, 256, dim_bottleneck, 1)
if freeze_at == 2:
model.StopGradient(s, s)
s, dim_in = add_stage(
model, 'res3', s, n2, dim_in, 512, dim_bottleneck * 2, 1
)
if freeze_at == 3:
model.StopGradient(s, s)
s, dim_in = add_stage(
model, 'res4', s, n3, dim_in, 1024, dim_bottleneck * 4, 1
)
if freeze_at == 4:
model.StopGradient(s, s)
if len(block_counts) == 4:
n4 = block_counts[3]
s, dim_in = add_stage(
model, 'res5', s, n4, dim_in, 2048, dim_bottleneck * 8,
cfg.RESNETS.RES5_DILATION
)
if freeze_at == 5:
model.StopGradient(s, s)
return s, dim_in, 1. / 32. * cfg.RESNETS.RES5_DILATION
else:
return s, dim_in, 1. / 16.
(未完待续)