[PYTHON] Les spécifications Tensor max, min de PyTorch ne sont pas pratiques

introduction

Le Tensor de PyTorch peut être traité comme le ndarray de NumPy, mais il y a quelques différences.

En particulier, max () et min (), qui trouvent les valeurs maximale et minimale de Tensor, sont souvent utilisés mais ont un comportement différent de NumPy et sont difficiles à gérer. Pour atténuer cette différence, j'ai écrit une fonction pour trouver les valeurs maximales et minimales de Tensor dans un style similaire à NumPy.

Différence entre «max ()» de PyTorch et «max ()» de NumPy

Si vous voulez trouver la valeur maximale totale de Tensor ou ndarray, il n'y a aucune différence entre les deux. Autrement dit, les deux expressions suivantes renvoient la même valeur (bien que de types différents).

PyTorch


torch.tensor([1,2,3,4,5]).max()

NumPy


np.array([1,2,3,4,5]).max()

Cependant, le comportement est différent lorsque ʻaxis` est spécifié dans NumPy, c'est-à-dire lorsque vous voulez trouver la valeur maximale le long d'un axe spécifique.

Dans NumPy, vous pouvez spécifier plusieurs axes pour l'axe en appuyant sur, etc., et vous pouvez l'utiliser pour trouver la valeur maximale pour chaque canal de l'image.

NumPy


x = np.array([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(axis=(1,2))    # -> [8, 9, 8]

Cependant, ** Tensor max () ne peut spécifier qu'un seul axe à la fois ** (et le nom de l'argument est dim au lieu de ʻaxis`).

De plus, dans max () de Tensor, ** la valeur maximale et l'index (= argmax) indiquant sa position sont renvoyés sous forme de taple. ** **

Cette spécification est très maladroite, et si vous voulez faire la même chose que NumPy ci-dessus, vous devez faire quelque chose de très maladroit, en prenant la 0ème valeur de retour et en superposant max.

PyTorch


x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]

De plus, dans le cas du code ci-dessus, en fonction de l'ordre dans lequel dim est spécifié, l'indice de l'axe se décale chaque fois que max () est exécuté, de sorte que la valeur maximale de l'axe comme prévu ne peut pas être obtenue. Cela peut aussi être un foyer.

La fonction que j'ai écrite

def tensor_max(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.max(dim=ax, keepdim=keepdims)[0]

    return x

def tensor_min(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.min(dim=ax, keepdim=keepdims)[0]

    return x

Par exemple, pour trouver la valeur maximale de «x», écrivez comme «tensor_max (x)». Comme avec NumPy, vous pouvez spécifier ʻaxisetkeepdims` comme arguments.

Même si plusieurs axes sont spécifiés, max (min) est appliqué dans l'ordre décroissant, de sorte que le résultat ne change pas en fonction de l'ordre spécifié.

En utilisant cela, vous pouvez simplement écrire les valeurs maximales et minimales sur plusieurs axes avec PyTorch.

x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]    # before
tensor_max(x, axis=(1,2))        # after

en conclusion

PyTorch est pratique, mais il est souvent déroutant car les spécifications de Tensor semblent similaires à NumPy et subtilement différentes. Je veux que tu sois unifié.

Recommended Posts

Les spécifications Tensor max, min de PyTorch ne sont pas pratiques
[PyTorch] Échantillon ② ~ TENSOR ~
[Pytorch] numpy à tenseur