pytorch学习之torch.max()

  • Post author:
  • Post category:其他

官方API:https://pytorch.org/docs/stable/torch.html#torch.max

首先我们来看一个简单的例子体会一下torch.max()的用法:

import  torch
a = torch.randn(3,3)
print(a)#返回生成的随机tensor(3*3)
print(torch.max(a,0))#返回每一列中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引)
print(torch.max(a,1))#返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
print(torch.max(a))#返回tensor a 中的最大值
print(torch.max(a,0)[0] )# 只返回最大值的每个数
print(torch.max(a,0)[1])#只返回最大值的每个索引

运行结果:(pycharm)

注:每次随机生成的tensor并不相同,故结果不唯一,但原理是相同的。

总结:

torch.max(a) 返回输入tensor a中所有元素的最大值

torch.max(a,0) 返回每一中最大值的那个元素,且返回索引(返回最大元素在这一列的行索引

torch.max(a,1) 返回每一中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引

torch.max()[0], 只返回最大值的每个数

troch.max()[1], 只返回最大值的每个索引

torch.max()[1].data 只返回variable中的数据部分(去掉Variable containing:)

torch.max()[1].data.numpy() 把数据转化成numpy ndarry

torch.max()[1].data.numpy().squeeze() 把数据条目中维度为1 的删除掉

torch.max(tensor1,tensor2) element-wise 比较tensor1 和tensor2 中的元素,返回较大的那个值


版权声明:本文为Jane_JXR原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。