使用pytorch训练模型时想要预先加载预训练模型,忽然出现这种错误。
原因大概是该预训练模型保存方法是完全保存:
torch.save(model, path)
该方法将模型内容全部保存,甚至包括存放路径
这导致将保存的模型换位置的后,load加载的时候可能导致路径出现问题
解决方法:
model = Model()
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, 'pretrained_model.pt')
torch.jit.load('pretrained_model.pt')
避免该问题的方法:
在保存模型的时候只保存状态字典,不要全部保存了!
即:
torch.save(model.state_dict(), PATH)
版权声明:本文为Arcofcosmos原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。