1.官方文档中的介绍
scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原来的。
scatter_(dim, index, src) 的参数有 3 个
- dim:在哪个维度进行变换
- index:用来 scatter 的元素索引
- src:用来 scatter 的源元素
具体的转化关系可以参考下图。
注意,这个里面的i,j,k可以说都是相同的量。
2.实例介绍
利用官网中的例子
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
这里我们来看看,首先声明了一个2*5的矩阵,里面的值是[0,1),然后我们使用了scatter_,可以看出dim=0,对应上面的公式,我们可以得到。
self[index[0,0],0] = self[0,0] = src[0,0] = 00.3992
(其中i=0,j=0)
self[index[0,1],1] = self[1,1] = src[0,1] = 0.2908
(其中i=0,j=1)
以此类推…
3.实现one-hot
y_train = torch.Tensor(y_train).long()
y_train_onehot = y_train
y_train_onehot = y_train.view(-1, 1)
y_train_onehot = torch.zeros(y_train.size(0), 10).scatter_(1, y_train_onehot, 1).long()
数据说明:
y_train结构是(10000),是有10000个训练集,每个都进行了分类,一共10类(类似于[1,2,3,5,0,6,4,9…])
代码的思想就是首先生成一个全0的矩阵,然后通过对其列上面的值变换为1从而实现one-hot。
当然现在不用这个啦,pytorch中有one_hot函数
torch.nn.functional.one_hot(tensor, num_classes=-1) → LongTenso
对于上面的数据我们可以直接用
y_train = torch.Tensor(y_train).long()
y_train = torch.nn.functional.one_hot(y_train,10)
输出的结果为