[PYTHON] Visualization of the firing state of the hidden layer of the model learned in the TensorFlow MNIST tutorial

Overview

To understand machine learning, I tried TensorFlow's MNIST tutorial. We also implemented an application that allows you to enter handwritten characters from a browser, and visualized the ignition status and discrimination results of the hidden layer.

Main execution environment

TensorFlow MNIST tutorial

Follow the MNIST tutorial (Deep MNIST for Experts) on the official TensorFlow website. Here, a network with two sets of convolution layers and pooling layers and one stage of fully connected layers is constructed. What I've changed from the tutorial is that I've added the process of making the session interactive and saving weights and biases after training the model.

When executed, learning is performed (about 1 hour), and the accuracy of about 99% can be finally obtained. After execution, the trained weights and biases are saved as a binary file.

train.py


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i % 100 == 0:
        train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
    print('step %d, training accuracy %g' % (i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.save(sess, 'cnn_model')

Handwriting input interface

Implement an input interface for handwritten characters using HTML5 Canvas. Each time the input is completed, the pixel value of the drawn input image is POSTed to the server in JSON format and the discrimination result is acquired (described later).

templates/index.html


<!DOCTYPE html>
<html>
<head>
    <title>TensorFlow MNIST Demo</title>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
    <link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
</head>
<body>

<div class="container">
    <h1>TensorFlow MNIST Demo</h1>
    <div class="row">
        <div class="col-md-4">
            <h2>Input</h2>
            <canvas id="input-image" width="28" height="28" style="width: 280px; height: 280px;"></canvas>
        </div>

        <div class="col-md-8">
            <h2>Prediction</h2>
            <table class="table table-bordered table-striped">
                <tr>
                    <th width="10%">0</th>
                    <th width="10%">1</th>
                    <th width="10%">2</th>
                    <th width="10%">3</th>
                    <th width="10%">4</th>
                    <th width="10%">5</th>
                    <th width="10%">6</th>
                    <th width="10%">7</th>
                    <th width="10%">8</th>
                    <th width="10%">9</th>
                </tr>
                <tr>
                    <td id="score0">-</td>
                    <td id="score1">-</td>
                    <td id="score2">-</td>
                    <td id="score3">-</td>
                    <td id="score4">-</td>
                    <td id="score5">-</td>
                    <td id="score6">-</td>
                    <td id="score7">-</td>
                    <td id="score8">-</td>
                    <td id="score9">-</td>
                </tr>
            </table>
            <button class="btn btn-large" id="clear">Clear</button>
        </div>
    </div>

    <div class="row">
        <h2>h_conv1</h2>
        <canvas id="conv1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_pool1</h2>
        <canvas id="pool1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_conv2</h2>
        <canvas id="conv2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>    
    <div class="row">
        <h2>h_pool2</h2>
        <canvas id="pool2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>
</div>

<script type="text/javascript">
    
    var canvas = document.getElementById('input-image');
    var context = canvas.getContext('2d');    
    var moveFlag = false;
    var Xpoint;
    var Ypoint;
    var offsetX = canvas.getBoundingClientRect().left;
    var offsetY = canvas.getBoundingClientRect().top;
    var size = 28;
    var scale = 10;

    context.lineWidth = 1;
    context.strokeStyle = '#FFF';
    context.fillStyle = '#000';
    context.fillRect(0, 0, size, size);

    canvas.addEventListener('mousedown', startPoint, false);
    canvas.addEventListener('mousemove', movePoint, false);
    canvas.addEventListener('mouseup', endPoint, false);

    document.getElementById('clear').addEventListener('click', clear, false);
    updateImage();

    function startPoint(e) {
        e.preventDefault();
        context.beginPath();
        Xpoint = Math.round((e.pageX - offsetX) / scale);
        Ypoint = Math.round((e.pageY - offsetY) / scale);
        context.moveTo(Xpoint, Ypoint);
    }

    function movePoint(e) {
        if(e.buttons === 1 || e.witch === 1) {
            Xpoint = Math.round((e.pageX - offsetX) / scale);
            Ypoint = Math.round((e.pageY - offsetY) / scale);
            moveFlag = true;
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
    }

    function endPoint(e) {
        if (moveFlag === true) {
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
        moveFlag = false;
        updateImage();
    }

    function clear() {
        context.fillStyle = '#000';
        context.fillRect(0, 0, size, size);
        updateImage();
    }

    function updateImage() {
        //See below
    }

</script>

</body>
</html>

Handwriting input application

At the same time you start Flask, build the same network that you created in the MNIST tutorial and load the trained weights and biases.

index.py


from flask import Flask, render_template, request, redirect, jsonify
import numpy as np
import tensorflow as tf

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.restore(sess, 'cnn_model')

app = Flask(__name__)

if __name__ == '__main__':
    app.debug = True
    app.run(host='0.0.0.0')

The index action displays the handwriting input interface and inference results. There is nothing special on the server side.

index.py


@app.route('/')
def index():
    return render_template('index.html')

In the predict action, the pixel value of the POSTed input image is input to the network, and the discrimination result is returned in JSON.

index.py


@app.route('/predict', methods=['POST'])
def predict():
    result = {}
    result['h_conv1'] = sess.run(h_conv1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool1'] = sess.run(h_pool1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_conv2'] = sess.run(h_conv2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool2'] = sess.run(h_pool2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['y_conv']  = sess.run(y_conv,  feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].tolist()

    return jsonify(result)

Communication and result display processing by Ajax. The area around the Canvas is cluttered, but the ignition state of the hidden layer is just drawn in a square.

templates/index.html


<script>
function updateImage() {
    rawImage = context.getImageData(0, 0, size, size);
    image = Array.from(rawImage.data.filter(function(element, index, array) {
        return index % 4 === 0;
    }));
    
    $.ajax({
        url: '/predict',
        type: 'POST',
        data: JSON.stringify(image),
        contentType: 'application/JSON',
        dataType : 'JSON',
        success: function(data, status, xhr) {
            console.log('success');
            console.log(data);

            var canvas = document.getElementById('conv1');
            var context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 28; j++) {
                    for (var k = 0; k < 28; k++) {
                        val = Math.round(data['h_conv1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool1');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_pool1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_pool1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            canvas = document.getElementById('conv2');
            context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_conv2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool2');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 7; j++) {
                    for (var k = 0; k < 7; k++) {
                        val = Math.round(data['h_pool2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            $('[id^="score"]').removeClass('warning');
            for (var i = 0; i < 10; i++) {
                $('#score'+i).text(Math.round(data['y_conv'][i]));
                if (Math.max.apply(null, data['y_conv']) === data['y_conv'][i]) {
                    $('#score'+i).addClass('warning');
                }
            }

        },
        error: function(data, status, error) {
            console.log('error');
            console.log(error);
        }
    });
}
</script>

Execution result

I tried to enter "5" by hand. The score of "5" is the highest and it can be correctly identified. Screen Shot 2017-07-24 at 17.07.13.png

Output result of the convolution layer by the 1st stage 5x5 filter. Impression that it reacts to the shapes such as left and right of vertical lines, top and bottom of horizontal lines, and diagonal lines. Screen Shot 2017-07-24 at 17.07.35.png

Output result of 2x2 maximum value pooling layer of the first stage. Screen Shot 2017-07-24 at 17.07.45.png

Output result of the convolution layer by the 2nd stage 5x5 filter. Are you reacting to a slightly more complex pattern than the first row? Screen Shot 2017-07-24 at 17.07.57.png

Output result of 2x2 maximum value pooling layer in the second stage. Screen Shot 2017-07-24 at 17.08.08.png

After this, there is one fully connected layer with 1024 units, which leads to an output layer with 10 units.

reference

Recommended Posts

Visualization of the firing state of the hidden layer of the model learned in the TensorFlow MNIST tutorial
I tried the MNIST tutorial for beginners of tensorflow.
Tutorial to infer the model learned in Tensorflow with C ++/OpenVINO at high speed
I made a demo that lets the model learned in the Tensorflow mnist tutorial distinguish the handwritten numbers written on the canvas.
Use the vector learned by word2vec in the Embedding layer of LSTM
I tried the TensorFlow tutorial MNIST 3rd
Create a REST API using the model learned in Lobe and TensorFlow Serving.
Record of TensorFlow mnist expert edition (Visualization of TensorBoard)
Conducting the TensorFlow MNIST For ML Beginners Tutorial
Supervised learning of mnist in the fully connected layer, clustering and evaluating the final stage
Specify the lighting Model of SCN Material in Pythonista
Count the number of parameters in the deep learning model
The idea of Tensorflow learned from potato chip manufacturing
Examine the parameters of RandomForestClassifier in the Kaggle / Titanic tutorial
How to use the model learned in Lobe in Python
I tried refactoring the CNN model of TensorFlow using TF-Slim
[Blender] Know the selection status of hidden objects in the outliner
What beginners learned from the basics of variables in python
The story of downgrading the version of tensorflow in the demo of Mask R-CNN.