[PYTHON] Erläutern Sie den Code von Tensorflow_in_ROS

Vorwort

Ergebnis

Bevor wir uns mit der detaillierten Codebeschreibung befassen, zeige ich Ihnen nur, was Sie damit tun können. [Screenshot vom 23.11.2016 15_44_07.png](https: //qiita-image-store.s3.amazonaws. com / 0/134368 / b426cc2e-e963-0897-f2d0-f83a5ec7a3d0.png)

Handschriftliche 9 Eingabe von der Kamera auf der rechten Seite wie folgt Das linke ist das Schätzergebnis des trainierten CNN, und 9 wird ordnungsgemäß zurückgegeben, nachdem 9 angezeigt wurde. Dieses Mal habe ich den Code geschrieben, mit dem CNN wie dieses auf ROS ausgeführt werden kann. Weitere Informationen zu Ausführungsmethoden und -vorbereitungen finden Sie unter Vorheriger Qiita-Artikel.

Überblick über die Verarbeitung

Wenn Sie die Gliederung des Prozesses überprüfen

  1. Konfigurieren Sie CNN zum Lesen trainierter Dateien
  2. Abonnieren Sie Bildinformationen vom Kameraknoten
  3. Komprimieren Sie die Bildinformationen auf 28 * 28 und binärisieren Sie Schwarzweiß, sodass CNN von MNIST eingegeben wird
  4. Zeigen Sie CNN das Bild und schätzen Sie die Anzahl
  5. Veröffentlichen Sie das Ergebnis

Es ist wie es ist.

Ganzer Code

Der gesamte Code sieht folgendermaßen aus:

tensorflow_in_ros_mnist.py


import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Int16
from cv_bridge import CvBridge
import cv2
import numpy as np
import tensorflow as tf


def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], 
                      padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')


def makeCNN(x,keep_prob):
    # --- define CNN model
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)

    h_pool1 = max_pool_2x2(h_conv1)

    W_conv2 = weight_variable([3, 3, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

    h_pool2 = max_pool_2x2(h_conv2)

    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])

    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
 
    return y_conv



class RosTensorFlow():
    def __init__(self):
        self._cv_bridge = CvBridge()
        self._sub = rospy.Subscriber('image', Image, self.callback, queue_size=1)
        self._pub = rospy.Publisher('result', Int16, queue_size=1)

        self.x = tf.placeholder(tf.float32, [None,28,28,1], name="x")
        self.keep_prob = tf.placeholder("float")
        self.y_conv = makeCNN(self.x,self.keep_prob)

        self._saver = tf.train.Saver()
        self._session = tf.InteractiveSession()
        
        init_op = tf.initialize_all_variables()
        self._session.run(init_op)

        self._saver.restore(self._session, "model.ckpt")


    def callback(self, image_msg):
        cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")
        cv_image_gray = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY)
        ret,cv_image_binary = cv2.threshold(cv_image_gray,128,255,cv2.THRESH_BINARY_INV)
        cv_image_28 = cv2.resize(cv_image_binary,(28,28))
        np_image = np.reshape(cv_image_28,(1,28,28,1))
        predict_num = self._session.run(self.y_conv, feed_dict={self.x:np_image,self.keep_prob:1.0})
        answer = np.argmax(predict_num,1)
        rospy.loginfo('%d' % answer)
        self._pub.publish(answer)

    def main(self):
        rospy.spin()

if __name__ == '__main__':
    rospy.init_node('rostensorflow')
    tensor = RosTensorFlow()
    tensor.main()

Codekommentar

** Teil importieren **

tensorflow_in_ros_mnist.py


import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Int16
from cv_bridge import CvBridge
import cv2
import numpy as np
import tensorflow as tf

Dieses Mal wollte ich einen ROS-Knoten in Python einrichten, also habe ich rospy hinzugefügt. Ich habe auch Image zum Lesen von Bildern, Int16 zum Exportieren, cv_bridge zum Übergeben von ROS-Nachrichtendateien an OpenCV, OpenCV, Numpy und Tensorflow zum maschinellen Lernen hinzugefügt.

** CNN-Definitionsteil **

tensorflow_in_ros_mnist.py


