首先我们看一下Pytorch中torch.where函数是怎样定义的:
@overload
def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
torch.where函数的功能如下:
torch.where(condition, x, y):
condition:判断条件
x:若满足条件,则取x中元素
y:若不满足条件,则取y中元素
以具体实例看一下torch.where函数的效果:
import torch
# 条件
condition = torch.rand(3, 2)
print(condition)
# 满足条件则取x中对应元素
x = torch.ones(3, 2)
print(x)
# 不满足条件则取y中对应元素
y = torch.zeros(3, 2)
print(y)
# 条件判断后的结果
result = torch.where(condition > 0.5, x, y)
print(result)
结果如下:
tensor([[0.3224, 0.5789],
[0.8341, 0.1673],
[0.1668, 0.4933]])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
tensor([[0., 1.],
[1., 0.],
[0., 0.]])
可以看到torch.where函数会对condition中的元素逐一进行判断,根据判断的结果选取x或y中的值,所以要求x和y应该与condition形状相同。
版权声明:本文为tszupup原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。