[PYTHON] Image segmentation with CaDIS: a Cataract Dataset

1.First of all

For beginners, this article aims to use TensorFlow 2.0 for the time being to do semantic segmentation with Deep Learning. The image dataset uses the cataract surgery segmentation dataset [^ 1] published by Digital Surgery Ltd. Also, the network will be SegNet [^ 2] with the 10-layer CNN used previous as the encoder.

All code

2. Environment

--PC specs

  1. CaDIS: Cataract Dataset for Image Segmentation 4738 (25 videos) cataract surgery segmentation dataset published by Digital Surgery Ltd. You can download surgical images and segmentation images from the links below. CaDIS Dataset https://cataracts.grand-challenge.org/CaDIS/

The segmentation labels are as follows. At the same time, it also shows the percentage of each class in pixels. From the table, you can see that some classes do not exist in each group. ~~ It's annoying! ~~ The number of images for learning, verification, and test is 3584 (19 videos), 540 (3 videos), and 614 (3 videos), respectively.

Index Class Pixel ratio(Learning)[%] Pixel ratio(Verification)[%] Pixel ratio(test)[%]
0 Pupil 17.1 15.7 16.2
1 Surgical Tape 6.51 6.77 4.81
2 Hand 0.813 0.725 0.414
3 Eye Retractors 0.564 0.818 0.388
4 Iris 11.0 11.0 12.8
5 Eyelid 0 0 1.86
6 Skin 12.0 20.4 10.7
7 Cornea 49.6 42.2 50.6
8 Hydro. Cannula 0.138 0.0984 0.0852
9 Visco. Cannula 0.0942 0.0720 0.0917
10 Cap. Cystotome 0.0937 0.0821 0.0771
11 Rycroft Cannula 0.0618 0.0788 0.0585
12 Bonn Forceps 0.241 0.161 0.276
13 Primary Knife 0.123 0.258 0.249
14 Phaco. Handpiece 0.173 0.240 0.184
15 Lens Injector 0.343 0.546 0.280
16 A/I Handpiece 0.327 0.380 0.305
17 Secondary Knife 0.102 0.0933 0.148
18 Micromanipulator 0.188 0.229 0.215
19 A/I Handpiece Handle 0.0589 0.0271 0.0358
20 Cap. Forceps 0.0729 0.0144 0.0384
21 Rycroft Cannula Handle 0.0406 0.0361 0.0101
22 Phaco. Handpiece Handle 0.0566 0.00960 0.0202
23 Cap. Cystotome Handle 0.0170 0.0124 0.0287
24 Secondary Knife Handle 0.0609 0.0534 0.0124
25 Lens Injector Handle 0.0225 0.0599 0.0382
26 Water Sprayer 0.000448 0 0.00361
27 Suture Needle 0.000764 0 0
28 Needle Holder 0.0201 0 0
29 Charleux Cannula 0.00253 0 0.0164
30 Vannas Scissors 0.00107 0 0
31 Primary Knife Handle 0.000321 0 0.000385
32 Viter. Handpiece 0 0 0.0782
33 Mendez Ring 0.0960 0 0
34 Biomarker 0.00619 0 0
35 Marker 0.0661 0 0

In addition, an image sample is shown below. The raw segmentation image is a grayscale image with the Index in the above table as the pixel value.

Contains grotesque images
Surgical images and segmentation images [^ 1] Raw segmentation image

4. Data split

This dataset determines the images (videos) that should be used for training, validation, and testing. Details can be found in a file called splits.txt in the dataset. Therefore, the split group adopts the contents of splits.txt and describes the file path of the surgical image and segmentation image of each group and their correspondence in the csv file with the following code.

Code that describes the image file path in the csv file
import os
from collections import defaultdict
import pandas as pd


#Create a csv file that describes the correspondence between images and labels
def make_csv(fpath, dirlist):
    #Examine the file path of the training image
    dataset = defaultdict(list)
    for dir in dirlist:
        filelist = sorted(os.listdir(f'CaDIS/{dir}/Images'))
        dataset['filename'] += list(map(lambda x: f'{dir}/Images/{x}', filelist))
        filelist = sorted(os.listdir(f'CaDIS/{dir}/Labels'))
        dataset['label'] += list(map(lambda x: f'{dir}/Labels/{x}', filelist))

    #Save as csv file
    dataset = pd.DataFrame(dataset)
    dataset.to_csv(fpath, index=False)



