Pytorch中的广播机制(Broadcast)

  • Post author:
  • Post category:其他


1、

定义

PyTorch的

tensor

参数可以自动扩展其大小。一般的是小一点的会变大,来满足运算需求。


广播机制实际上是在运算过程中,去处理两个形状不同向量的一种手段。

理解:

核心:如果相加的两个数组的shape不同, 就会触发广播机制:

1)程序会自动执行操作使得A.shape==B.shape;

2)对应位置进行相加运算,结果的shape是:A.shape和B.shape对应位置的最大值,比如:A.shape=(1,9,4),B.shape=(15,1,4),那么A+B的shape是(15,9,4)


2、规则

满足一下情况的tensor是可以广播的。

  • 至少有一个维度

#像下面这种情况下就不行,因为x不满足这个条件。
x=torch.empty((0,))
y=torch.empty(2,2)
  • 两个tensor维度相等

  • 维度不等,其中一个为1

  • 维度不等,其中一个维度不存在

x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)
#这两个是可以广播的

如上面代码中,首先将两个张量维度向右靠齐,从右往左看,(张量的维度要先从内部对齐)

两个张量第四维大小相等,都为1,满足上面条件a;

第三个维度大小不相等,但第二个张量第三维大小为1,满足上面条件b;

第二个维度大小相等都为3,满足上面条件a;

第一个维度第一个张量有,第二个张量没有,满足上面条件b,因此两个张量每个维度都符合上面广播条件,因此可以进行广播。


两个张量维度从右往左看,如果出现两个张量在某个维度位置上面,



维度大小不相等







且两个维度大小没有一个是1



,那么这两个张量一定



不能进行广播。

3、计算过程规则

  • 维度不同,小维度的增加维度

  • 每个维度,计算结果取大的

  • 扩展维度是对数值进行复制

示例

# 广播机制
a = torch.arange(3).reshape((3, 1))
 
b = torch.arange(2).reshape((1, 2))
print(a + b)
tensor([[0],
        [1],
        [2]])
tensor([[0, 1]])
-----------------------
tensor([[0, 1],
        [1, 2],
        [2, 3]])



版权声明:本文为hlllllllhhhhh原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。