[PYTHON] Visualisierung des Zündzustands der verborgenen Schicht des Modells, die im TensorFlow MNIST-Lernprogramm gelernt wurde

Überblick

Um maschinelles Lernen zu verstehen, habe ich das MNIST-Tutorial von TensorFlow ausprobiert. Wir haben auch eine Anwendung implementiert, mit der Sie handschriftliche Zeichen über einen Browser eingeben und den Zündstatus und die Unterscheidungsergebnisse der verborgenen Schicht visualisieren können.

Hauptausführungsumgebung

TensorFlow MNIST Tutorial

Folgen Sie dem MNIST-Tutorial (Deep MNIST for Experts) auf der offiziellen Website von TensorFlow. Hier wird ein Netzwerk aufgebaut, in dem ein Satz aus einer Faltschicht und einer Poolschicht zwei Stufen aufweist und eine vollständig verbundene Schicht eine Stufe aufweist. Das Tutorial hat sich dahingehend geändert, dass die Sitzung interaktiv wird und nach dem Training des Modells Gewichte und Verzerrungen gespeichert werden.

Bei der Ausführung wird das Lernen durchgeführt (ca. 1 Stunde), und schließlich kann eine Genauigkeit von ca. 99% erzielt werden. Nach der Ausführung werden die trainierten Gewichte und Verzerrungen als Binärdatei gespeichert.

train.py


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i % 100 == 0:
        train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
    print('step %d, training accuracy %g' % (i, train_accuracy))
    train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.save(sess, 'cnn_model')

Handschriftliche Zeicheneingabeschnittstelle

Implementieren Sie eine Eingabeschnittstelle für handgeschriebene Zeichen mithilfe von HTML5 Canvas. Jedes Mal, wenn die Eingabe abgeschlossen ist, wird der Pixelwert des gezeichneten Eingabebildes im JSON-Format an den Server gesendet und das Unterscheidungsergebnis erfasst (später beschrieben).

templates/index.html


<!DOCTYPE html>
<html>
<head>
    <title>TensorFlow MNIST Demo</title>
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
    <link href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
</head>
<body>

<div class="container">
    <h1>TensorFlow MNIST Demo</h1>
    <div class="row">
        <div class="col-md-4">
            <h2>Input</h2>
            <canvas id="input-image" width="28" height="28" style="width: 280px; height: 280px;"></canvas>
        </div>

        <div class="col-md-8">
            <h2>Prediction</h2>
            <table class="table table-bordered table-striped">
                <tr>
                    <th width="10%">0</th>
                    <th width="10%">1</th>
                    <th width="10%">2</th>
                    <th width="10%">3</th>
                    <th width="10%">4</th>
                    <th width="10%">5</th>
                    <th width="10%">6</th>
                    <th width="10%">7</th>
                    <th width="10%">8</th>
                    <th width="10%">9</th>
                </tr>
                <tr>
                    <td id="score0">-</td>
                    <td id="score1">-</td>
                    <td id="score2">-</td>
                    <td id="score3">-</td>
                    <td id="score4">-</td>
                    <td id="score5">-</td>
                    <td id="score6">-</td>
                    <td id="score7">-</td>
                    <td id="score8">-</td>
                    <td id="score9">-</td>
                </tr>
            </table>
            <button class="btn btn-large" id="clear">Clear</button>
        </div>
    </div>

    <div class="row">
        <h2>h_conv1</h2>
        <canvas id="conv1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_pool1</h2>
        <canvas id="pool1" width="260" height="132" style="width: 1300px; height: 660px;"></canvas>
    </div>
    <div class="row">
        <h2>h_conv2</h2>
        <canvas id="conv2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>    
    <div class="row">
        <h2>h_pool2</h2>
        <canvas id="pool2" width="260" height="66" style="width: 1300px; height: 330px;"></canvas>
    </div>
</div>

