pytorch似乎并没有提供删除指定位置的元素或者删除某个值的直接方法,但是可以使用其他方法曲线达到目标。
删除指定位置元素或者具体某个值需要先用nonzero()获取到要删除的元素的索引位置,之后使用torch.cat()进行删除操作
import torch
a = torch.randn(4,5).int() #生成一个随机的4,5 tensor
print(a, a.size())
tensor([[ 0, 1, 0, 0, 0],
[ 0, 0, 0, -1, 1],
[ 0, 1, 0, 0, -1],
[ 0, 0, 0, 1, 0],], dtype=torch.int32) torch.Size([4, 5])
# 删除tensor里面的元素1
p=(a==1).nonzero() #获取1的位置
print(p)
tensor([[0, 1],
[1, 4],
[2, 1],
[3, 3]]) #a中1的二维和一维坐标
b=torch.cat([torch.cat((a[i][0:j],a[i][j+1:])) for i, j in p]) #用torch.cat()删除所有的1
print(b) #使用cat之后矩阵变成1维的,之后可以把矩阵view()回去
tensor([ 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, 0],
dtype=torch.int32)
b=b.view(a.size(0), a.size(1)-1) #因为每一个1维维度都删去了一个1,因此矩阵的1维维度要减1
print(b)
tensor([[ 0, 0, 0, 0],
[ 0, 0, 0, -1],
[ 0, 0, 0, -1],
[ 0, 0, 0, 0]], dtype=torch.int32)
矩阵的删除适合于每行都删除相同数量的元素,如果删除的元素数量不同就不能view成与原矩阵相同行数的矩阵,但是1维矩阵则可以删除多个值。
版权声明:本文为lhyyhlfornew原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。