[PYTHON] [PyTorch] Why you can treat an instance of CrossEntropyLoss () like a function

Instance = function? ?? ?? ??

[Learn while making! Development Deep Learning by PyTorch](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3%81%8C% E3% 82% 89% E5% AD% A6% E3% 81% B6% EF% BC% 81PyTorch% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E7% 99% BA% E5% B1% 95% E3% 83% 87% E3% 82% A3% E3% 83% BC% E3% 83% 97% E3% 83% A9% E3% 83% BC% E3% 83% 8B% E3% 83% Read the book B3% E3% 82% B0-% E5% B0% 8F% E5% B7% 9D-% E9% 9B% 84% E5% A4% AA% E9% 83% 8E-ebook / dp / B07VPDVNKW) There was such a description in 1-3 transfer learning. (You can see all the code at author GitHub)

1-3_transfer_learning.ipynb


#Package import
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

(Omission)

#Loss function settings
criterion = nn.CrossEntropyLoss()

(Omission)

def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    #epoch loop
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        #Learning and verification loop for each epoch
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  #Put the model in training mode
            else:
                net.eval()   #Put the model in validation mode

            epoch_loss = 0.0  #epoch loss sum
            epoch_corrects = 0  #Number of correct answers for epoch

            #Epoch to check the verification performance when unlearned=0 training omitted
            if (epoch == 0) and (phase == 'train'):
                continue

            #Loop to retrieve mini-batch from data loader
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                #Initialize optimizer
                optimizer.zero_grad()

                #Forward calculation
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)  #Calculate loss
                    _, preds = torch.max(outputs, 1)  #Predict label
                    
  
                    #Backpropagation during training
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    #Calculation of iteration results
                    #Update total loss
                    epoch_loss += loss.item() * inputs.size(0)  
                    #Updated the total number of correct answers
                    epoch_corrects += torch.sum(preds == labels.data)

            #Display loss and correct answer rate for each epoch
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double(
            ) / len(dataloaders_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

What I would like you to pay attention to here is ** criteria **. It is defined as an instance of nn.CrossEntropyLoss () as follows:

1-3_transfer_learning.ipynb


criterion = nn.CrossEntropyLoss()

And I treat ** criteria ** like a function.

1-3_transfer_learning.ipynb


loss = criterion(outputs, labels)

However, when I check the source code of torch.nn.CrossEntropyLoss, there is no description of __call__method **! So ** why can you treat an instance of CrossEntropyLoss () like a function? ** ** The purpose of this article is to solve this mystery. See here for why the presence or absence of the __call__ method is important.

About class inheritance

The beginning of the source code of the CrossEntropyLoss class is written as follows.

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):

First of all, what does it mean to put something in parentheses when defining a class in Python? This is called ** class inheritance **, and is used when calling a function or method defined in another class as it is. (The following specific example is quoted from here)

#Inheritance
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def world(self):
        print("World")

a = MyClass2()
a.hello() # Hello
a.world() # World

The caveat here is that if a parent and child class have methods with the same name defined, the child class's method will be overwritten. This is called an override.

#override
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def hello(self):        #Parent class hello()Overwrite method
        print("HELLO")

a = MyClass2()
a.hello()                   # HELLO

And I want to use the method defined in the parent class for the method of the child class! You can use the super () function when you think about it.

class MyClass1:
    def __init__(self):
       self.val1 = 123

class MyClass2(MyClass1):
    def __init__(self):
        super().__init__()
        self.val2 = 456

a = MyClass2()
print(a.val1) # 123
print(a.val2) # 456

Returning to the story, the CrossEntropyLoss class inherits from the _WeightedLoss class. By the way, if you check the code of CrossEntropyLoss a little more,

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):

__constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

The description is a little different from the above example because it is super (CrossEntropyLoss, self), but [Python official](https://docs.python.org/ja/3/library/functions.html?highlight If you refer to = super # super), you can see that the meanings of both are exactly the same.

From the official


class C(B):
    def method(self, arg):
        super().method(arg)    # This does the same thing as:
                               # super(C, self).method(arg)

Now let's take a look at the description of the _WeitedLoss class.

Python:torch.nn.modules.loss


class _WeightedLoss(_Loss):

From this, we can see that _WeitedLoss inherits from _Loss. Now let's take a look at the description of the _WeitedLoss class.

Python:torch.nn.modules.loss


class _Loss(Module):

From this, we can see that _Loss inherits from Module. Now let's take a look at the description of the Module class.

Python:torch.nn.modules.module


class Module:

Module does not inherit anything! So let's check from the contents of Module.

torch.nn.Module

The _Loss class inherits the __init __ method of the Module class, so check this only. I will try.

Python:torch.nn.modules.module


#note:Not all codes are listed
from collections import OrderedDict, namedtuple

class Module:
    _version: int = 1

    training: bool

    dump_patches: bool = False
    
    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

You can see that many ʻOrderedDict () are defined here. For more information on ʻOrederedDict (), please refer to here, but to put it simply, as the name suggests, ** " An empty dictionary (dict) ** that is ordered. In other words, this class just defines a lot of empty dictionaries.

And, well, the __call__method in question is actually defined here!

Python:torch.nn.modules.module


def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

    __call__ : Callable[..., Any] = _call_impl

In the last line __call__: Callable [..., Any] = _call_impl, the content of __call__ is _call_impl, so if you call the instance like a function, the above function will be executed. If you don't understand the meaning of Callable [..., Any], please refer to here. Also, this colon is a function annotation, please refer to here for details. To put it simply, it simply "writes an expression that serves as an annotation in the argument and return value of the function."

I'll follow the meaning of this code in this article.

In addition to the above, some methods are defined in the Module class, so check if necessary.

The following can be read by scanning.

torch.nn._Loss

The _WeightedLoss class inherits the __init__ method of the _Loss class. I will check it.

Python:torch.nn.modules.loss


class _Loss(Module):
reduction: str

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

Here you can see that we are introducing a new self.reduction. And that value seems to depend on the values of size_average and reduce.

torch.nn.__WeightedLoss

The __init__ method of the _WeightedLoss class is inherited by the CrossEntropyLoss class. I will check it.

Python:torch.nn.modules.loss


class _WeightedLoss(_Loss):
    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)

