MMPretrain

  • Post author:
  • Post category:其他



title: mmpretrain实战

date: 2023-06-07 16:04:01

tags: [image classification,mmlab]




mmpretrain实战

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ccTl9bOl-1686129437336)(null)]

在这里插入图片描述

主要讲解了安装,还有使用教程.安装教程直接参考官网.下面讲解一下mmpretrain使用



实战教程



2.1简单使用

我们可以直接从定义好的模型来进行推理,首先list_model可以列出所有的分类,然后通过关键字可以识别出来resnet所有的模型,然后我们通过get_model,输入关键字就可以得到模型,之后,我们通过使用inference来进行传入模型,还有ckp,还有图形就可以直接来进行推理.



2.2自定义使用

首先整个mmlab都是通过使用cfg来进行配置的,所以我们如果要进行自己的resnet50配置,我们可以从官网的cfg来进行参考.

首先是模型,模型分为backbone骨干网络,head就是输出头,使用neck来进行连接网络.然后最后的loss,实在模型里就定义号了,使用的是topk

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=33,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ),
    init_cfg = dict(type='Pretrained',checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
    )

之后就是dataset的配置,我们使用的type是自定义的type,设置输入的train,还有val路径,之后设置val的评估指标,使用top1.

下面就是训练时候的配置,循环次数,还有优化器

最后就是训练时候的配置,自动保存权重最高的,还有值保留最近5个文件

剩下的地方可以设置args参数 例如load_file还有work-dir

work_dir = './exp'
  checkpoint=dict(type='CheckpointHook', interval=1,max_keep_ckpts=5,save_best='auto'),



版权声明:本文为qq_42186444原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。