import torch
import torch.nn as nn
input=torch.randn([32,49,768])
l=nn.Linear(768,512)
out=l(input)
print(out.shape)
# torch.Size([32, 49, 512])
# l=nn.Linear(49,512)
# mat1 and mat2 shapes cannot be multiplied (1568x768 and 49x512)
# 说明了执行linear时,输入的channel只能位于最后一维
b=nn.BatchNorm1d(49)
out=b(out)
print(out.shape)
# torch.Size([32, 49, 512])
# b=nn.BatchNorm1d(512)
# RuntimeError: running_mean should contain 49 elements not 512
# 说明了执行linear时,输入的channel只能位于最后中间
版权声明:本文为qq_42820467原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。