Here, ʻOptional [Tensor]is specified in the function annotation ofweight. The explanation of [here](https://python.ms/union-and-optional/) is easy to understand. Simply put, weight means that either Tensor typeorNone type` can be included.

Let's get back to the main subject. There is a new function here called self.register_buffer, which is a function defined in the Module class. Below is the source code.

Python:torch.nn.modules.module


forward: Callable[..., Any] = _forward_unimplemented

    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
        r"""Adds a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the module's state. Buffers, by
        default, are persistent and will be saved alongside parameters. This
        behavior can be changed by setting :attr:`persistent` to ``False``. The
        only difference between a persistent buffer and a non-persistent buffer
        is that the latter will not be a part of this module's
        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name
            tensor (Tensor): buffer to be registered.
            persistent (bool): whether the buffer is part of this module's
                :attr:`state_dict`.

        Example::

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:
            raise AttributeError(
                "cannot assign buffer before Module.__init__() call")
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("buffer name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

It's a fairly long code, but the upper half is the explanation of the code, and the part above ʻelse of the ʻif statement only sets the error, so the explanation is omitted. And in ʻelse, you put elements in self._buffersofdict type. In other words, by defining the WeightedLoss class`, we have:

self._buffer = {'weight': weight} #The weight on the right is Tensor type or None type

torch.nn.CrossEntropyLoss Finally, I have come back to the question. Below is the source code. I have a long comment out, but I will quote all of them.

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):
    r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.

    It is useful when training a classification problem with `C` classes.
    If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
    assigning weight to each of the classes.
    This is particularly useful when you have an unbalanced training set.

    The `input` is expected to contain raw, unnormalized scores for each class.

    `input` has to be a Tensor of size either :math:`(minibatch, C)` or
    :math:`(minibatch, C, d_1, d_2, ..., d_K)`
    with :math:`K \geq 1` for the `K`-dimensional case (described later).

    This criterion expects a class index in the range :math:`[0, C-1]` as the
    `target` for each value of a 1D tensor of size `minibatch`; if `ignore_index`
    is specified, this criterion also accepts this class index (this index may not
    necessarily be in the class range).

    The loss can be described as:

    .. math::
        \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
                       = -x[class] + \log\left(\sum_j \exp(x[j])\right)

    or in the case of the :attr:`weight` argument being specified:

    .. math::
        \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)

    The losses are averaged across observations for each minibatch. If the
    :attr:`weight` argument is specified then this is a weighted average:

    .. math::
        \text{loss} = \frac{\sum^{N}_{i=1} loss(i, class[i])}{\sum^{N}_{i=1} weight[class[i]]}

    Can also be used for higher dimension inputs, such as 2D images, by providing
    an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
    where :math:`K` is the number of dimensions, and a target of appropriate shape
    (see below).


    Args:
        weight (Tensor, optional): a manual rescaling weight given to each class.
            If given, has to be a Tensor of size `C`
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there are multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when reduce is ``False``. Default: ``True``
        ignore_index (int, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient. When :attr:`size_average` is
            ``True``, the loss is averaged over non-ignored targets.
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
            be applied, ``'mean'``: the weighted mean of the output is taken,
            ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in
            the meantime, specifying either of those two args will override
            :attr:`reduction`. Default: ``'mean'``

    Shape:
        - Input: :math:`(N, C)` where `C = number of classes`, or
          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
          in the case of `K`-dimensional loss.
        - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
          K-dimensional loss.
        - Output: scalar.
          If :attr:`reduction` is ``'none'``, then the same size as the target:
          :math:`(N)`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
          of K-dimensional loss.

    Examples::

        >>> loss = nn.CrossEntropyLoss()
        >>> input = torch.randn(3, 5, requires_grad=True)
        >>> target = torch.empty(3, dtype=torch.long).random_(5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    __constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

First of all, in the __init__ method, a new variable called self.ignore_index has been added. And a function called forward () is also defined. However, the __call__ method has not been defined since the Module class. Therefore, the __call__ method of the Module class was the identity that the instance of the CrossEntropyLoss class was used like a function.

In this article, I'd like to get a closer look at what happens when you treat an instance of CrossEntropyLoss () like a function!

Recommended Posts

[PyTorch] Why you can treat an instance of CrossEntropyLoss () like a function
[PyTorch] A little understanding of CrossEntropyLoss with mathematical formulas
Create an instance of a predefined class from a string in Python
[Road to Python Intermediate] Call a class instance like a function with __call__
If you give a list with the default argument of the function ...
Utilization of lambda (when passing a function as an argument of another function)