[PYTHON] PyTorch Tensor max, min specifications are inconvenient

Introduction

PyTorch's Tensor can basically be treated like NumPy's ndarray, but there are some differences.

In particular, max () and min (), which find the maximum and minimum values of Tensor, are often used but behave differently from NumPy and are difficult to handle. To alleviate this difference, I wrote a function to find the maximum and minimum values of Tensor in a style similar to NumPy.

Difference between PyTorch's max () and NumPy's max ()

If you want to find the overall maximum of a Tensor or ndarray, there is no difference between the two. That is, both of the following expressions return the same value (although of different types).

PyTorch


torch.tensor([1,2,3,4,5]).max()

NumPy


np.array([1,2,3,4,5]).max()

However, the behavior is different when you specify ʻaxis` in NumPy, that is, when you want to find the maximum value along a specific axis.

In NumPy, you can specify multiple axes for axis by tuple etc., and you can use it to find the maximum value for each image channel.

NumPy


x = np.array([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(axis=(1,2))    # -> [8, 9, 8]

However, ** Tensor's max () can only specify one axis at a time ** (and the argument name is dim instead of ʻaxis`).

Furthermore, in Tensor's max (), ** the maximum value and the index (= argmax) indicating its position are returned as tuples. ** **

This specification is very awkward, and if you want to do the same thing as NumPy above, you have to do something very clunky, taking the 0th return value and overlaying max.

PyTorch


x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]

Also, in the case of the above code, depending on the order in which dim is specified, the index of the axis shifts every time max () is performed, so the maximum value of the axis as intended cannot be obtained. It can also be a hotbed.

The function I wrote

def tensor_max(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.max(dim=ax, keepdim=keepdims)[0]

    return x

def tensor_min(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.min(dim=ax, keepdim=keepdims)[0]

    return x

For example, to find the maximum value of x, write astensor_max (x). As with NumPy, you can specify ʻaxisandkeepdims` as arguments.

Even if multiple axes` are specified, max (min) is multiplied in descending order, so the result does not change depending on the specified order.

By using this, you can easily write the maximum and minimum values across multiple axes with PyTorch.

x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]    # before
tensor_max(x, axis=(1,2))        # after

in conclusion

PyTorch is useful, but it's often confusing because the Tensor specs seem to be similar to NumPy and subtly different. I want you to be unified.

Recommended Posts

PyTorch Tensor max, min specifications are inconvenient
[PyTorch] Sample ② ~ TENSOR ~
[Pytorch] numpy to tensor