解决RuntimeError:Error(s)in loading state_dict for YoloBody

  • Post author:
  • Post category:其他


解决RuntimeError:Error(s)in loading state_dict for YoloBody



问题描述

本人运行的是官网下载的基于pytorch实现的YOLOV3模型。在尝试利用训练好的权值文件去进行预测时产生一下问题:

RuntimeError: Error(s) in loading state_dict for YoloBody:
	size mismatch for last_layer0.6.weight: copying a param with shape torch.Size([45, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([255, 1024, 1, 1]).
	size mismatch for last_layer0.6.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([255]).
	size mismatch for last_layer1.6.weight: copying a param with shape torch.Size([45, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([255, 512, 1, 1]).
	size mismatch for last_layer1.6.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([255]).
	size mismatch for last_layer2.6.weight: copying a param with shape torch.Size([45, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([255, 256, 1, 1]).
	size mismatch for last_layer2.6.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([255]).

即加载的模型参数和初始化的YOLOV3结构不匹配的问题。



分析

通过报错信息,不匹配的地方主要是在最后的层数上,且只有一个数据对不上,即本人训练得到的数据中,第一个参数数量是45,而初始换的结构式定下的参数数量是255。通过对比YOLOV3的结构参数可以知道该参数应该为最后输出的channel数。本人数据集总共分为10类,每个图片三个bbox,3*(10+4+1)=45,符合我的训练模型。而初始化的参数数量为255,即对应80个类。原因在于初始化结构是默认根据COCO数据集的80个类建立的,相关代码如下:

class YOLO(object):
    _defaults = {
        "model_path"        : 'model_data/yolo_weights.pth',
        "classes_path"      : 'model_data/coco_classes.txt',
        "model_image_size"  : (416, 416, 3),
        "confidence"        : 0.5,
        "iou"               : 0.3,
        "cuda"              : True,

这部分代码在yolo.py文件之中。

可以看到classes_path的路径为model_data/coco_classes.txt,点开该文件可以看到这里面保存着coco数据集80个类别的信息

在这里插入图片描述

说明想法正确



解决方法

直接将model_data/coco_classes.txt文件中的内容改成自己的类别信息,或者新建文件夹存储自己的类别信息,并且将上述代码中的路径改成自己的文件即可。



补充

由于好像比较多人无法解决问题,本人在此对使用YOLOv3模型的流程做一个详细一点补充说明,希望可以帮助有疑惑的人发现问题。

在使用YOLOV3之前,当然是要将自己的图片数据和标注数据分别放在

VOCdevkit/VOC2007/Annotations



VOCdevkit/VOC2007/JPEGImages

中。然后就可以运行

voc2yolo3.py

。这一步的目的其实是根据

voc2yolo3.py

中设置的比例将数据集分成训练集和测试集。然后将其中的图片名分别保存在

VOCdevkit/VOC2007/ImageSets/Main/



txt

文件之中。保存结果如下图所示。下述字符串是我各个图片的名字。

在这里插入图片描述

接着你就需要运行

voc_annotation.py

,这一步的目的是根据上一步中划分的训练集和数据集,将对应的图片信息。包括位置,类别,标记框位置等,写入2007_test.txt等文件中。

这时你需要修改

voc_annotation.py

中的classes,将其改成你的类

。比如你需要识别apple和orange。这时候你的classes=[apple,orange],那么将信息写入txt的时候,苹果图片中的类别信息就是0,橙子图片的位置信息就是1.例子如下

在这里插入图片描述

接着,你就可以运行train.py训练你的模型了。注意到train.py中模型的定义为

 #------------------------------------------------------#

    model = YoloBody(Config)

    #------------------------------------------------------#

点开YoloBody的定义,你会发现里面有很多类似的语句:

final_out_filter0 = len(config["yolo"]["anchors"][0]) * (5 + config["yolo"]["classes"])

即YoloBody会根据Config中的参数创建YOLOv3模型,而在Config中:


 "yolo": {
        "anchors": [[[116, 90], [156, 198], [373, 326]],
                    [[30, 61], [62, 45], [59, 119]],
                    [[10, 13], [16, 30], [33, 23]]],

        #"anchors":[[[310, 328], [281, 279], [237, 300]],
        #            [[176, 251], [131, 111], [87, 150]],
        #           [[85, 65], [59, 99], [46, 61]]],
        "classes": 10,
    },


所以在训练开始时你需要修改Config中的“classes”,改成你要分的类别数,再运行train.py

训练完成,你会获得模型的权重数据,此时你需要用你的权重数据初始化模型看看你的训练结果,这个过程你可以调用video.py,调用摄像头拍摄图片进行检验,也可以运行predict.py,选取本地图片进行检测。无论哪一种方式,其代码都会有下列语句

yolo = YOLO()#创建yolo类

这个是用来创建YOLO类的,而点开 YOLO代码,我们可以看到上述分析中的内容。即此时构建

YOLO不是根据Config中的“classes”来构建的,而是根据”classes_path”中类别的数目来确定的 ,因此你需要修改”classes_path”,将其指向包含你的分类类别的txt。否则由于构建YOLO使用的参数不一样,在加载你训练好的权重数据时会报错。

以上便是使用YOLOv3训练你的模型所所需要注意的地方。另外也要注意,在

voc_annotation.py

的classes和在

YOLO中”classes_path”

中保存的类别信息顺序应当一致,否在识别时会输出错误的类别信息。这是因为生成类别信息编号时参考的是

voc_annotation.py

的classes,而输出类别信息时是根据模型输出的编号查找

YOLO

中**“classes_path”** 中保存的类别信息,因此两者需要一一对应,否则输出类别信息会错乱。

一般而言,注意上述步骤便不会出现问题,仍有加载数据报错的话请仔细检查上述步骤中哪一步出了问题。出了上述问题之外,还有一个点会造成加载数据出错,就是加载其他电脑上训练出来的权重数据。本人曾利用笔记本电脑加载服务器训练出来的数据,会出现报错。这是因为两者训练的环境不一样,比如cuda版本,torch版本,cudnn版本等,这些可能会导致最后训练出来的权重数据不可共用。比如我利用服务器训练的权重数据文件大小会比用笔记本训练出来权重数据文件略大,两者无法共用。



版权声明:本文为xhxzdxm原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。