import torch
from torch import nn
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())
print(m.weight.shape)
来看一下输出:
out:
torch.Size([128, 30])
torch.Size([30, 20])
发现weight的形状是[30,20]而非[20, 30]?
所以具体看一下源码的实现方式:
-
Linear类的源码网址:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
-
functional模块的源码网址:
https://pytorch.org/docs/stable/_modules/torch/nn/functional.html
-
在Linear类中的
__init__
函数中,weight形状为
[out_features, in_features]
-
在
forward
函数中调用
F.linear
函数,实现单层线性神经网络层的计算
-
在F.linear函数中,使用的是
weight.t()
,也就是将weight转置,再传入matmul计算。
通过以上三步,pytorch就完成weight形状的维护。简单的说就是,
在定义时使用的是[out_features, in_features],而在单层线性神经网络计算时使用的是weight的转置矩阵。
版权声明:本文为dss_dssssd原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。