Pytorch学习 (二十六)—- torch.scatter的使用

  • Post author:
  • Post category:其他




总说

一个非常有用的函数,主要是用于“group index”的操作。

先安装一下 https://github.com/rusty1s/pytorch_scatter

from torch_scatter import scatter
import torch

src = (torch.rand(2, 6, 2)*4).int()
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")

print(src)
print(index)
print(out)

输出

tensor([[[1, 3],
         [3, 3],
         [3, 2],
         [2, 1],
         [1, 0],
         [0, 2]],

        [[0, 3],
         [3, 0],
         [2, 1],
         [2, 2],
         [3, 0],
         [0, 3]]], dtype=torch.int32)
tensor([0, 1, 0, 1, 2, 1])
tensor([[[4, 5],
         [5, 6],
         [1, 0]],

        [[2, 4],
         [5, 5],
         [3, 0]]], dtype=torch.int32)

这里沿着“dim=1”来分别找不同的index,对于“index 0”来说,在第“0”个和第“2”个位置出现,对应的数据是

[[1, 3], [0, 3]]  以及 [[3, 2], [2, 1]]

这里”sum”,那么后,就变成了

[[4, 5], [2, 4]]

同理解释其他的index。



版权声明:本文为Hungryof原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。