pytorch中scatter_介绍

  • Post author:
  • Post category:其他




1.官方文档中的介绍

scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原来的。

scatter_(dim, index, src) 的参数有 3 个

  • dim:在哪个维度进行变换
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素

具体的转化关系可以参考下图。

在这里插入图片描述

注意,这个里面的i,j,k可以说都是相同的量。



2.实例介绍

利用官网中的例子

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

这里我们来看看,首先声明了一个2*5的矩阵,里面的值是[0,1),然后我们使用了scatter_,可以看出dim=0,对应上面的公式,我们可以得到。


self[index[0,0],0] = self[0,0] = src[0,0] = 00.3992


(其中i=0,j=0)


self[index[0,1],1] = self[1,1] = src[0,1] = 0.2908


(其中i=0,j=1)

以此类推…



3.实现one-hot

y_train = torch.Tensor(y_train).long()
y_train_onehot = y_train
y_train_onehot = y_train.view(-1, 1)
y_train_onehot = torch.zeros(y_train.size(0), 10).scatter_(1, y_train_onehot, 1).long()

数据说明:

y_train结构是(10000),是有10000个训练集,每个都进行了分类,一共10类(类似于[1,2,3,5,0,6,4,9…])

代码的思想就是首先生成一个全0的矩阵,然后通过对其列上面的值变换为1从而实现one-hot。


当然现在不用这个啦,pytorch中有one_hot函数

torch.nn.functional.one_hot(tensor, num_classes=-1) → LongTenso

在这里插入图片描述

对于上面的数据我们可以直接用

y_train = torch.Tensor(y_train).long()
y_train = torch.nn.functional.one_hot(y_train,10)

输出的结果为

在这里插入图片描述



版权声明:本文为weixin_43977647原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。