nn.Linear和nn.BatchNorm1的维度问题

  • Post author:
  • Post category:其他


import torch
import torch.nn as nn
input=torch.randn([32,49,768])

l=nn.Linear(768,512)
out=l(input)
print(out.shape)
# torch.Size([32, 49, 512])

# l=nn.Linear(49,512)
# mat1 and mat2 shapes cannot be multiplied (1568x768 and 49x512)
# 说明了执行linear时,输入的channel只能位于最后一维
b=nn.BatchNorm1d(49)
out=b(out)
print(out.shape)
# torch.Size([32, 49, 512])
# b=nn.BatchNorm1d(512)
# RuntimeError: running_mean should contain 49 elements not 512
# 说明了执行linear时,输入的channel只能位于最后中间



版权声明:本文为qq_42820467原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。