洋書、時々プログラミング

博士課程修了→メーカーという経路を辿っている人の日常

pytorch argmaxの挙動

torch.maxとtorch.argmaxがゴタゴタしていたのでメモ

  • torch.maxは1つ目に最大値、2つ目に最大値の入っているindexを返す。
  • torch.argmaxはtorch.maxの2つ目のみを返す。
  • axisを指定しないと1行にreshapeされた後にargmaxが行われる。
import torch
a = torch.rand((2,3,4,5))
print(a.shape)
# torch.Size([2, 3, 4, 5])

print(a.argmax(axis=0).shape)
# torch.Size([3, 4, 5])

print(a.argmax(axis=1).shape)
# torch.Size([2, 4, 5])

print(a.argmax(axis=2).shape)
# torch.Size([2, 3, 5])

print(a.argmax(axis=3).shape)
# torch.Size([2, 3, 4])

print(a.argmax(axis=-1).shape)
# torch.Size([2, 3, 4])

a = torch.randn(4, 4)
print(a)

'''
tensor([[-0.3592, -0.0374,  0.5395,  0.6372],
        [-1.0343,  0.4862, -1.1469,  0.0645],
        [ 0.1129, -0.7580, -1.0913, -1.6897],
        [ 1.1991, -0.3005,  0.6481,  0.5518]])
'''
print(torch.argmax(a))
# Tensor(12) 1行に直される