RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the

  • Post author:
  • Post category:其他


bug:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

源代码如下:

if __name__ == "__main__":
    from torchsummary import summary
    model = UNet()
    print(model)
    summary(model, input_size=(1, 480, 480))

在使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)。


我们将model传到GPU上便可

。将代码如下修改便可正常运行:

if __name__ == "__main__":
    from torchsummary import summary
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)	# modify
    print(model)
    summary(model, input_size=(1, 480, 480))



版权声明:本文为qq_43369406原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。