[PYTHON] Résumons la fonction de reporting de Chainer

tl;dr

Vous pouvez obtenir la journalisation minimale gratuitement en écrivant trainer.extend (extensions.LogReport ()). Vous pouvez être heureux si vous gardez ʻextensions.ParameterStatistics`.

Objectif

Cela fait un an que Training Loop Abstraction a été introduit dans Chainer en juin 2016, v1.11. Je pense qu'il y a beaucoup de gens qui veulent utiliser les actifs (dette?) De la boucle d'apprentissage Oreore et n'y ont pas touché. Puisque je veux maîtriser Trainer et d'autres choses de manière bâclée, j'ai résumé les rapports sur les métriques d'apprentissage qui ne sont pas si populaires (mais très importantes).

Comprendre le mécanisme de reporting dans Chainer

Il est très important de surveiller diverses métriques pour l'apprentissage en profondeur. Par exemple, Tensorflow dispose d'une puissante fonctionnalité de création de rapports appelée summary.

Officiellement, il existe une classe appelée Reporter qui semble très bien rapporter Documentation, mais Official Si vous regardez l'exemple, la classe «Reporter» n'apparaît nulle part, et la précision, etc. est clairement indiquée. Il n'y a pas de place pour l'écrire. Qu'est-ce que ça veut dire?

Si vous regardez Exemple officiel L96-97, vous pouvez voir LogReport et` J'ajoute des "extensions" telles que PrintReport "à l'objet" Trainer ".

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

Et regardez chainer / chainer / training / trainer.py: L291-L299 En y regardant, nous déclarons with reporter.scope (self.observation) au début de la boucle d'apprentissage. Cette déclaration garantit que tous les appels à chainer.reporter.report ({'name': value_to_report}) effectués dans la boucle d'apprentissage sont stockés dans self.observation.

    def run(self):
        ....
        reporter = self.reporter
        stop_trigger = self.stop_trigger

        # main training loop
        try:
            while not stop_trigger(self):
                self.observation = {}
                with reporter.scope(self.observation):
                    update()
                    for name, entry in extensions:
                        if entry.trigger(self):
                             entry.extension(self)

En d'autres termes, même si vous ne dites pas explicitement «Reporter», les métriques sont en fait collectées dans les coulisses. Au fait, concernant les données collectées, en appelant ʻentry.extension (self) `, [chainer / chainer / training / extensions / log_report.py: L67-L88](https://github.com/chainer/chainer/ Il est passé à blob / v2.0.0 / chainer / training / extensions / log_report.py # L67-L88).

    def __call__(self, trainer):
        # accumulate the observations
        keys = self._keys
        observation = trainer.observation
        summary = self._summary

        if keys is None:
            summary.add(observation)
        else:
            summary.add({k: observation[k] for k in keys if k in observation})

        if self._trigger(trainer):
            # output the result
            stats = self._summary.compute_mean()
            stats_cpu = {}
            for name, value in six.iteritems(stats):
                stats_cpu[name] = float(value)  # copy to CPU

            updater = trainer.updater
            stats_cpu['epoch'] = updater.epoch
            stats_cpu['iteration'] = updater.iteration
            stats_cpu['elapsed_time'] = trainer.elapsed_time

Le nombre d'époques, etc. est ajouté dans cette fonction et sorti à l'endroit approprié. Si aucune extension avec une fonction de rapport n'est enregistrée, les données seront simplement supprimées.

Maintenant, je vois pourquoi je n'ai pas lu explicitement Reporter dans l'exemple officiel. Mais pourquoi la précision (ʻaccuracy) et la perte (loss) sont-elles alors que je n'ai jamais appelé chainer.reporter.report`?

Alors, jetez un œil à chainer / chainer / links / model / classifier.py Si vous regardez, vous pouvez voir que chainer.reporter.report est appelé dans l'implémentation officielle.

        self.loss = self.lossfun(self.y, t)
        reporter.report({'loss': self.loss}, self)
        if self.compute_accuracy:
            self.accuracy = self.accfun(self.y, t)
            reporter.report({'accuracy': self.accuracy}, self)

En d'autres termes, si vous écrivez simplement trainer.extend (extensions.LogReport ()), vous obtiendrez la journalisation minimale requise, et si vous appelez simplement chainer.reporter.report dans votre modèle, vous pouvez faire n'importe quel rapport. Tu peux le faire. Pratique.

Au fait, si vous exécutez l'exemple ci-dessus, vous obtiendrez le rapport suivant dans result / log.

[{u'elapsed_time': 6.940603971481323,
  u'epoch': 1,
  u'iteration': 600,
  u'main/accuracy': 0.9213500021273892,
  u'main/loss': 0.2787705701092879,
  u'validation/main/accuracy': 0.9598000049591064,
  u'validation/main/loss': 0.13582063710317016},
 {u'elapsed_time': 14.360282897949219,
  u'epoch': 2,
  u'iteration': 1200,
  ...

Essayez d'augmenter le contenu des rapports

C'est assez pratique, mais utiliser ʻextensions.ParameterStatistics` est riche comme le [tf.summary.histogram] de Tensorflow (https://www.tensorflow.org/api_docs/python/tf/summary/histogram). La surveillance est possible.

...
trainer.extend(extensions.ParameterStatistics(model))
...

La valeur représentative de chaque matrice de lien incluse dans le modèle est automatiquement collectée et ajoutée au résultat. C'est très pratique.

[{u'None/predictor/l1/W/data/max': 0.18769985591371854,
  u'None/predictor/l1/W/data/mean': 0.0006860141372822189,
  u'None/predictor/l1/W/data/min': -0.21658104345202445,
  u'None/predictor/l1/W/data/percentile/0': -0.1320047355272498,
  u'None/predictor/l1/W/data/percentile/1': -0.08497818301255008,
  u'None/predictor/l1/W/data/percentile/2': -0.04122352957670082,
  u'None/predictor/l1/W/data/percentile/3': 0.0008963784146650747,
  u'None/predictor/l1/W/data/percentile/4': 0.0428067545834066,
  ...

Le résultat de l'exécution ci-dessus est dans gist.

Recommended Posts

Résumons la fonction de reporting de Chainer
Résumons Apache
Essayez de dessiner une fonction logistique
Résumons brièvement LPIC niveau 1 (102)
Résumons brièvement LPIC niveau 1 (101 éditions)