[PYTHON] Image classification with wide-angle fundus image dataset

1.First of all

For beginners, this article aims to use TensorFlow 2.0 to classify images with Deep Learning for the time being. Since the image data set is not interesting with MNIST, I will use the wide-angle fundus image data set [^ 1] published by Tsukazaki Hospital. Also, the network is a simple 10-tier CNN.

All code

2. Environment

--PC specs

3. Wide-angle fundus image dataset

A wide-angle fundus data set of 13047 sheets (5389 people, 8588 eyes) published by Tsukazaki Hospital. You can download the csv file with the image and disease label associated with it from the link below. Tsukazaki Optos Public Project https://tsukazaki-ai.github.io/optos_dataset/

The breakdown of the disease label is as follows.

label disease Number of sheets
AMD Age-related macular degeneration 413
RVO Retinal vein occlusion 778
Gla Glaucoma 2619
MH Macular hole 222
DR Diabetic retinopathy 3323
RD Retinal detachment 974
RP Retinitis pigmentosa 258
AO Arterial occlusion 21
DM Diabetes mellitus 3895

Is the total number of images in the table different from the number of images? I'm sure some of you may have thought, so let's take a look at the actual csv file.

filename age sex LR AMD RVO Gla MH DR RD RP AO DM
000000_00.jpg 78 M L 0 0 0 0 0 0 0 0 0
000000_01.jpg 78 M R 0 0 0 0 0 0 0 0 0
000001_02.jpg 69 M L 0 0 1 0 0 0 0 0 0
000011_01.jpg 70 F L 0 0 0 0 1 0 0 0 1

In this way, it is a multi-label problem with multiple labels (complications) for one image. There are a total of 4364 non-diseased images that are not labeled. In addition, an image sample is shown below.

Contains grotesque images
000000_00.jpg 000000_01.jpg 000001_02.jpg 000011_01.jpg

There is an imbalance in the number of data, and it is quite annoying with multi-label ~ ~ It is a practical data set, but in this article it is easy to use only non-multi-label images and only those with a large number of classes Classify.

4. Data split

First, extract only non-multi-label images from the csv file. However, since there is also DM in the DR image, the image in which DR and DM are occurring at the same time is also extracted. However, we decided not to use DR and AO, which have only 3 and 11 images, respectively. Also, since there were 3113 DR + DMs and 530 DMs with partially duplicated labels, we decided not to use the DM with the smaller number this time. In addition, I changed the format of the csv file so that it can be processed later.

Code to extract non-multilabel images and combine them into a csv file
from collections import defaultdict
import pandas as pd


#Read the csv file of the wide-angle fundus dataset
df = pd.read_csv('data.csv')

dataset = defaultdict(list)

for i in range(len(df)):
    #Convert the attached label into characters
    labels = ''
    if df.iloc[i]['AMD'] == 1:
        labels += '_AMD'
    if df.iloc[i]['RVO'] == 1:
        labels += '_RVO'
    if df.iloc[i]['Gla'] == 1:
        labels += '_Gla'
    if df.iloc[i]['MH'] == 1:
        labels += '_MH'
    if df.iloc[i]['DR'] == 1:
        labels += '_DR'
    if df.iloc[i]['RD'] == 1:
        labels += '_RD'
    if df.iloc[i]['RP'] == 1:
        labels += '_RP'
    if df.iloc[i]['AO'] == 1:
        labels += '_AO'
    if df.iloc[i]['DM'] == 1:
        labels += '_DM'
    if labels == '':
        labels = 'Normal'
    else:
        labels = labels[1:]

    #Not multi-label(DR+Excluding DM)Image and
    #A few DR, DM and
    #Duplicate labels but DR+Extract less non-DM images than DM
    if '_' not in labels or labels == 'DR_DM':
        if labels not in ('DR', 'AO', 'DM'):
            dataset['filename'].append(df.iloc[i]['filename'])
            dataset['id'].append(df.iloc[i]['filename'].split('_')[0].split('.')[0])
            dataset['label'].append(labels)

#Save as csv file
dataset = pd.DataFrame(dataset)
dataset.to_csv('dataset.csv', index=False)

I created the following csv file with the above code. Since the image is named by the rule of {serial number ID} _ {serial number} .jpg, the serial number ID is used as id.

filename id label
000000_00.jpg 0 Normal
000000_01.jpg 0 Normal
000001_02.jpg 1 Gla
000011_01.jpg 11 DR_DM

As a result of the extraction, the breakdown of the classification class and the number of images is as follows. Normal is a non-illness image.

label Number of sheets
Normal 4364
Gla 2293
AMD 375
RP 247
DR_DM 3113
RD 883
RVO 537
MH 161

