详解 torch.max 函数

  • Post author:
  • Post category:其他


torch.max()

  • 返回输入张量所有元素的最大值。

参数:

  • input (Tensor) – 输入张量

例子:

>>> a = torch.randn(1, 3)
>>> a

 0.4729 -0.2266 -0.2085
[torch.FloatTensor of size 1x3]

>>> torch.max(a)
0.4729
torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。

输出形状中,将dim维设定为1,其它与输入形状保持一致。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 指定的维度
  • max (Tensor, optional) – 结果张量,包含给定维度上的最大值
  • max_indices (LongTensor, optional) – 结果张量,包含给定维度上每个最大值的位置索引

例子:

>> a = torch.randn(4, 4)
>> a

0.0692  0.3142  1.2513 -0.5428
0.9288  0.8552 -0.2073  0.6409
1.0695 -0.0101 -2.4507 -1.2230
0.7426 -0.7666  0.4862 -0.6628
torch.FloatTensor of size 4x4]

>>> torch.max(a, 1)
(
 1.2513
 0.9288
 1.0695
 0.7426
[torch.FloatTensor of size 4x1]
,
 2
 0
 0
 0
[torch.LongTensor of size 4x1]
)
torch.max(input, other, out=None) → Tensor
返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。 
即,( out_i=max(input_i,other_i) \)

输出形状中,将dim维设定为1,其它与输入形状保持一致。

参数:

  • input (Tensor) – 输入张量
  • other (Tensor) – 输出张量
  • out (Tensor, optional) – 结果张量

例子:

>>> a = torch.randn(4)
>>> a

 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]

>>> b = torch.randn(4)
>>> b

 1.0067
-0.8010
 0.6258
 0.3627
[torch.FloatTensor of size 4]

>>> torch.max(a, b)

 1.3869
 0.3912
 0.6258
 0.3627
[torch.FloatTensor of size 4]



版权声明:本文为ViatorSun原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。