总说
一个非常有用的函数,主要是用于“group index”的操作。
先安装一下 https://github.com/rusty1s/pytorch_scatter
from torch_scatter import scatter
import torch
src = (torch.rand(2, 6, 2)*4).int()
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(src)
print(index)
print(out)
输出
tensor([[[1, 3],
[3, 3],
[3, 2],
[2, 1],
[1, 0],
[0, 2]],
[[0, 3],
[3, 0],
[2, 1],
[2, 2],
[3, 0],
[0, 3]]], dtype=torch.int32)
tensor([0, 1, 0, 1, 2, 1])
tensor([[[4, 5],
[5, 6],
[1, 0]],
[[2, 4],
[5, 5],
[3, 0]]], dtype=torch.int32)
这里沿着“dim=1”来分别找不同的index,对于“index 0”来说,在第“0”个和第“2”个位置出现,对应的数据是
[[1, 3], [0, 3]] 以及 [[3, 2], [2, 1]]
这里”sum”,那么后,就变成了
[[4, 5], [2, 4]]
同理解释其他的index。
版权声明:本文为Hungryof原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。