[PYTHON] Bildklassifizierung mit Weitwinkel-Fundusbilddatensatz

1. Zuallererst

Für Anfänger zielt dieser Artikel darauf ab, TensorFlow 2.0 zu verwenden, um Bilder vorerst mit Deep Learning zu klassifizieren. Da der Bilddatensatz für MNIST nicht interessant ist, werde ich den vom Tsukazaki Hospital veröffentlichten Weitwinkel-Fundusbilddatensatz [^ 1] verwenden. Das Netzwerk ist auch ein einfaches 10-Tier-CNN.

Alle Codes

2. Umwelt

3. Weitwinkel-Fundusbilddatensatz

Ein Weitwinkel-Fundus-Datensatz von 13047 Blättern (5389 Personen, 8588 Augen), veröffentlicht vom Tsukazaki Hospital. Sie können die CSV-Datei mit dem dazugehörigen Bild- und Krankheitslabel über den folgenden Link herunterladen. Tsukazaki Optos Public Project https://tsukazaki-ai.github.io/optos_dataset/

Die Aufteilung der Krankheitsbezeichnung ist wie folgt.

Etikette Krankheit Anzahl der Blätter
AMD Altersbedingte Degeneration des gelben Flecks 413
RVO Netzhautvenenverschluss 778
Gla Glaukom 2619
MH Gelbes Fleckloch 222
DR Diabetes-Retinopathie 3323
RD Netzhautablösung 974
RP Netzhautpigmentdegeneration 258
AO Arterielle Okklusion 21
DM Diabetes 3895

Unterscheidet sich die Gesamtzahl der Blätter in der Tabelle von der Anzahl der Bilder? Ich bin sicher, einige von Ihnen haben das vielleicht gedacht, also schauen wir uns die eigentliche CSV-Datei an.

filename age sex LR AMD RVO Gla MH DR RD RP AO DM
000000_00.jpg 78 M L 0 0 0 0 0 0 0 0 0
000000_01.jpg 78 M R 0 0 0 0 0 0 0 0 0
000001_02.jpg 69 M L 0 0 1 0 0 0 0 0 0
000011_01.jpg 70 F L 0 0 0 0 1 0 0 0 1

Auf diese Weise handelt es sich um ein Problem mit mehreren Etiketten mit mehreren Etiketten (Komplikationen) für ein Bild. Es gibt insgesamt 4364 nicht erkrankte Bilder, die nicht beschriftet sind. Zusätzlich wird unten ein Bildbeispiel gezeigt.

Enthält groteske Bilder
000000_00.jpg 000000_01.jpg 000001_02.jpg 000011_01.jpg

Es gibt ein Ungleichgewicht in der Anzahl der Daten, und es ist ziemlich ärgerlich mit Multi-Label ~ ~ Es ist ein praktischer Datensatz, aber in diesem Artikel ist es einfach, nur Bilder ohne Multi-Label und nur solche mit einer großen Anzahl von Klassen zu verwenden Klassifizieren.

4. Datenaufteilung

Extrahieren Sie zunächst nur nicht mehrfach beschriftete Bilder aus der CSV-Datei. Da das DR-Bild jedoch auch DM enthält, wird auch das Bild extrahiert, in dem DR und DM gleichzeitig auftreten. Wir haben uns jedoch entschieden, DR und AO nicht zu verwenden, die nur 3 bzw. 11 Bilder enthalten. Da es 3113 DR + DMs und 530 DMs mit teilweise überlappenden Etiketten gab, haben wir uns dieses Mal entschieden, das DM mit der kleineren Anzahl nicht zu verwenden. Außerdem habe ich das Format der CSV-Datei geändert, damit sie später verarbeitet werden kann.

Code zum Extrahieren von Bildern ohne Multilabel und zum Kombinieren dieser Bilder zu einer CSV-Datei
from collections import defaultdict
import pandas as pd


#Lesen Sie die CSV-Datei des Weitwinkel-Fundus-Datensatzes
df = pd.read_csv('data.csv')

dataset = defaultdict(list)

for i in range(len(df)):
    #Konvertieren Sie die angehängte Beschriftung in Zeichen
    labels = ''
    if df.iloc[i]['AMD'] == 1:
        labels += '_AMD'
    if df.iloc[i]['RVO'] == 1:
        labels += '_RVO'
    if df.iloc[i]['Gla'] == 1:
        labels += '_Gla'
    if df.iloc[i]['MH'] == 1:
        labels += '_MH'
    if df.iloc[i]['DR'] == 1:
        labels += '_DR'
    if df.iloc[i]['RD'] == 1:
        labels += '_RD'
    if df.iloc[i]['RP'] == 1:
        labels += '_RP'
    if df.iloc[i]['AO'] == 1:
        labels += '_AO'
    if df.iloc[i]['DM'] == 1:
        labels += '_DM'
    if labels == '':
        labels = 'Normal'
    else:
        labels = labels[1:]

    #Nicht Multi-Label(DR+Ohne DM)Bild und
    #Ein paar DR, DM und
    #Doppelte Etiketten, aber DR+Extrahieren Sie weniger Nicht-DM-Bilder als DM
    if '_' not in labels or labels == 'DR_DM':
        if labels not in ('DR', 'AO', 'DM'):
            dataset['filename'].append(df.iloc[i]['filename'])
            dataset['id'].append(df.iloc[i]['filename'].split('_')[0].split('.')[0])
            dataset['label'].append(labels)