<script type="text/javascript">
    
    var canvas = document.getElementById('input-image');
    var context = canvas.getContext('2d');    
    var moveFlag = false;
    var Xpoint;
    var Ypoint;
    var offsetX = canvas.getBoundingClientRect().left;
    var offsetY = canvas.getBoundingClientRect().top;
    var size = 28;
    var scale = 10;

    context.lineWidth = 1;
    context.strokeStyle = '#FFF';
    context.fillStyle = '#000';
    context.fillRect(0, 0, size, size);

    canvas.addEventListener('mousedown', startPoint, false);
    canvas.addEventListener('mousemove', movePoint, false);
    canvas.addEventListener('mouseup', endPoint, false);

    document.getElementById('clear').addEventListener('click', clear, false);
    updateImage();

    function startPoint(e) {
        e.preventDefault();
        context.beginPath();
        Xpoint = Math.round((e.pageX - offsetX) / scale);
        Ypoint = Math.round((e.pageY - offsetY) / scale);
        context.moveTo(Xpoint, Ypoint);
    }

    function movePoint(e) {
        if(e.buttons === 1 || e.witch === 1) {
            Xpoint = Math.round((e.pageX - offsetX) / scale);
            Ypoint = Math.round((e.pageY - offsetY) / scale);
            moveFlag = true;
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
    }

    function endPoint(e) {
        if (moveFlag === true) {
            context.lineTo(Xpoint, Ypoint);
            context.stroke();
        }
        moveFlag = false;
        updateImage();
    }

    function clear() {
        context.fillStyle = '#000';
        context.fillRect(0, 0, size, size);
        updateImage();
    }

    function updateImage() {
        //Siehe unten
    }

</script>

</body>
</html>

Handschriftliche Zeicheneingabeanwendung

Erstellen Sie gleichzeitig mit dem Starten von Flask dasselbe Netzwerk, das im MNIST-Lernprogramm erstellt wurde, und laden Sie die trainierten Gewichte und Verzerrungen.

index.py


from flask import Flask, render_template, request, redirect, jsonify
import numpy as np
import tensorflow as tf

def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv1 = weight_variable([5, 5, 1, 32], name='W_conv1')
b_conv1 = bias_variable([32], name='b_conv1')

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64], name='W_conv2')
b_conv2 = bias_variable([64], name='b_conv2')

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024], name='W_fc1')
b_fc1 = bias_variable([1024], name='b_fc1')

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10], name='W_fc2')
b_fc2 = bias_variable([10], name='b_fc2')

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver({'W_conv1':W_conv1, 'b_conv1':b_conv1, 'W_conv2':W_conv2, 'b_conv2':b_conv2, 'W_fc1':W_fc1, 'b_fc1':b_fc1, 'W_fc2':W_fc2, 'b_fc2':b_fc2})
saver.restore(sess, 'cnn_model')

app = Flask(__name__)

if __name__ == '__main__':
    app.debug = True
    app.run(host='0.0.0.0')

Die Indexaktion zeigt die handschriftliche Zeicheneingabeschnittstelle und das Inferenzergebnis an. Auf der Serverseite gibt es nichts Besonderes.

index.py


@app.route('/')
def index():
    return render_template('index.html')

Bei der Vorhersageaktion wird der Pixelwert des POST-Eingabebilds in das Netzwerk eingegeben und das Unterscheidungsergebnis in JSON zurückgegeben.

index.py


