[PYTHON] PyTorch Tensor max, min Spezifikationen sind unpraktisch

Einführung

PyTorchs Tensor kann grundsätzlich wie NumPys Ndarray behandelt werden, es gibt jedoch einige Unterschiede.

Insbesondere werden häufig "max ()" und "min ()" verwendet, die die Maximal- und Minimalwerte von Tensor ermitteln, sich jedoch von NumPy unterscheiden und schwer zu handhaben sind. Um diesen Unterschied zu verringern, habe ich eine Funktion geschrieben, um die Maximal- und Minimalwerte von Tensor in einem ähnlichen Stil wie NumPy zu ermitteln.

Unterschied zwischen "max ()" von PyTorch und "max ()" von NumPy

Wenn Sie den Gesamtmaximalwert von Tensor oder ndarray ermitteln möchten, gibt es keinen Unterschied zwischen den beiden. Das heißt, beide der folgenden Ausdrücke geben denselben Wert zurück (obwohl sie unterschiedlichen Typs sind).

PyTorch


torch.tensor([1,2,3,4,5]).max()

NumPy


np.array([1,2,3,4,5]).max()

Das Verhalten ist jedoch anders, wenn Sie in NumPy "Achse" angeben, dh wenn Sie den Maximalwert entlang einer bestimmten Achse ermitteln möchten.

In NumPy können Sie durch Tippen usw. mehrere Achsen für die Achse angeben und damit den Maximalwert für jeden Kanal des Bildes ermitteln.

NumPy


x = np.array([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(axis=(1,2))    # -> [8, 9, 8]

** Tensors "max ()" kann jedoch jeweils nur eine Achse angeben ** (und der Argumentname lautet "dim" anstelle von "axis").

Außerdem werden in Tensors "max ()" ** der Maximalwert und der Index (= argmax), der seine Position angibt, als Taple zurückgegeben. ** ** **

Diese Spezifikation ist sehr umständlich, und wenn Sie dasselbe wie NumPy oben tun möchten, müssen Sie etwas sehr Klobiges tun, indem Sie den 0. Rückgabewert nehmen und max.

PyTorch


x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]

Im Fall des obigen Codes verschiebt sich der Index der Achse in Abhängigkeit von der Reihenfolge, in der dim angegeben ist, jedes Mal, wenn max () ausgeführt wird, so dass der beabsichtigte Maximalwert der Achse nicht erhalten werden kann. Es kann auch eine Brutstätte sein.

Die Funktion, die ich geschrieben habe

def tensor_max(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.max(dim=ax, keepdim=keepdims)[0]

    return x

def tensor_min(x, axis=None, keepdims=False):
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.min(dim=ax, keepdim=keepdims)[0]

    return x

Um beispielsweise den Maximalwert von "x" zu ermitteln, schreiben Sie als "tensor_max (x)". Wie bei NumPy können Sie "axis" und "keepdims" als Argumente angeben.

Selbst wenn mehrere Achsen angegeben sind, wird max (min) in absteigender Reihenfolge angewendet, sodass sich das Ergebnis in Abhängigkeit von der angegebenen Reihenfolge nicht ändert.

Auf diese Weise kann PyTorch einfach die Maximal- und Minimalwerte über mehrere Achsen schreiben.

x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]    # before
tensor_max(x, axis=(1,2))        # after

abschließend

PyTorch ist praktisch, aber oft verwirrend, da die Spezifikationen von Tensor NumPy ähnlich und leicht unterschiedlich zu sein scheinen. Ich möchte, dass du vereint bist.

Recommended Posts

PyTorch Tensor max, min Spezifikationen sind unpraktisch
[PyTorch] Probe ② ~ TENSOR ~
[Pytorch] numpy bis Tensor