1、
定义
PyTorch的
tensor
参数可以自动扩展其大小。一般的是小一点的会变大,来满足运算需求。
广播机制实际上是在运算过程中,去处理两个形状不同向量的一种手段。
理解:
核心:如果相加的两个数组的shape不同, 就会触发广播机制:
1)程序会自动执行操作使得A.shape==B.shape;
2)对应位置进行相加运算,结果的shape是:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,9,4),B.shape=(15,1,4),那么A+B的shape是(15,9,4)
2、规则
满足一下情况的tensor是可以广播的。
-
至少有一个维度
#像下面这种情况下就不行,因为x不满足这个条件。
x=torch.empty((0,))
y=torch.empty(2,2)
-
两个tensor维度相等
-
维度不等,其中一个为1
-
维度不等,其中一个维度不存在
x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)
#这两个是可以广播的
如上面代码中,首先将两个张量维度向右靠齐,从右往左看,(张量的维度要先从内部对齐)
两个张量第四维大小相等,都为1,满足上面条件a;
第三个维度大小不相等,但第二个张量第三维大小为1,满足上面条件b;
第二个维度大小相等都为3,满足上面条件a;
第一个维度第一个张量有,第二个张量没有,满足上面条件b,因此两个张量每个维度都符合上面广播条件,因此可以进行广播。
两个张量维度从右往左看,如果出现两个张量在某个维度位置上面,
维度大小不相等
,
且两个维度大小没有一个是1
,那么这两个张量一定
不能进行广播。
3、计算过程规则
-
维度不同,小维度的增加维度
-
每个维度,计算结果取大的
-
扩展维度是对数值进行复制
示例
# 广播机制
a = torch.arange(3).reshape((3, 1))
b = torch.arange(2).reshape((1, 2))
print(a + b)
tensor([[0],
[1],
[2]])
tensor([[0, 1]])
-----------------------
tensor([[0, 1],
[1, 2],
[2, 3]])