torch.maxとtorch.argmaxがゴタゴタしていたのでメモ
- torch.maxは1つ目に最大値、2つ目に最大値の入っているindexを返す。
- torch.argmaxはtorch.maxの2つ目のみを返す。
- axisを指定しないと1行にreshapeされた後にargmaxが行われる。
import torch a = torch.rand((2,3,4,5)) print(a.shape) # torch.Size([2, 3, 4, 5]) print(a.argmax(axis=0).shape) # torch.Size([3, 4, 5]) print(a.argmax(axis=1).shape) # torch.Size([2, 4, 5]) print(a.argmax(axis=2).shape) # torch.Size([2, 3, 5]) print(a.argmax(axis=3).shape) # torch.Size([2, 3, 4]) print(a.argmax(axis=-1).shape) # torch.Size([2, 3, 4]) a = torch.randn(4, 4) print(a) ''' tensor([[-0.3592, -0.0374, 0.5395, 0.6372], [-1.0343, 0.4862, -1.1469, 0.0645], [ 0.1129, -0.7580, -1.0913, -1.6897], [ 1.1991, -0.3005, 0.6481, 0.5518]]) ''' print(torch.argmax(a)) # Tensor(12) 1行に直される