#Training data video folder
train_dir = ['Video01', 'Video03', 'Video04', 'Video06', 'Video08', 'Video09',
             'Video10', 'Video11', 'Video13', 'Video14', 'Video15', 'Video17',
             'Video18', 'Video20', 'Video21', 'Video22', 'Video23', 'Video24',
             'Video25']

#Verification data video folder
val_dir = ['Video05', 'Video07', 'Video16']

#Test data video folder
test_dir = ['Video02', 'Video12', 'Video19']


#Create a csv file that describes the correspondence between the image of the training data and the label
make_csv('train.csv', train_dir)

#Create a csv file that describes the correspondence between the image of the verification data and the label
make_csv('val.csv', val_dir)

#Create a csv file that describes the correspondence between the image of the training data and the label
make_csv('test.csv', test_dir)

The csv file containing the file paths for training, verification, and test data is in this format.

filename label
Video01/Images/Video1_frame000090.png Video01/Labels/Video1_frame000090.png
Video01/Images/Video1_frame000100.png Video01/Labels/Video1_frame000100.png
Video01/Images/Video1_frame000110.png Video01/Labels/Video1_frame000110.png

5. Model building & learning

First, import the library you want to use.

import dataclasses
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, MaxPool2D, UpSampling2D
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.utils import Sequence
import cv2

Next, describe the parameters etc.

directory = 'CaDIS' #Folder where images are stored
df_train = pd.read_csv('train.csv') #DataFrame with training data information
df_validation = pd.read_csv('val.csv') #DataFrame with validation data information
image_size = (224, 224) #Input image size
classes = 36 #Number of classification classes
batch_size = 32 #Batch size
epochs = 300 #Number of epochs
loss = cce_dice_loss #Loss function
optimizer = Adam(lr=0.001, amsgrad=True) #Optimization function
metrics = dice_coeff #Evaluation method
#ImageDataGenerator Image amplification parameters
aug_params = {'rotation_range': 5,
              'width_shift_range': 0.05,
              'height_shift_range': 0.05,
              'shear_range': 0.1,
              'zoom_range': 0.05,
              'horizontal_flip': True,
              'vertical_flip': True}

The following is applied as the callback processing during learning.

# val_Save model only when loss is minimized
mc_cb = ModelCheckpoint('model_weights.h5',
                        monitor='val_loss', verbose=1,
                        save_best_only=True, mode='min')
#When learning is stagnant, the learning rate is set to 0..Double
rl_cb = ReduceLROnPlateau(monitor='loss', factor=0.2, patience=3,
                          verbose=1, mode='auto',
                          min_delta=0.0001, cooldown=0, min_lr=0)
#If learning does not progress, learning will be forcibly terminated
es_cb = EarlyStopping(monitor='loss', min_delta=0,
                      patience=5, verbose=1, mode='auto')