Next, divide the image data. Since the data set is 13047 sheets (5389 people, 8588 eyes), images of the same person and the same eye are included. Images of the same person or eyes contain similar features and labels, which can cause data leaks. Therefore, the division is performed so that the same person does not exist across the training data and the test data. In addition, make sure that the ratio of each class breakdown of training data and test data is approximately the same. This time, the training data was 60%, the verification data was 20%, and the test data was 20%.

Group stratification K partition code

5. Model building & learning

First, import the library you want to use.

import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Input, MaxPool2D
from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam

Next, describe the parameters etc. label_list is arranged in abc order for the convenience of the library.

directory = 'img' #Folder where images are stored
df_train = pd.read_csv('train.csv') #DataFrame with training data information
df_validation = pd.read_csv('val.csv') #DataFrame with validation data information
label_list = ['AMD', 'DR_DM', 'Gla', 'MH', 'Normal', 'RD', 'RP', 'RVO'] #Label name
image_size = (224, 224) #Input image size
classes = len(label_list) #Number of classification classes
batch_size = 32 #Batch size
epochs = 300 #Number of epochs
loss = 'categorical_crossentropy' #Loss function
optimizer = Adam(lr=0.001, amsgrad=True) #Optimization function
metrics = 'accuracy' #Evaluation method
#ImageDataGenerator Image amplification parameters
aug_params = {'rotation_range': 5,
              'width_shift_range': 0.05,
              'height_shift_range': 0.05,
              'shear_range': 0.1,
              'zoom_range': 0.05,
              'horizontal_flip': True,
              'vertical_flip': True}

The following is applied as the callback processing during learning.

# val_Save model only when loss is minimized
mc_cb = ModelCheckpoint('model_weights.h5',
                        monitor='val_loss', verbose=1,
                        save_best_only=True, mode='min')
#When learning is stagnant, the learning rate is set to 0..Double
rl_cb = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=3,
                          verbose=1, mode='auto',
                          min_delta=0.0001, cooldown=0, min_lr=0)
#If learning does not progress, learning will be forcibly terminated
es_cb = EarlyStopping(monitor='loss', min_delta=0,
                      patience=5, verbose=1, mode='auto')

Since the number of data in each class is unbalanced, if you make a mistake in a class with a small number of data, make sure that the loss is large.

#Adjust loss weights to match the number of data
weight_balanced = {}
for i, label in enumerate(label_list):
    weight_balanced[i] = (df_train['label'] == label).sum()
max_count = max(weight_balanced.values())
for label in weight_balanced:
    weight_balanced[label] = max_count / weight_balanced[label]
print(weight_balanced)

Generates a generator of training and validation data. Use ImageDataGenerator for data expansion and load images from DataFrame with flow_from_dataframe. The reason why label_list is in abc order is that when an image is read by flow_from_dataframe, classes are assigned in abc order of the character string, so that the correspondence between the class number and the label name can be understood. You can check the correspondence later, but it's annoying, so ...

#Generator generation
##Training data generator
datagen = ImageDataGenerator(rescale=1./255, **aug_params)
train_generator = datagen.flow_from_dataframe(
    dataframe=df_train, directory=directory,
    x_col='filename', y_col='label',
    target_size=image_size, class_mode='categorical',
    classes=label_list,
    batch_size=batch_size)
step_size_train = train_generator.n // train_generator.batch_size
##Validation data generator
datagen = ImageDataGenerator(rescale=1./255)
validation_generator = datagen.flow_from_dataframe(
    dataframe=df_validation, directory=directory,
    x_col='filename', y_col='label',
    target_size=image_size, class_mode='categorical',
    classes=label_list,
    batch_size=batch_size)
step_size_validation = validation_generator.n // validation_generator.batch_size

Build a simple 10-layer CNN.

#Building a 10-tier CNN
def cnn(input_shape, classes):
    #Input layer
    inputs = Input(shape=(input_shape[0], input_shape[1], 3))

    #1st layer
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #2nd layer
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #3rd layer
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #4th layer
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #5th and 6th layers
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    #7th and 8th layers
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = GlobalAveragePooling2D()(x)

    #9th and 10th layers
    x = Dense(256, kernel_initializer='he_normal')(x)
    x = Dense(classes, kernel_initializer='he_normal')(x)
    outputs = Activation('softmax')(x)


    return Model(inputs=inputs, outputs=outputs)

#Network construction
model = cnn(image_size, classes)
model.summary()
model.compile(loss=loss, optimizer=optimizer, metrics=[metrics])

Learn the network.

#Learning
history = model.fit_generator(
    train_generator, steps_per_epoch=step_size_train,
    epochs=epochs, verbose=1, callbacks=[mc_cb, rl_cb, es_cb],
    validation_data=validation_generator,
    validation_steps=step_size_validation,
    class_weight=weight_balanced,
    workers=3)

Finally, save the learning curve graph as an image.