@app.route('/predict', methods=['POST'])
def predict():
    result = {}
    result['h_conv1'] = sess.run(h_conv1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool1'] = sess.run(h_pool1, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_conv2'] = sess.run(h_conv2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['h_pool2'] = sess.run(h_pool2, feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].transpose(2, 0, 1).tolist()
    result['y_conv']  = sess.run(y_conv,  feed_dict={x: [request.get_json()], keep_prob: 1.0})[0].tolist()

    return jsonify(result)

Kommunikation und Verarbeitung der Ergebnisanzeige durch Ajax. Der Bereich um die Leinwand ist unübersichtlich, aber der Brennzustand der verborgenen Ebene wird nur in einem Quadrat gezeichnet.

templates/index.html


<script>
function updateImage() {
    rawImage = context.getImageData(0, 0, size, size);
    image = Array.from(rawImage.data.filter(function(element, index, array) {
        return index % 4 === 0;
    }));
    
    $.ajax({
        url: '/predict',
        type: 'POST',
        data: JSON.stringify(image),
        contentType: 'application/JSON',
        dataType : 'JSON',
        success: function(data, status, xhr) {
            console.log('success');
            console.log(data);

            var canvas = document.getElementById('conv1');
            var context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 28; j++) {
                    for (var k = 0; k < 28; k++) {
                        val = Math.round(data['h_conv1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool1');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_pool1'].length; i++) {
                var x0 = 4 + (i%8) * 32;
                var y0 = 4 + Math.floor(i/8) * 32;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_pool1'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            canvas = document.getElementById('conv2');
            context = canvas.getContext('2d');
            context.fillStyle = '#ddf';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 14; j++) {
                    for (var k = 0; k < 14; k++) {
                        val = Math.round(data['h_conv2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k, y0+j, 1, 1);
                    }
                }
            }

            canvas = document.getElementById('pool2');
            context = canvas.getContext('2d');
            context.fillStyle = '#fdd';
            context.fillRect(0, 0, 260, 132);

            for (var i = 0; i < data['h_conv2'].length; i++) {
                var x0 = 3 + (i%16) * 16;
                var y0 = 2 + Math.floor(i/16) * 16;
                for (var j = 0; j < 7; j++) {
                    for (var k = 0; k < 7; k++) {
                        val = Math.round(data['h_pool2'][i][j][k]);
                        context.fillStyle = 'rgb('+val+','+val+','+val+')';
                        context.fillRect(x0+k*2, y0+j*2, 2, 2);
                    }
                }
            }

            $('[id^="score"]').removeClass('warning');
            for (var i = 0; i < 10; i++) {
                $('#score'+i).text(Math.round(data['y_conv'][i]));
                if (Math.max.apply(null, data['y_conv']) === data['y_conv'][i]) {
                    $('#score'+i).addClass('warning');
                }
            }

        },
        error: function(data, status, error) {
            console.log('error');
            console.log(error);
        }
    });
}
</script>

Ausführungsergebnis

Ich habe versucht, "5" von Hand einzugeben. Die Punktzahl "5" ist die höchste und kann korrekt identifiziert werden. Screen Shot 2017-07-24 at 17.07.13.png

Ausgabeergebnis der Faltungsschicht durch das 5x5-Filter der 1. Stufe. Eindruck, dass es auf Formen wie links und rechts von der vertikalen Linie, oben und unten auf der horizontalen Linie und diagonale Linie reagiert. Screen Shot 2017-07-24 at 17.07.35.png

Ausgabeergebnis der 2x2-Maximalwert-Pooling-Schicht der ersten Stufe. Screen Shot 2017-07-24 at 17.07.45.png

Ausgabeergebnis der Faltungsschicht durch das 5x5-Filter der 2. Stufe. Reagieren Sie auf ein etwas komplexeres Muster als in der ersten Reihe? Screen Shot 2017-07-24 at 17.07.57.png

Ausgabeergebnis der 2x2-Maximalwert-Pooling-Schicht der zweiten Stufe. Screen Shot 2017-07-24 at 17.08.08.png

Danach gibt es eine vollständig verbundene Schicht mit 1024 Einheiten, was zu einer Ausgangsschicht mit 10 Einheiten führt.

Referenz

Recommended Posts

Visualisierung des Zündzustands der verborgenen Schicht des Modells, die im TensorFlow MNIST-Lernprogramm gelernt wurde
Ich habe das MNIST-Tutorial von tensorflow für Anfänger ausprobiert.
Ich habe eine Demo erstellt, mit der das in Tensorflows Mnist-Tutorial trainierte Modell die handgeschriebenen Zahlen auf der Leinwand unterscheiden kann.
Verwenden Sie den von word2vec gelernten Vektor in der Einbettungsebene von LSTM
TensorFlow Tutorial Ich habe MNIST 3rd ausprobiert
Erstellen Sie eine REST-API mit dem in Lobe und TensorFlow Serving erlernten Modell.
Aufzeichnung der TensorFlow Mnist Expert Edition (Visualisierung von TensorBoard)
Durchführen des TensorFlow MNIST für ML-Anfänger-Tutorials
Überwachtes Lernen von Mnist in der vollständig verbundenen Ebene, Clustering und Bewertung der letzten Phase
Geben Sie das Beleuchtungsmodell für SCN-Material in Pythonista an
Zählen Sie die Anzahl der Parameter im Deep-Learning-Modell
Tensorflows Denkweise lernte aus der Kartoffelherstellung
Untersuchen Sie die Parameter von RandomForestClassifier im Kaggle / Titanic-Tutorial
Verwendung des in Lobe in Python erlernten Modells
Ich habe versucht, das CNN-Modell von TensorFlow mit TF-Slim umzugestalten
[Blender] Kennen Sie den Auswahlstatus von versteckten Objekten im Outliner
Die Geschichte der Herabstufung der Version von Tensorflow in der Demo von Mask R-CNN.