model.to(device)无法将自定义层的tensor转移到指定设备

  • Post author:
  • Post category:其他


有时在model内自定义的模块或参数无法被model.to(device)正确转移到指定设备。如果仅在CPU上跑是没问题的,但如果在GPU上跑,其余部分参数被转移到了GPU,这部分无法正确转移的参数却存放在CPU,就会报错。

可以考虑以下几种解决办法:

1、使用

nn.ModuleList()

而不是python的内建List来存放多个模块

例如下面自定义的Module,如果将nn.ModuleList去掉(仅使用列表生成式得到多个Linear层),则在使用module.to(device)时,这些层的parameter无法被正确转移到指定设备。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

2、使用

nn.Parameter()


这篇论文

提出了解决多任务学习的不同Loss权重难以确定的问题的方法,参考

GitHub的tensorflow实现

,可以得到pytorch的实现如下。这个例子中的两个learnable的参数sigmas,如果不使用

nn.Parameter()

,仅使用

torch.rand(..., requires_grad=True)

,则该tensor无法被

model.to(device)

正确转移到对应设备。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.sigmas = nn.Parameter((1 - 0.2) * torch.rand(2) + 0.2, requires_grad=True)

    def get_multi_task_loss(self, loss1, loss2):
        factor1 = torch.div(1.0, torch.mul(2.0, self.sigmas[0]))
        factor2 = torch.div(1.0, torch.mul(2.0, self.sigmas[1]))
        loss = torch.add(torch.mul(factor1, loss1), torch.log(self.sigmas[0]))
        loss = torch.add(loss, torch.add(torch.mul(factor2, loss2), torch.log(self.sigmas[1])))
        return loss

3、使用

register_buffer()

如果想把tensor转移到GPU中,但这些tensor又不需要更新,所以不想将其设为Parameter,则可以考虑使用

register_buffer()

。下面的例子中,使用

self.k = torch.zeros(k)

则无法使

k



model.to(device)

转移到GPU中。(PS:我觉得使用

nn.Parameter()

传入

requires_grad=False

也一样吧)

class MyModule(nn.Module):
    def __init__(self, k):
        super(MyModule, self).__init__()
        self.register_buffer('k', torch.zeros(k))
        # self.k = torch.zeros(k)
    def forward(self):
            ...
            # 这里使用 self.k 即可



参考资料

[1]

Why model.to(device) wouldn’t put tensors on a custom layer to the same device?

[2]

https://github.com/ranandalon/mtl

[3]

Pytorch学习(十九)— 模型中buffer的使用



版权声明:本文为xpy870663266原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。