def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], 
                      padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')


def makeCNN(x,keep_prob):
    # --- define CNN model
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)

    h_pool1 = max_pool_2x2(h_conv1)

    W_conv2 = weight_variable([3, 3, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

    h_pool2 = max_pool_2x2(h_conv2)

    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])

    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
 
    return y_conv

Ich habe eine Funktion zum Definieren von CNN mit der Konfiguration gemäß Tensorflows Tutorial (Deep MNIST for Expert) erstellt. .. Die Konfiguration ist ein Modell, bei dem nach zweimaligem Durchführen eines Faltungspools eine vollständig verbundene Schicht eingefügt und die Wahrscheinlichkeit für jede Zahl berechnet wird.

** Erste Hälfte des init Teils der Klasse **

tensorflow_in_ros_mnist.py


class RosTensorFlow():
    def __init__(self):
        self._cv_bridge = CvBridge()
        self._sub = rospy.Subscriber('image', Image, self.callback, queue_size=1)
        self._pub = rospy.Publisher('result', Int16, queue_size=1)

Die erste Hälfte ist die ROS-Verarbeitung. Hier rufen wir die CvBridge-Funktion auf und definieren Subscriber und Publisher. Diesmal empfängt der Abonnent die Nachricht vom Typ Image und der Publisher die Nachricht vom Typ Int16. Grundsätzlich wie im Beispiel von ROS. Durch Einfügen eines Rückrufs in das Argument des Unterteils wird die Rückruffunktion jedes Mal aufgerufen, wenn eine Bildnachricht empfangen wird.

** Zweite Hälfte von init Teil der Klasse **

tensorflow_in_ros_mnist.py


        self.x = tf.placeholder(tf.float32, [None,28,28,1], name="x")
        self.keep_prob = tf.placeholder("float")
        self.y_conv = makeCNN(self.x,self.keep_prob)

        self._saver = tf.train.Saver()
        self._session = tf.InteractiveSession()
        
        init_op = tf.initialize_all_variables()
        self._session.run(init_op)

        self._saver.restore(self._session, "model.ckpt")

Die zweite Hälfte ist die Tensorflow-Verarbeitung. Zuerst definieren wir x, den Platzhalter, der das Bild enthält, und keep_prob, den Platzhalter, der die DropOut-Rate enthält. Ein Platzhalter ist wie ein Zugang zu Daten, und zur Laufzeit werden hier immer mehr Daten eingegeben.

Definieren Sie als Nächstes das diesmal verwendete CNN als y_conv. Stellen Sie sich vor, es definiert den Pfad, über den Daten vom Dateneingang von x, keep_prob über das CNN zum Ausgang mit dem Namen y_conv ausgegeben werden.

Bereiten Sie nach dem Definieren der Route den tatsächlichen Datenfluss mit der Sitzungsfunktion vor. Initialisieren Sie das Gewicht W und die Vorspannung b von CNN einmal mit der Funktion tf.initialize_all_variables.

Lesen Sie hier die gelernten Parameter. Um die Parameter zu lesen, müssen Sie saver.restore ausführen, nachdem Sie die Funktion tf.train.Saver verwendet haben.

** Rückrufteil **

tensorflow_in_ros_mnist.py


    def callback(self, image_msg):
        cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8")
        cv_image_gray = cv2.cvtColor(cv_image, cv2.COLOR_RGB2GRAY)
        ret,cv_image_binary = cv2.threshold(cv_image_gray,128,255,cv2.THRESH_BINARY_INV)
        cv_image_28 = cv2.resize(cv_image_binary,(28,28))
        np_image = np.reshape(cv_image_28,(1,28,28,1))
        predict_num = self._session.run(self.y_conv, feed_dict={self.x:np_image,self.keep_prob:1.0})
        answer = np.argmax(predict_num,1)
        rospy.loginfo('%d' % answer)
        self._pub.publish(answer)

Dies wird jedes Mal gelesen, wenn eine Bildnachricht eingeht. Nach dem Konvertieren der Nachricht in ein Bild mit cv_bridge wird sie grau skaliert, binärisiert, schwarzweiß invertiert und die Größe angepasst, und das Bild wird in den CNN geworfen. Der mit der höchsten Wahrscheinlichkeit des zurückgegebenen Schätzergebnisses Predict_Num wird als Antwort festgelegt und veröffentlicht.