#Als CSV-Datei speichern
dataset = pd.DataFrame(dataset)
dataset.to_csv('dataset.csv', index=False)

Ich habe die folgende CSV-Datei mit dem obigen Code erstellt. Da die Bilder gemäß der Regel {Seriennummer ID} _ {Seriennummer} .jpg benannt sind, wird die Seriennummer ID als ID verwendet.

filename id label
000000_00.jpg 0 Normal
000000_01.jpg 0 Normal
000001_02.jpg 1 Gla
000011_01.jpg 11 DR_DM

Als Ergebnis der Extraktion ist die Aufteilung der Klassifizierungsklasse und der Anzahl der Bilder wie folgt. Normal ist ein Bild ohne Krankheit.

Etikette Anzahl der Blätter
Normal 4364
Gla 2293
AMD 375
RP 247
DR_DM 3113
RD 883
RVO 537
MH 161

Teilen Sie als nächstes die Bilddaten. Da der Datensatz 13047 Blätter (5389 Personen, 8588 Augen) umfasst, sind Bilder derselben Person und desselben Auges enthalten. Bilder derselben Person oder derselben Augen enthalten ähnliche Merkmale und Beschriftungen, die zu Datenlecks führen können. Daher wird die Aufteilung so durchgeführt, dass in den Trainingsdaten und den Testdaten nicht dieselbe Person vorhanden ist. Stellen Sie außerdem sicher, dass das Verhältnis der Aufschlüsselung der Trainingsdaten und Testdaten für jede Klasse ungefähr gleich ist. Diesmal betrugen die Trainingsdaten 60%, die Verifizierungsdaten 20% und die Testdaten 20%.

Split-Code für Gruppenschichtung K

5. Modellbau & Lernen

Importieren Sie zunächst die Bibliothek, die Sie verwenden möchten.

import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Input, MaxPool2D
from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam

Beschreiben Sie als Nächstes die Parameter usw. label_list ist zur Vereinfachung der Bibliothek in abc-Reihenfolge angeordnet.

directory = 'img' #Ordner, in dem Bilder gespeichert werden
df_train = pd.read_csv('train.csv') #DataFrame mit Trainingsdateninformationen
df_validation = pd.read_csv('val.csv') #DataFrame mit Informationen zu Validierungsdaten
label_list = ['AMD', 'DR_DM', 'Gla', 'MH', 'Normal', 'RD', 'RP', 'RVO'] #Markenname
image_size = (224, 224) #Bildgröße eingeben
classes = len(label_list) #Anzahl der Klassifizierungsklassen
batch_size = 32 #Chargengröße
epochs = 300 #Anzahl der Epochen
loss = 'categorical_crossentropy' #Verlustfunktion
optimizer = Adam(lr=0.001, amsgrad=True) #Optimierungsfunktion
metrics = 'accuracy' #Bewertungsmethoden
#ImageDataGenerator Bildverstärkungsparameter
aug_params = {'rotation_range': 5,
              'width_shift_range': 0.05,
              'height_shift_range': 0.05,
              'shear_range': 0.1,
              'zoom_range': 0.05,
              'horizontal_flip': True,
              'vertical_flip': True}

Das Folgende wird als Rückrufprozess während des Lernens angewendet.

# val_Modell nur speichern, wenn der Verlust minimiert ist
mc_cb = ModelCheckpoint('model_weights.h5',
                        monitor='val_loss', verbose=1,
                        save_best_only=True, mode='min')
#Wenn das Lernen stagniert, beträgt die Lernrate 0.Doppelt
rl_cb = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=3,
                          verbose=1, mode='auto',
                          min_delta=0.0001, cooldown=0, min_lr=0)
#Wenn das Lernen nicht voranschreitet, wird das Lernen gewaltsam abgebrochen
es_cb = EarlyStopping(monitor='loss', min_delta=0,
                      patience=5, verbose=1, mode='auto')

Da die Anzahl der Daten in jeder Klasse unausgewogen ist, stellen Sie sicher, dass der Verlust groß ist, wenn Sie in einer Klasse mit einer kleinen Anzahl von Daten einen Fehler machen.

