pytorch 删除tensor中的指定位置元素

  • Post author:
  • Post category:其他


pytorch似乎并没有提供删除指定位置的元素或者删除某个值的直接方法,但是可以使用其他方法曲线达到目标。

删除指定位置元素或者具体某个值需要先用nonzero()获取到要删除的元素的索引位置,之后使用torch.cat()进行删除操作

import torch
a = torch.randn(4,5).int()  #生成一个随机的4,5 tensor
print(a, a.size())
tensor([[ 0,  1,  0,  0,  0],
        [ 0,  0,  0, -1,  1],
        [ 0,  1,  0,  0, -1],
        [ 0,  0,  0,  1,  0],], dtype=torch.int32) torch.Size([4, 5])
# 删除tensor里面的元素1
p=(a==1).nonzero()  #获取1的位置
print(p)
tensor([[0, 1],
        [1, 4],
        [2, 1],
        [3, 3]])  #a中1的二维和一维坐标
b=torch.cat([torch.cat((a[i][0:j],a[i][j+1:])) for i, j in p])  #用torch.cat()删除所有的1
print(b)  #使用cat之后矩阵变成1维的,之后可以把矩阵view()回去
tensor([ 0,  0,  0,  0,  0,  0,  0, -1,  0,  0,  0, -1,  0,  0,  0,  0],
       dtype=torch.int32)

b=b.view(a.size(0), a.size(1)-1)  #因为每一个1维维度都删去了一个1,因此矩阵的1维维度要减1
print(b)
tensor([[ 0,  0,  0,  0],
        [ 0,  0,  0, -1],
        [ 0,  0,  0, -1],
        [ 0,  0,  0,  0]], dtype=torch.int32)


矩阵的删除适合于每行都删除相同数量的元素,如果删除的元素数量不同就不能view成与原矩阵相同行数的矩阵,但是1维矩阵则可以删除多个值。



版权声明:本文为lhyyhlfornew原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。