[PYTHON] Verwenden Sie Tensorboard mit NNabla

Tensorboard ist ein sehr nützliches Werkzeug zum Zeichnen von Verlustkurven, Histogrammen und Bildern während des Trainings. Ich habe kürzlich Sonys neuronalen Netzwerkrahmen NNabla (https://nnabla.org/) verwendet, aber ich hatte kein Visualisierungstool, also habe ich ein Python-Paket erstellt, damit NNabla auch Tensorboard verwenden kann.

https://github.com/naibo-code/nnabla_tensorboard

Die Grundlagen basieren auf "tensorboardX for pytorch".

Wie benutzt man

Grundsätzlich können Sie sehen, wie es aussieht, indem Sie demp.py ausführen. Es unterstützt das Zeichnen von Skalaren, Histogrammen, Bildern usw.

# Install
pip install 'git+https://github.com/naibo-code/nnabla_tensorboard.git'

# Demo
python examples/demo.py

Skalar

scaler

Histogramm

histogram

Zeichenausgabe

text

Visualisieren Sie das MNIST-Lernen mit NNabla + Tensorboard

NNabla bietet einige Beispiele in diesem Repository https://github.com/sony/nnabla-examples/. Dieses Mal verwenden wir MNIST-Lerncode, um die Lernergebnisse in Echtzeit auf dem Tensorboard zu visualisieren. Ich versuchte zu.

Nur diese beiden Funktionen sollten geändert werden (nur der mit "NEU" gekennzeichnete Teil). Importieren Sie außerdem das Paket am Anfang der Datei mit "from nnabla_tensorboard import SummaryWriter".

from nnabla_tensorboard import SummaryWriter


def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)

    # For tensorboard (NEW)
    tb_writer = SummaryWriter(args.monitor_path)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation (NEW)
            validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)

        # Instead of using nnabla.monitor, use nnabla_tensorboard. (NEW)
        if i % args.val_interval == 0:
            tb_writer.add_image('image/train_data_{}'.format(i), image.d[0])

        tb_writer.add_scalar('train/loss', loss.d.copy(), global_step=i)
        tb_writer.add_scalar('train/error', e, global_step=i)
        monitor_time.add(i)

    validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)

    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': F.softmax(vpred)},
             'names': {'x': vimage}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(os.path.join(args.model_save_path,
                           '{}_result.nnp'.format(args.net)), runtime_contents)

    tb_writer.close()
def validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer):
    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        vpred.data.cast(np.float32, ctx)
        ve += categorical_error(vpred.d, vlabel.d)
    tb_writer.add_scalar('test/error', ve / args.val_iter, i)

NNabla + Tensorboard: MNIST-Ausführungsergebnis

Lernkurve mnist_curve.png

Ich zeichne auch das Eingabebild. mnist_image.png

Sie müssen nicht mit Ihrem eigenen Skript zeichnen, und Tensorboard ist schließlich praktisch.

Funktionen, die Sie hinzufügen möchten

Recommended Posts

Verwenden Sie Tensorboard mit NNabla
Verwenden Sie Tensorboard mit Chainer
Verwenden Sie mecab-ipadic-neologd mit igo-python
Verwenden Sie RTX 3090 mit PyTorch
Verwenden Sie ansible mit cygwin
Verwenden Sie pipdeptree mit virtualenv
[Python] Verwenden Sie JSON mit Python
Verwenden Sie Mock mit Pytest
Verwenden Sie den Indikator mit pd.merge
Verwenden Sie Gentelella mit Django
Verwenden Sie Mecab mit Python 3
Verwenden Sie DynamoDB mit Python
Verwenden Sie pip mit MSYS2
Verwenden Sie Python 3.8 mit Anaconda
Verwenden Sie Copyright mit Spacemacs
Die Geschichte des Versuchs, Tensorboard mit Pytorch zu verwenden
Verwenden Sie TypeScript mit Django-Kompressor
Verwenden Sie MySQL mit Django
Verwenden Sie GPS mit Edison
Verwenden Sie nim mit Jupyter
Verwenden Sie gemeinsam genutzten Speicher mit gemeinsam genutzten Bibliotheken
Verwenden Sie benutzerdefinierte Tags mit PyYAML
Verwenden Sie Richtungsdiagramme mit networkx
Verwenden Sie TensorFlow mit Intellij IDEA
Verwenden Sie die Twitter-API mit Python
Verwenden Sie pip mit Jupyter Notebook
Verwenden Sie DATE_FORMAT mit dem SQLAlchemy-Filter
Verwenden Sie TUN / TAP mit Python
Verwenden Sie sqlite3 mit NAO (Pepper)
Verwenden Sie die load_extensions von sqlite mit Pyramid
Verwenden Sie Windows 10-Schriftarten mit WSL
Verwendung von Chainer mit Jetson TK1
Verwenden Sie SSL mit Sellerie + Redis
Verwenden Sie Cython mit Jupyter Notebook
Verwenden Sie Maxout + CNN mit Pylearn2
Verwenden Sie WDC-433SU2M2 mit Manjaro Linux
Verwenden Sie OpenBLAS mit numpy, scipy
Verwenden Sie die Unterschall-API mit Python3
Verwenden von Sonicwall NetExtener mit Systemd
Verwenden Sie prefetch_related bequem mit Django
Verwenden Sie einen AWS-Interpreter mit Pycharm
Verwenden von Bokeh mit IPython Notebook
Verwenden Sie Python-ähnliche Bereiche mit Rust
Verwenden Sie MLflow mit Databricks ④ - Anrufmodell -
Verwenden Sie pyright mit CentOS7, emacs lsp-mode
Python: So verwenden Sie Async mit
Verwenden der SQL-Datenbank von Azure mit SQL Alchemy
Verwenden Sie vl53l0x mit RaspberryPi (Python)
Verwenden Sie PX-S1UD / PX-Q1UD mit Jetson Nano
Verwenden Sie die Vorschaufunktion mit aws-cli
So verwenden Sie virtualenv mit PowerShell
Verwenden Sie NAIF SPICE TOOL KIT mit Python
Verwenden Sie rospy mit virtualenv in Python3
Markdown mit Jupyter-Notebook verwenden (mit Verknüpfung)
Verwenden Sie Python in pyenv mit NeoVim
Verwenden Sie Tensorflow 2.1.0 mit Anaconda unter Windows 10!
Verwenden Sie die Windows 10-Sprachsynthese mit Python
Ich kann kein Japanisch mit Pyperclip verwenden