tensorboard est un outil très utile pour dessiner des courbes de perte, des histogrammes et des images pendant l'entraînement. J'ai récemment utilisé le cadre de réseau neuronal de Sony NNabla (https://nnabla.org/), mais je n'avais pas d'outil de visualisation, j'ai donc créé un package python pour que NNabla puisse également utiliser tensorboard.
https://github.com/naibo-code/nnabla_tensorboard
Les bases sont basées sur "tensorboardX for pytorch".
En gros, vous pouvez voir à quoi il ressemble en exécutant demp.py
. Il prend en charge le dessin de scalaires, d'histogrammes, d'images, etc.
# Install
pip install 'git+https://github.com/naibo-code/nnabla_tensorboard.git'
# Demo
python examples/demo.py
NNabla fournit quelques exemples dans ce référentiel https://github.com/sony/nnabla-examples/. Cette fois, nous utiliserons le code d'apprentissage MNIST pour visualiser les résultats d'apprentissage en temps réel sur le tensorboard. J'ai essayé de.
Seules ces deux fonctions doivent être modifiées (uniquement la partie marquée «NEW»). De même, importez le package au début du fichier avec from nnabla_tensorboard import SummaryWriter
.
from nnabla_tensorboard import SummaryWriter
def train():
"""
Main script.
Steps:
* Parse command line arguments.
* Specify a context for computation.
* Initialize DataIterator for MNIST.
* Construct a computation graph for training and validation.
* Initialize a solver and set parameter variables to it.
* Create monitor instances for saving and displaying training stats.
* Training loop
* Computate error rate for validation data (periodically)
* Get a next minibatch.
* Execute forwardprop on the training graph.
* Compute training error
* Set parameter gradients zero
* Execute backprop.
* Solver updates parameters by using gradients computed by backprop.
"""
args = get_args()
from numpy.random import seed
seed(0)
# Get context.
from nnabla.ext_utils import get_extension_context
logger.info("Running in %s" % args.context)
ctx = get_extension_context(
args.context, device_id=args.device_id, type_config=args.type_config)
nn.set_default_context(ctx)
# Create CNN network for both training and testing.
if args.net == 'lenet':
mnist_cnn_prediction = mnist_lenet_prediction
elif args.net == 'resnet':
mnist_cnn_prediction = mnist_resnet_prediction
else:
raise ValueError("Unknown network type {}".format(args.net))
# TRAIN
# Create input variables.
image = nn.Variable([args.batch_size, 1, 28, 28])
label = nn.Variable([args.batch_size, 1])
# Create prediction graph.
pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train)
pred.persistent = True
# Create loss function.
loss = F.mean(F.softmax_cross_entropy(pred, label))
# TEST
# Create input variables.
vimage = nn.Variable([args.batch_size, 1, 28, 28])
vlabel = nn.Variable([args.batch_size, 1])
# Create prediction graph.
vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)
# Create Solver.
solver = S.Adam(args.learning_rate)
solver.set_parameters(nn.get_parameters())
# Create monitor.
from nnabla.monitor import Monitor, MonitorTimeElapsed
monitor = Monitor(args.monitor_path)
monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
# For tensorboard (NEW)
tb_writer = SummaryWriter(args.monitor_path)
# Initialize DataIterator for MNIST.
from numpy.random import RandomState
data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
vdata = data_iterator_mnist(args.batch_size, False)
# Training loop.
for i in range(args.max_iter):
if i % args.val_interval == 0:
# Validation (NEW)
validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)
if i % args.model_save_interval == 0:
nn.save_parameters(os.path.join(
args.model_save_path, 'params_%06d.h5' % i))
# Training forward
image.d, label.d = data.next()
solver.zero_grad()
loss.forward(clear_no_need_grad=True)
loss.backward(clear_buffer=True)
solver.weight_decay(args.weight_decay)
solver.update()
loss.data.cast(np.float32, ctx)
pred.data.cast(np.float32, ctx)
e = categorical_error(pred.d, label.d)
# Instead of using nnabla.monitor, use nnabla_tensorboard. (NEW)
if i % args.val_interval == 0:
tb_writer.add_image('image/train_data_{}'.format(i), image.d[0])
tb_writer.add_scalar('train/loss', loss.d.copy(), global_step=i)
tb_writer.add_scalar('train/error', e, global_step=i)
monitor_time.add(i)
validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer)
parameter_file = os.path.join(
args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
nn.save_parameters(parameter_file)
# append F.Softmax to the prediction graph so users see intuitive outputs
runtime_contents = {
'networks': [
{'name': 'Validation',
'batch_size': args.batch_size,
'outputs': {'y': F.softmax(vpred)},
'names': {'x': vimage}}],
'executors': [
{'name': 'Runtime',
'network': 'Validation',
'data': ['x'],
'output': ['y']}]}
save.save(os.path.join(args.model_save_path,
'{}_result.nnp'.format(args.net)), runtime_contents)
tb_writer.close()
def validation(args, ctx, vdata, vimage, vlabel, vpred, i, tb_writer):
ve = 0.0
for j in range(args.val_iter):
vimage.d, vlabel.d = vdata.next()
vpred.forward(clear_buffer=True)
vpred.data.cast(np.float32, ctx)
ve += categorical_error(vpred.d, vlabel.d)
tb_writer.add_scalar('test/error', ve / args.val_iter, i)
Courbe d'apprentissage
Je trace également l'image d'entrée.
Vous n'êtes pas obligé de dessiner avec votre propre script, et tensorboard est pratique après tout.
Recommended Posts