一、timm库简介
PyTorch Image Models,简称timm,是一个巨大的PyTorch代码集合,整合了常用的models、layers、utilities、optimizers、schedulers、data-loaders/augmentations和reference training/validation scripts。
二、安装
pip install timm
三、使用
- 查看所有模型
model_list = timm.list_models()
print(model_list)
- 查看具有预训练参数的模型
model_pretrain_list = timm.list_models(pretrained=True)
print(model_pretrain_list)
- 检索特定模型
model_resnet = timm.list_models('*resnet*')
print(model_resnet)
- 创建模型
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 1000])
- 创建模型–改变输出类别数
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=100)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 100])
- 创建模型–改变输入通道数
x = torch.randn((1, 10, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, in_chans=10)
out = modle_mobilenetv2(x)
# print(out.shape)
# torch.Size([1, 1000])
- 创建模型–只提取特征
x = torch.randn((1, 3, 256, 512))
modle_mobilenetv2 = timm.create_model('mobilenetv2_100', pretrained=True, features_only=True)
out = modle_mobilenetv2(x)
# for o in out:
# print(o.shape)
# torch.Size([1, 16, 128, 256])
# torch.Size([1, 24, 64, 128])
# torch.Size([1, 32, 32, 64])
# torch.Size([1, 96, 16, 32])
# torch.Size([1, 320, 8, 16])
版权声明:本文为weixin_44762713原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。