#Passen Sie die Verlustgewichte an die Anzahl der Daten an
weight_balanced = {}
for i, label in enumerate(label_list):
    weight_balanced[i] = (df_train['label'] == label).sum()
max_count = max(weight_balanced.values())
for label in weight_balanced:
    weight_balanced[label] = max_count / weight_balanced[label]
print(weight_balanced)

Erzeugt einen Generator für Trainings- und Validierungsdaten. Verwenden Sie ImageDataGenerator zur Datenerweiterung und laden Sie Bilder aus DataFrame mit flow_from_dataframe. Der Grund, warum label_list in abc-Reihenfolge angeordnet ist, besteht darin, dass beim Lesen eines Bildes durch flow_from_dataframe Klassen in abc-Reihenfolge der Zeichenfolge zugewiesen werden, damit die Entsprechung zwischen der Klassennummer und dem Label-Namen verstanden werden kann. Sie können die Korrespondenz später überprüfen, aber es ist ärgerlich, also ...

#Generator generieren
##Trainingsdatengenerator
datagen = ImageDataGenerator(rescale=1./255, **aug_params)
train_generator = datagen.flow_from_dataframe(
    dataframe=df_train, directory=directory,
    x_col='filename', y_col='label',
    target_size=image_size, class_mode='categorical',
    classes=label_list,
    batch_size=batch_size)
step_size_train = train_generator.n // train_generator.batch_size
##Validierungsdatengenerator
datagen = ImageDataGenerator(rescale=1./255)
validation_generator = datagen.flow_from_dataframe(
    dataframe=df_validation, directory=directory,
    x_col='filename', y_col='label',
    target_size=image_size, class_mode='categorical',
    classes=label_list,
    batch_size=batch_size)
step_size_validation = validation_generator.n // validation_generator.batch_size

Erstellen Sie ein einfaches 10-Schicht-CNN.

#Aufbau eines 10-lagigen CNN
def cnn(input_shape, classes):
    #Eingabeebene
    inputs = Input(shape=(input_shape[0], input_shape[1], 3))

    #1. Schicht
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #2. Schicht
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #3. Schicht
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #4. Schicht
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #5. und 6. Schicht
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #7. und 8. Schicht
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = GlobalAveragePooling2D()(x)

    #9. und 10. Schicht
    x = Dense(256, kernel_initializer='he_normal')(x)
    x = Dense(classes, kernel_initializer='he_normal')(x)
    outputs = Activation('softmax')(x)


    return Model(inputs=inputs, outputs=outputs)

#Netzwerkaufbau
model = cnn(image_size, classes)
model.summary()
model.compile(loss=loss, optimizer=optimizer, metrics=[metrics])

Lernen Sie das Netzwerk.

#Lernen
history = model.fit_generator(
    train_generator, steps_per_epoch=step_size_train,
    epochs=epochs, verbose=1, callbacks=[mc_cb, rl_cb, es_cb],
    validation_data=validation_generator,
    validation_steps=step_size_validation,
    class_weight=weight_balanced,
    workers=3)

Speichern Sie abschließend das Trainingskurvendiagramm als Bild.

#Zeichnen und speichern Sie ein Diagramm der Lernkurve
def plot_history(history):
    fig, (axL, axR) = plt.subplots(ncols=2, figsize=(10, 4))

    # [links]Grafik über Metriken
    L_title = 'Accuracy_vs_Epoch'
    axL.plot(history.history['accuracy'])
    axL.plot(history.history['val_accuracy'])
    axL.grid(True)
    axL.set_title(L_title)
    axL.set_ylabel('accuracy')
    axL.set_xlabel('epoch')
    axL.legend(['train', 'test'], loc='upper left')

    # [Rechte Seite]Grafik über Verlust
    R_title = "Loss_vs_Epoch"
    axR.plot(history.history['loss'])
    axR.plot(history.history['val_loss'])
    axR.grid(True)
    axR.set_title(R_title)
    axR.set_ylabel('loss')
    axR.set_xlabel('epoch')
    axR.legend(['train', 'test'], loc='upper left')

    #Grafik als Bild speichern
    fig.savefig('history.jpg')
    plt.close()

#Lernkurve speichern
plot_history(history)

Die Lernergebnisse sind wie folgt.

history.jpg

6. Bewertung

Da es sich bei der Auswertung um unausgeglichene Daten handelt, wird sie anhand des F1-Scores ausgewertet. Leiten Sie zunächst die Testdaten anhand des zuvor erlernten Modells ab.

Zusätzlicher Import.

import numpy as np
from PIL import Image
from sklearn.metrics import classification_report
from tqdm import tqdm

Beschreiben Sie die Parameter. Lesen Sie diesmal die CSV-Testdatei.

