torch.load报错:No module named ‘models‘

  • Post author:
  • Post category:其他


使用pytorch训练模型时想要预先加载预训练模型,忽然出现这种错误。

原因大概是该预训练模型保存方法是完全保存:

torch.save(model, path)

该方法将模型内容全部保存,甚至包括存放路径

这导致将保存的模型换位置的后,load加载的时候可能导致路径出现问题

解决方法:

model = Model()
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, 'pretrained_model.pt')
torch.jit.load('pretrained_model.pt')


参考自

避免该问题的方法:

在保存模型的时候只保存状态字典,不要全部保存了!

即:

torch.save(model.state_dict(), PATH)


参考自



版权声明:本文为Arcofcosmos原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。