#Draw and save a graph of the learning curve
def plot_history(history):
    fig, (axL, axR) = plt.subplots(ncols=2, figsize=(10, 4))

    # [left]Graph about metrics
    L_title = 'Accuracy_vs_Epoch'
    axL.plot(history.history['accuracy'])
    axL.plot(history.history['val_accuracy'])
    axL.grid(True)
    axL.set_title(L_title)
    axL.set_ylabel('accuracy')
    axL.set_xlabel('epoch')
    axL.legend(['train', 'test'], loc='upper left')

    # [Right side]Graph about loss
    R_title = "Loss_vs_Epoch"
    axR.plot(history.history['loss'])
    axR.plot(history.history['val_loss'])
    axR.grid(True)
    axR.set_title(R_title)
    axR.set_ylabel('loss')
    axR.set_xlabel('epoch')
    axR.legend(['train', 'test'], loc='upper left')

    #Save the graph as an image
    fig.savefig('history.jpg')
    plt.close()

#Saving the learning curve
plot_history(history)

The learning results are as follows.

history.jpg

6. Evaluation

Since the evaluation is unbalanced data, it is evaluated by F1 Score. First, infer the test data using the model you learned earlier.

Additional import.

import numpy as np
from PIL import Image
from sklearn.metrics import classification_report
from tqdm import tqdm

Describe the parameters. This time, read the csv file of test.

directory = 'img' #Folder where images are stored
df_test = pd.read_csv('test.csv') #DataFrame with test data information
label_list = ['AMD', 'DR_DM', 'Gla', 'MH', 'Normal', 'RD', 'RP', 'RVO'] #Label name
image_size = (224, 224) #Input image size
classes = len(label_list) #Number of classification classes

Build the learned network and load the weights you learned earlier.

#Network construction&Read learned weights
model = cnn(image_size, classes)
model.load_weights('model_weights.h5')

The image is read and converted so that the conditions are the same as during learning, and inference is performed.

#inference
X = df_test['filename'].values
y_true = list(map(lambda x: label_list.index(x), df_test['label'].values))
y_pred = []
for file in tqdm(X, desc='pred'):
    #Resize the image so that it has the same conditions as when learning&conversion
    img = Image.open(f'{directory}/{file}')
    img = img.resize(image_size, Image.LANCZOS)
    img = np.array(img, dtype=np.float32)
    img *= 1./255
    img = np.expand_dims(img, axis=0)

    y_pred.append(np.argmax(model.predict(img)[0]))

Calculate the F1 Score using scikit-learn.

#Evaluation
print(classification_report(y_true, y_pred, target_names=label_list))

Below are the evaluation results. Sure enough, AMD and MH, which have a small amount of data, have low scores.

              precision    recall  f1-score   support

         AMD       0.17      0.67      0.27        75
       DR_DM       0.72      0.75      0.73       620
         Gla       0.76      0.69      0.72       459
          MH       0.09      0.34      0.14        32
      Normal       0.81      0.50      0.62       871
          RD       0.87      0.79      0.83       176
          RP       0.81      0.86      0.83        50
         RVO       0.45      0.65      0.53       107

    accuracy                           0.64      2390
   macro avg       0.58      0.66      0.59      2390
weighted avg       0.73      0.64      0.67      2390

7. Summary

In this article, we used a simple 10-layer CNN to classify images of the wide-angle fundus dataset published by Tsukazaki Hospital. In the future, based on this result, we will improve the performance while incorporating the latest methods such as network structure and data expansion method.

Recommended Posts

Image classification with wide-angle fundus image dataset
Image segmentation with CaDIS: a Cataract Dataset
Cooking object detection with yolo + image classification
[PyTorch] Handle image pairs with Dataset & DataLorder
MNIST (handwritten digit) image classification with multi-layer perceptron
Image processing with MyHDL
Image recognition with keras
Image processing with Python
Challenge image classification with TensorFlow2 + Keras 3 ~ Visualize MNIST data ~
Image Processing with PIL
"Garbage classification by image!" App creation diary day2-Fine-tuning with VGG16-
[Deep learning] Image classification with convolutional neural network [DW day 4]
Image download with Flickr API
[PyTorch] Image classification of CIFAR-10
I tried AutoGluon's Image Classification
Read image coordinates with Python-matplotlib
Image processing with PIL (Pillow)
Image editing with python OpenCV
Document classification with Sentence Piece
Image upload & customization with django-ckeditor
Sorting image files with Python (3)
CNN (1) for image classification (for beginners)
Create Image Viewer with Tkinter
Sorting image files with Python
Image processing with Python (Part 3)
Image caption generation with Chainer
Get image features with OpenCV
Image recognition with Keras + OpenCV
[Python] Image processing with scikit-image
Challenge image classification by TensorFlow2 + Keras 4 ~ Let's predict with trained model ~
Challenge image classification with TensorFlow2 + Keras 9-Learning, saving and loading models-