[PYTHON] Utiliser tensorboard avec NNabla

tensorboard est un outil très utile pour dessiner des courbes de perte, des histogrammes et des images pendant l'entraînement. J'ai récemment utilisé le cadre de réseau neuronal de Sony NNabla (https://nnabla.org/), mais je n'avais pas d'outil de visualisation, j'ai donc créé un package python pour que NNabla puisse également utiliser tensorboard.

https://github.com/naibo-code/nnabla_tensorboard

Les bases sont basées sur "tensorboardX for pytorch".

Comment utiliser

En gros, vous pouvez voir à quoi il ressemble en exécutant demp.py. Il prend en charge le dessin de scalaires, d'histogrammes, d'images, etc.

# Install
pip install 'git+https://github.com/naibo-code/nnabla_tensorboard.git'

# Demo
python examples/demo.py

Scalaire

scaler

histogramme

histogram

Sortie de caractères

text

Visualisez l'apprentissage MNIST avec NNabla + tensorboard

NNabla fournit quelques exemples dans ce référentiel https://github.com/sony/nnabla-examples/. Cette fois, nous utiliserons le code d'apprentissage MNIST pour visualiser les résultats d'apprentissage en temps réel sur le tensorboard. J'ai essayé de.

Seules ces deux fonctions doivent être modifiées (uniquement la partie marquée «NEW»). De même, importez le package au début du fichier avec from nnabla_tensorboard import SummaryWriter.

from nnabla_tensorboard import SummaryWriter


def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)

    # For tensorboard (NEW)
    tb_writer = SummaryWriter(args.monitor_path)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation (NEW)
            validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)

        # Instead of using nnabla.monitor, use nnabla_tensorboard. (NEW)
        if i % args.val_interval == 0:
            tb_writer.add_image('image/train_data_{}'.format(i), image.d[0])

        tb_writer.add_scalar('train/loss', loss.d.copy(), global_step=i)
        tb_writer.add_scalar('train/error', e, global_step=i)
        monitor_time.add(i)

    validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)

    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': F.softmax(vpred)},
             'names': {'x': vimage}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(os.path.join(args.model_save_path,
                           '{}_result.nnp'.format(args.net)), runtime_contents)

    tb_writer.close()
def validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer):
    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        vpred.data.cast(np.float32, ctx)
        ve += categorical_error(vpred.d, vlabel.d)
    tb_writer.add_scalar('test/error', ve / args.val_iter, i)

NNabla + tensorboard: résultat de l'exécution MNIST

Courbe d'apprentissage mnist_curve.png

Je trace également l'image d'entrée. mnist_image.png

Vous n'êtes pas obligé de dessiner avec votre propre script, et tensorboard est pratique après tout.

Fonctionnalités que vous souhaitez ajouter

Recommended Posts

Utiliser tensorboard avec NNabla
Utiliser tensorboard avec Chainer
Utilisez mecab-ipadic-neologd avec igo-python
Utilisez RTX 3090 avec PyTorch
Utiliser ansible avec cygwin
Utiliser pipdeptree avec virtualenv
[Python] Utiliser JSON avec Python
Utilisez Mock avec pytest
Utiliser l'indicateur avec pd.merge
Utiliser Gentelella avec Django
Utiliser mecab avec Python 3
Utiliser DynamoDB avec Python
Utiliser pip avec MSYS2
Utilisez Python 3.8 avec Anaconda
Utiliser les droits d'auteur avec Spacemacs
Histoire d'essayer d'utiliser Tensorboard avec Pytorch
Utiliser TypeScript avec django-compresseur
Utiliser MySQL avec Django
Utiliser le GPS avec Edison
Utilisez nim avec Jupyter
Utiliser la mémoire partagée avec une bibliothèque partagée
Utiliser des balises personnalisées avec PyYAML
Utiliser des graphiques directionnels avec networkx
Utiliser TensorFlow avec Intellij IDEA
Utiliser l'API Twitter avec Python
Utiliser pip avec Jupyter Notebook
Utiliser DATE_FORMAT avec le filtre SQLAlchemy
Utiliser TUN / TAP avec Python
Utilisez sqlite3 avec NAO (Pepper)
Utilisez les load_extensions de sqlite avec Pyramid
Utiliser les polices Windows 10 avec WSL
Utilisation du chainer avec Jetson TK1
Utiliser SSL avec Celery + Redis
Utiliser Cython avec Jupyter Notebook
Utilisez Maxout + CNN avec Pylearn2
Utilisez WDC-433SU2M2 avec Manjaro Linux
Utilisez OpenBLAS avec numpy, scipy
Utiliser l'API subsonique avec python3
Utilisation de Sonicwall NetExtener avec Systemd
Utilisez prefetch_related commodément avec Django
Utiliser l'interpréteur AWS avec Pycharm
Utilisation de Bokeh avec IPython Notebook
Utiliser une plage de type Python avec Rust
Utiliser MLflow avec Databricks ④ --Call model -
Utiliser pyright avec CentOS7, emacs lsp-mode
Python: comment utiliser async avec
Utilisation de la base de données SQL d'Azure avec SQL Alchemy
Utilisez vl53l0x avec RaspberryPi (python)
Utilisez PX-S1UD / PX-Q1UD avec Jetson nano
Utilisez la fonction de prévisualisation avec aws-cli
Pour utiliser virtualenv avec PowerShell
Utilisez NAIF SPICE TOOL KIT avec Python
Utiliser rospy avec virtualenv dans Python3
Utiliser Markdown avec le notebook Jupyter (avec raccourci)
Utiliser Python mis en pyenv avec NeoVim
Utilisez Tensorflow 2.1.0 avec Anaconda sur Windows 10!
Utiliser la synthèse vocale Windows 10 avec Python
Je ne peux pas utiliser le japonais avec pyperclip