Generates a generator of training and validation data. Use ʻImageDataGeneratorfor data expansion. Also, this time we will useSequence` to create mini-batch data.

The __getitem__ function is the part that specifically creates the mini-batch. The input image is processed according to the following procedure.

  1. Load the image
  2. Resize to the specified input image size
  3. Convert to float type
  4. Perform data expansion processing
  5. Divide the value by 255 and normalize to 0-1

The processing of the segmentation image is performed according to the following procedure.

  1. Load the image
  2. Resize to the specified input image size
  3. Convert to float type
  4. Perform data expansion processing
  5. Create an image in which the pixel of class 0 is 1 and 0 otherwise, the image in which the pixel of class 1 is 1 and 0 otherwise ... (for the number of classes) and connect them in the channel direction. Create an array of sizes (vertical, horizontal, number of classes)
#Data generator
@dataclasses.dataclass
class TrainSequence(Sequence):
    directory: str #Folder where images are stored
    df: pd.DataFrame #DataFrame with data information
    image_size: tuple #Input image size
    classes: int #Number of classification classes
    batch_size: int #Batch size
    aug_params: dict #ImageDataGenerator Image amplification parameters

    def __post_init__(self):
        self.df_index = list(self.df.index)
        self.train_datagen = ImageDataGenerator(**self.aug_params)

    def __len__(self):
        return math.ceil(len(self.df_index) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.df_index[idx * self.batch_size:(idx+1) * self.batch_size]

        x = []
        y = []
        for i in batch_x:
            rand = np.random.randint(0, int(1e9))
            #Input image
            img = cv2.imread(f'{self.directory}/{self.df.at[i, "filename"]}')
            img = cv2.resize(img, self.image_size, interpolation=cv2.INTER_LANCZOS4)
            img = np.array(img, dtype=np.float32)
            img = self.train_datagen.random_transform(img, seed=rand)
            img *= 1./255
            x.append(img)

            #Segmentation image
            img = cv2.imread(f'{self.directory}/{self.df.at[i, "label"]}', cv2.IMREAD_GRAYSCALE)
            img = cv2.resize(img, self.image_size, interpolation=cv2.INTER_LANCZOS4)
            img = np.array(img, dtype=np.float32)
            img = np.reshape(img, (self.image_size[0], self.image_size[1], 1))
            img = self.train_datagen.random_transform(img, seed=rand)
            img = np.reshape(img, (self.image_size[0], self.image_size[1]))
            seg = []
            for label in range(self.classes):
                seg.append(img == label)
            seg = np.array(seg, np.float32)
            seg = seg.transpose(1, 2, 0)
            y.append(seg)

        x = np.array(x)
        y = np.array(y)


        return x, y

#Generator generation
##Training data generator
train_generator = TrainSequence(directory=directory, df=df_train,
                                image_size=image_size, classes=classes,
                                batch_size=batch_size, aug_params=aug_params)
step_size_train = len(train_generator)
##Validation data generator
validation_generator = TrainSequence(directory=directory, df=df_validation,
                                     image_size=image_size, classes=classes,
                                     batch_size=batch_size, aug_params={})
step_size_validation = len(validation_generator)

Last time Constructed SegNet with a structure that excludes all connections from the created 10-layer simple CNN as an encoder, and a structure that looks like the encoder in reverse order as a decoder. To do. Please refer to here for the explanation of SegNet.

# SegNet(8 layers of encoder, 8 layers of decoder)Build
def cnn(input_shape, classes):
    #Input image size must be a multiple of 32
    assert input_shape[0]%32 == 0, 'Input size must be a multiple of 32.'
    assert input_shape[1]%32 == 0, 'Input size must be a multiple of 32.'

    #encoder
    ##Input layer
    inputs = Input(shape=(input_shape[0], input_shape[1], 3))

    ##1st layer
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    ##2nd layer
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    ##3rd layer
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    ##4th layer
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    ##5th and 6th layers
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

    ##7th and 8th layers
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    #Decoder
    ##1st layer
    x = Conv2D(1024, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    ##2nd and 3rd layers
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    ##4th layer
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    ##5th layer
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    ##6th layer
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    ##7th and 8th layers
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    x = Conv2D(classes, (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal')(x)
    outputs = Activation('softmax')(x)


    return Model(inputs=inputs, outputs=outputs)

#Network construction
model = cnn(image_size, classes)
model.summary()
model.compile(loss=loss, optimizer=optimizer, metrics=[metrics])

The rest is the same as Last time. Train and save the learning curve.

#Learning
history = model.fit_generator(
    train_generator, steps_per_epoch=step_size_train,
    epochs=epochs, verbose=1, callbacks=[mc_cb, rl_cb, es_cb],
    validation_data=validation_generator,
    validation_steps=step_size_validation,
    workers=3)

#Draw and save a graph of the learning curve
def plot_history(history):
    fig, (axL, axR) = plt.subplots(ncols=2, figsize=(10, 4))

    # [left]Graph about metrics
    L_title = 'Dice_coeff_vs_Epoch'
    axL.plot(history.history['dice_coeff'])
    axL.plot(history.history['val_dice_coeff'])
    axL.grid(True)
    axL.set_title(L_title)
    axL.set_ylabel('dice_coeff')
    axL.set_xlabel('epoch')
    axL.legend(['train', 'test'], loc='upper left')

    # [Right side]Graph about loss
    R_title = "Loss_vs_Epoch"
    axR.plot(history.history['loss'])
    axR.plot(history.history['val_loss'])
    axR.grid(True)
    axR.set_title(R_title)
    axR.set_ylabel('loss')
    axR.set_xlabel('epoch')
    axR.legend(['train', 'test'], loc='upper left')

    #Save the graph as an image
    fig.savefig('history.jpg')
    plt.close()

#Saving the learning curve
plot_history(history)

The learning results are as follows.

history.jpg

6. Evaluation

Evaluation is performed by average IoU for each class and mean IoU which is the average of them. The calculation was done with the following code.

Additional import.

from collections import defaultdict

Inference and evaluation are performed according to the following procedure.

  1. Load the image
  2. Resize to the specified input image size
  3. Convert to float type and normalize the value to 0 to 1
  4. Make an array of batch size 1
  5. Infer and get the segmentation image
  6. Restore the segmentation image size to its original size
  7. Calculate IoU for each image and each class
  8. Calculate average IoU for each class

    directory = 'CaDIS' #Folder where images are stored
    df_test = pd.read_csv('test.csv') #DataFrame with test data information
    image_size = (224, 224) #Input image size
    classes = 36 #Number of classification classes


    #Network construction
    model = cnn(image_size, classes)
    model.summary()
    model.load_weights('model_weights.h5')


    #inference
    dict_iou = defaultdict(list)
    for i in tqdm(range(len(df_test)), desc='predict'):
        img = cv2.imread(f'{directory}/{df_test.at[i, "filename"]}')
        height, width = img.shape[:2]
        img = cv2.resize(img, image_size, interpolation=cv2.INTER_LANCZOS4)
        img = np.array(img, dtype=np.float32)
        img *= 1./255
        img = np.expand_dims(img, axis=0)
        label = cv2.imread(f'{directory}/{df_test.at[i, "label"]}', cv2.IMREAD_GRAYSCALE)

        pred = model.predict(img)[0]
        pred = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)

        ##IoU calculation
        pred = np.argmax(pred, axis=2)
        for j in range(classes):
            y_pred = np.array(pred == j, dtype=np.int)
            y_true = np.array(label == j, dtype=np.int)
            tp = sum(sum(np.logical_and(y_pred, y_true)))
            other = sum(sum(np.logical_or(y_pred, y_true)))
            if other != 0:
                dict_iou[j].append(tp/other)

    # average IoU
    for i in range(classes):
        if i in dict_iou:
            dict_iou[i] = sum(dict_iou[i]) / len(dict_iou[i])
        else:
            dict_iou[i] = -1
    print('average IoU', dict_iou)

Below are the evaluation results. In addition, mean IoU was 15.0%. According to the paper [^ 1], VGG is 20.61%, so I think this is the case.

Index Class average IoU[%]
0 Pupil 85.3
1 Surgical Tape 53.3
2 Hand 6.57
3 Eye Retractors 21.9
4 Iris 74.4
5 Eyelid 0.0
6 Skin 49.7
7 Cornea 88.0
8 Hydro. Cannula 0
9 Visco. Cannula 0
10 Cap. Cystotome 0
11 Rycroft Cannula 0
12 Bonn Forceps 3.58
13 Primary Knife 5.35
14 Phaco. Handpiece 0.0781
15 Lens Injector 16.4
16 A/I Handpiece 16.4
17 Secondary Knife 6.08
18 Micromanipulator 0
19 A/I Handpiece Handle 6.49
20 Cap. Forceps 0
21 Rycroft Cannula Handle 0
22 Phaco. Handpiece Handle 0
23 Cap. Cystotome Handle 0
24 Secondary Knife Handle 2.49
25 Lens Injector Handle 0
26 Water Sprayer
27 Suture Needle 0
28 Needle Holder
29 Charleux Cannula 0
30 Vannas Scissors
31 Primary Knife Handle 0
32 Viter. Handpiece 0
33 Mendez Ring
34 Biomarker
35 Marker

7. Summary

In this article, we performed the semantic segmentation of the cataract surgery segmentation dataset [^ 1] published by Digital Surgery Ltd. using SegNet with 8 layers each for encoder and decoder. According to the paper [^ 1], it seems that 52.66% will be obtained with PSPNet, so in the future, based on this result, I will aim for the same or better performance while incorporating the latest methods such as network structure and data expansion method.

Recommended Posts

Image segmentation with CaDIS: a Cataract Dataset
Image segmentation with scikit-image and scikit-learn
Image classification with wide-angle fundus image dataset
Create a dummy image with Python + PIL.
[PyTorch] Handle image pairs with Dataset & DataLorder
Performance comparison by incorporating a skip structure in SegNet (CaDIS: a Cataract Dataset)
I made a QR code image with CuteR
Image processing with MyHDL
Image recognition with keras
A4 size with python-pptx
Image segmentation using U-net
Image processing with Python
Creating a dataset loader
Decorate with a decorator
Image Processing with PIL