directory = 'img' #Ordner, in dem Bilder gespeichert werden
df_test = pd.read_csv('test.csv') #DataFrame mit Testdateninformationen
label_list = ['AMD', 'DR_DM', 'Gla', 'MH', 'Normal', 'RD', 'RP', 'RVO'] #Markenname
image_size = (224, 224) #Bildgröße eingeben
classes = len(label_list) #Anzahl der Klassifizierungsklassen

Bauen Sie das erlernte Netzwerk auf und laden Sie die zuvor erlernten Gewichte.

#Netzwerkaufbau&Gelernte Gewichte lesen
model = cnn(image_size, classes)
model.load_weights('model_weights.h5')

Die Inferenz erfolgt durch Lesen und Konvertieren des Bildes, sodass die Bedingungen dieselben sind wie beim Lernen.

#Inferenz
X = df_test['filename'].values
y_true = list(map(lambda x: label_list.index(x), df_test['label'].values))
y_pred = []
for file in tqdm(X, desc='pred'):
    #Ändern Sie die Größe des Bilds so, dass es die gleichen Bedingungen wie beim Lernen hat&Umwandlung
    img = Image.open(f'{directory}/{file}')
    img = img.resize(image_size, Image.LANCZOS)
    img = np.array(img, dtype=np.float32)
    img *= 1./255
    img = np.expand_dims(img, axis=0)

    y_pred.append(np.argmax(model.predict(img)[0]))

Berechnen Sie den F1-Score mit scikit-learn.

#Auswertung
print(classification_report(y_true, y_pred, target_names=label_list))

Nachfolgend sind die Bewertungsergebnisse aufgeführt. Sicher genug, AMD und MH, die eine kleine Datenmenge haben, haben niedrige Werte.

              precision    recall  f1-score   support

         AMD       0.17      0.67      0.27        75
       DR_DM       0.72      0.75      0.73       620
         Gla       0.76      0.69      0.72       459
          MH       0.09      0.34      0.14        32
      Normal       0.81      0.50      0.62       871
          RD       0.87      0.79      0.83       176
          RP       0.81      0.86      0.83        50
         RVO       0.45      0.65      0.53       107

    accuracy                           0.64      2390
   macro avg       0.58      0.66      0.59      2390
weighted avg       0.73      0.64      0.67      2390

7. Zusammenfassung

In diesem Artikel haben wir ein einfaches 10-Schicht-CNN verwendet, um Bilder des vom Tsukazaki Hospital veröffentlichten Weitwinkel-Fundus-Datensatzes zu klassifizieren. Basierend auf diesem Ergebnis werden wir in Zukunft die Leistung verbessern und dabei die neuesten Methoden wie Netzwerkstruktur und Datenerweiterungsmethode einbeziehen.

Recommended Posts

Bildklassifizierung mit Weitwinkel-Fundusbilddatensatz
Bildsegmentierung mit CaDIS: ein Katarakt-Datensatz
Kochobjekterkennung durch Yolo + Bildklassifizierung
MNIST-Bildklassifizierung (handschriftliche Nummer) mit mehrschichtigem Perzeptron
Bildverarbeitung mit MyHDL
Bilderkennung mit Keras
Bildverarbeitung mit Python
Fordern Sie die Bildklassifizierung mit TensorFlow2 + Keras 3 heraus ~ Visualisieren Sie MNIST-Daten ~
Bildverarbeitung mit PIL
"Müll nach Bild klassifizieren!" App-Erstellungstagebuch Tag2 ~ Feinabstimmung mit VGG16 ~
[Deep Learning] Bildklassifizierung mit Faltungsnetz [DW Tag 4]
Bild herunterladen mit Flickr API
[PyTorch] Bildklassifizierung von CIFAR-10
Ich habe die Bildklassifizierung von AutoGluon ausprobiert
Lesen Sie die Bildkoordinaten mit Python-matplotlib
Bildverarbeitung mit PIL (Pillow)
Bildbearbeitung mit Python OpenCV
Dokumentenklassifizierung mit Satzstück
Hochladen und Anpassen von Bildern mit django-ckeditor
Sortieren von Bilddateien mit Python (3)
CNN (1) zur Bildklassifizierung (für Anfänger)
Erstellen Sie den Image Viewer mit Tkinter
Bilddateien mit Python sortieren
Bildverarbeitung mit Python (3)
Bildunterschriftengenerierung mit Chainer
Holen Sie sich Bildfunktionen mit OpenCV
Bilderkennung mit Keras + OpenCV
[Python] Bildverarbeitung mit Scicit-Image
Fordern Sie die Bildklassifizierung mit TensorFlow2 + Keras 4 heraus. ~ Lassen Sie uns mit dem trainierten Modell ~ vorhersagen
Fordern Sie die Bildklassifizierung mit TensorFlow2 + Keras 9 heraus. Lernen, Speichern und Laden von Modellen