[PYTHON] Use tensorboard with NNabla

tensorboard is a very useful tool for drawing loss curves, histograms and images during learning. I recently used a Sony neural network frame NNabla (https://nnabla.org/), but I didn't have a visualization tool, so I made a python package so that NNabla can also use tensorboard.

https://github.com/naibo-code/nnabla_tensorboard

The basics are based on "tensorboardX for pytorch".

How to use

Basically, you can see what it looks like by running demp.py. It supports drawing of scalars, histograms, images, etc.

# Install
pip install 'git+https://github.com/naibo-code/nnabla_tensorboard.git'

# Demo
python examples/demo.py

Scalar

scaler

histogram

histogram

Character output

text

Visualize MNIST learning with NNabla + tensorboard

NNabla provides some examples in this repository https://github.com/sony/nnabla-examples/. This time, we will use MNIST learning code to visualize the learning results in real time on the tensorboard. I tried to.

Only these two functions should be changed (only the part marked NEW). Also, import the package at the beginning of the file with from nnabla_tensorboard import SummaryWriter.

from nnabla_tensorboard import SummaryWriter


def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)

    # For tensorboard (NEW)
    tb_writer = SummaryWriter(args.monitor_path)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation (NEW)
            validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        loss.data.cast(np.float32, ctx)
        pred.data.cast(np.float32, ctx)
        e = categorical_error(pred.d, label.d)

        # Instead of using nnabla.monitor, use nnabla_tensorboard. (NEW)
        if i % args.val_interval == 0:
            tb_writer.add_image('image/train_data_{}'.format(i), image.d[0])

        tb_writer.add_scalar('train/loss', loss.d.copy(), global_step=i)
        tb_writer.add_scalar('train/error', e, global_step=i)
        monitor_time.add(i)

    validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)

    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': F.softmax(vpred)},
             'names': {'x': vimage}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(os.path.join(args.model_save_path,
                           '{}_result.nnp'.format(args.net)), runtime_contents)

    tb_writer.close()
def validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer):
    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        vpred.data.cast(np.float32, ctx)
        ve += categorical_error(vpred.d, vlabel.d)
    tb_writer.add_scalar('test/error', ve / args.val_iter, i)

NNabla + tensorboard: MNIST execution result

Learning curve mnist_curve.png

I also plot the input image. mnist_image.png

You don't have to draw with your own script, and tensorboard is convenient after all.

Features you want to add

Recommended Posts

Use tensorboard with NNabla
Use tensorboard with Chainer
Use mecab-ipadic-neologd with igo-python
Use RTX 3090 with PyTorch
Use ansible with cygwin
Use pipdeptree with virtualenv
[Python] Use JSON with Python
Use Mock with pytest
Use indicator with pd.merge
Use Gentelella with django
Use mecab with Python3
Use DynamoDB with Python
Use pip with MSYS2
Use Python 3.8 with Anaconda
Use pyright with Spacemacs
Story of trying to use tensorboard with pytorch
Use TypeScript with django-compressor
Use MySQL with Django
Use GPS with Edison
Use nim with Jupyter
Use shared memory with shared libraries
Use "$ in" operator with mongo-go-driver
Use custom tags with PyYAML
Use directional graphs with networkx
Use TensorFlow with Intellij IDEA
Use Twitter API with Python
Use pip with Jupyter Notebook
Use DATE_FORMAT with SQLAlchemy filter
Use TUN / TAP with Python
Use sqlite3 with NAO (Pepper)
Use sqlite load_extensions with Pyramid
Use Windows 10 fonts with WSL
Use chainer with Jetson TK1
Use SSL with Celery + Redis
Use Cython with Jupyter Notebook
Use Maxout + CNN with Pylearn2
Use WDC-433SU2M2 with Manjaro Linux
Use OpenBLAS with numpy, scipy
Use subsonic API with python3
Use Sonicwall NetExtener with Systemd
Use prefetch_related conveniently with Django
Use AWS interpreter with Pycharm
Use Bokeh with IPython Notebook
Use Python-like range with Rust
Use MLflow with Databricks ④ --Call model -
Use pyright with CentOS7, emacs lsp-mode
Python: How to use async with
Use Azure SQL Database with SQLAlchemy
Use vl53l0x with Raspberry Pi (python)
Use PX-S1UD / PX-Q1UD with Jetson nano
Use the preview feature with aws-cli
How to use virtualenv with PowerShell
[Python] Use Basic/Digest authentication with Flask
Use NAIF SPICE TOOLKIT with Python
Use rospy with virtualenv in Python3
Use markdown with jupyter notebook (with shortcut)
Use Python in pyenv with NeoVim
Use Tensorflow 2.1.0 with Anaconda on Windows 10!
How to use FTP with Python
Use Windows 10 speech synthesis with Python
I can't use Japanese with pyperclip