Die Hauptfunktion entfällt.

Nachwort

Jetzt können Sie mit ROS alles erledigen, solange Sie über ein geschultes Tensorflow-Modell verfügen. Ich denke, FasterRCNN kann Objekterkennung und Gesichtserkennung durchführen, daher würde ich es gerne ausprobieren.

Recommended Posts

Erläutern Sie den Code von Tensorflow_in_ROS
Erläutern Sie den Mechanismus der PEP557-Datenklasse
[Python3] Schreiben Sie das Codeobjekt der Funktion neu
[Python] Ruft den Zeichencode der Datei ab
[Python] Lesen Sie den Quellcode von Flasche Teil 2
[Python] Lesen Sie den Quellcode von Flasche Teil 1
Die Geschichte, Sourcetrail × macOS × VS Code auszuprobieren
Code zum Überprüfen des Betriebs von Python Matplot lib
Konvertieren Sie den Zeichencode der Datei mit Python3
Der Beginn von cif2cell
Die Bedeutung des Selbst
Die Geschichte von sys.path.append ()
Erklären Sie die assoziative Sequenz
Rache der Typen: Rache der Typen
Lassen Sie uns die Grundlagen des Python-Codes von TensorFlow aufschlüsseln
Holen Sie sich den Rückkehrcode eines Python-Skripts von bat
#Eine Funktion, die den Zeichencode einer Zeichenfolge zurückgibt
Ich habe versucht, den Beispielcode des Ansible-Moduls auszuführen
Richten Sie die Version von chromedriver_binary aus
Scraping das Ergebnis von "Schedule-Kun"
10. Zählen der Anzahl der Zeilen
Code, der bei AttributeError Standardwerte festlegt
Die Geschichte des Baus von Zabbix 4.4
Auf dem Weg zum Ruhestand von Python2
Vergleichen Sie die Schriftarten von Jupyter-Themen
Holen Sie sich die Anzahl der Ziffern
Einstellungen zum Eingeben und Debuggen des Inhalts der Bibliothek mit VS-Code
Verwenden Sie die Clustering-Ergebnisse erneut
2.x, 3.x Serienzeichencode von Python
Der Prozess, Python-Code objektorientiert zu machen und zu verbessern
GoPiGo3 des alten Mannes
Berechnen Sie die Anzahl der Änderungen
Überprüfen Sie den Code mit flake8
Ändern Sie das Thema von Jupyter
Die Popularität von Programmiersprachen
Ändern Sie den Stil von matplotlib
Visualisieren Sie die Flugbahn von Hayabusa 2
Über die Komponenten von Luigi
Verknüpfte Komponenten des Diagramms
Filtern Sie die Ausgabe von tracemalloc
Über die Funktionen von Python
Simulation des Inhalts der Brieftasche
Die Kraft der Pandas: Python
Folgen Sie dem QAOA-Fluss (VQE) auf der Quellcode-Ebene von Blueqat
Messen Sie die Testabdeckung von Push-Python-Code auf GitHub.
Erste Python ② Versuchen Sie, Code zu schreiben, während Sie die Funktionen von Python untersuchen
Erklären Sie detailliert den magischen Code für IQ Bot-Tabellenelemente
Ich habe den Code geschrieben, um den Brainf * ck-Code in Python zu schreiben
Aufzeichnung des Codes für klinische Studien, die von der Ethikkommission abgelehnt wurden
Überprüfen Sie den Speicherschutz von Linux Kern mit Code für ARM
Fassen wir den Grad der Kopplung zwischen Modulen mit Python-Code zusammen
for, continue, break Erläutern Sie den Ablauf der iterativen Verarbeitung in Python3-Teil 1
Die Spezifikationen von Pytz haben sich geändert
Testen Sie die Version des Argparse-Moduls
Finden Sie die Definition des Wertes von errno
Zeichnen Sie die Ausbreitung des neuen Koronavirus
Die Geschichte von Python und die Geschichte von NaN
Erhöhen Sie die Version von pyenv selbst