pytorch中Linear类中weight的形状问题源码探讨

  • Post author:
  • Post category:其他


import torch
from torch import nn

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)

print(output.size())
print(m.weight.shape)

来看一下输出:

out:

torch.Size([128, 30])

torch.Size([30, 20])

发现weight的形状是[30,20]而非[20, 30]?

所以具体看一下源码的实现方式:

  1. 在Linear类中的

    __init__

    函数中,weight形状为

    [out_features, in_features]


    在这里插入图片描述


  2. forward

    函数中调用

    F.linear

    函数,实现单层线性神经网络层的计算

    在这里插入图片描述
  3. 在F.linear函数中,使用的是

    weight.t()

    ,也就是将weight转置,再传入matmul计算。

    在这里插入图片描述

通过以上三步,pytorch就完成weight形状的维护。简单的说就是,

在定义时使用的是[out_features, in_features],而在单层线性神经网络计算时使用的是weight的转置矩阵。



版权声明:本文为dss_dssssd原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。