函数原型
torch.argmax(input, dim=None, keepdim=False)
作用:返回指定维度最大值的序号。
示例:
x = torch.randint(12, size=(3, 4))
print(x)
y = torch.argmax(x, dim=0)#返回每列最大值对应的行号
print(f'y.shape{y.shape}')
print(y)
z = torch.argmax(x, dim=1)#每行最大值对应的列号
print(f'z.shape{z.shape}')
print(z)
输出结果:
tensor([[ 1, 11, 9, 7],[10, 4, 7, 10],[ 7, 7, 5, 0]])
y.shapetorch.Size([4])
tensor([1, 0, 0, 1])
z.shapetorch.Size([3])
tensor([1, 0, 0])
如果参数中不写dim,则先把张量展平,然后返回最大值对应的索引。
x = torch.tensor([[0.1, 0.08, 0.52, 0.92],
[0.55, 0.2, 0.9, 0.88]])
index = torch.argmax(x)
print(x)
print(index)
>>>
tensor([[0.10, 0.08, 0.52, 0.92],[0.55, 0.20, 0.90, 0.88]])
tensor(3)
也就是共有8个元素,最大的第四个元素(索引为3)。
函数原型
torch.max(input) → Tensor
例子1:
x = torch.randn(5)
print(f'x: {x}')
y = torch.max(x)
print(y)
x1 = torch.randn(3, 4)
print(f'x1:{x1}')
y1 = torch.max(x1)
print(y1)
输出结果:
x: tensor([ 0.25, 0.57, 0.28, 3.31, -0.08])
tensor(3.31)
x1:tensor([[-0.55, -0.69, 0.04, -0.28],[-0.63, -0.08, 0.09, 0.28],[-1.82, -2.04, -0.74, -1.92]])
tensor(0.28)
函数原型
torch.max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
参数
input (Tensor) :输入张量
dim (int) : 指定的维度
keepdim(bool):输出张量是否保留dim.默认为False.
输出:
max (Tensor, optional) : 结果张量,包含给定维度上的最大值
max_indices (LongTensor, optional) : 结果张量,包含给定维度上每个最大值的位置索引
作用:返回输入张量给定维度上每行的最大值,并同时返回最大值的位置索引。
x1 = torch.randn(3, 4)
print(f'x1:{x1}')
max_values, max_indices = torch.max(x1, dim=1, keepdim=True)
print(max_values)
print(max_indices)
output:
x1:tensor([[ 0.63, -0.31, 2.00, 2.39],[ 0.59, 1.51, -0.23, 0.68],[ 1.09, -0.02, -1.06, -0.07]])
tensor([[2.39],[1.51],[1.09]])
tensor([[3],[1],[0]])