[PYTHON] Image segmentation using U-net

The environment uses python3.7 and Tensorflow 2.1.1 with the following contents.

What is image segmentation?

VOC2012 Let's look at an example of image segmentation in a dataset (see below). Screen Shot 2020-10-29 at 22.17.07.png In this example, the pixels in the image are classified as bike, driver, or background. This way of classifying images ** in pixel units ** is called image segmentation (or semantic segmentation).

When you start studying neural networks, you will often implement MNIST handwriting recognition in tutorials. In this case, the image is input and the probability that the image is a number from 0 to 9 is output, so the final output layer is 1D 10 nodes.

On the other hand, in the case of image segmentation, we want to classify each pixel of the image, so the final output layer has the dimension (vertical size of image, horizontal size of image, number of classes). U-net, which I will introduce this time, can provide excellent accuracy for such image segmentation problems.

U-net U-net is a model using CNN and skip connection published in the paper "U-Net: Convolutional Networks for Biomedical Image Segmentation". (Figure below).

Screen Shot 2020-10-29 at 22.44.45.png

The basic idea of U-net is ・ Feature extraction while coarse-graining the image with the convolutional layer and pooling layer -Restore information in pixel units with skip connection (the part connected horizontally in the above figure) is. In particular, the skip connection is the core of this U-net, and without this skip connection, it seems that the information in pixel units cannot be retained much due to the effect of coarse graining, and the accuracy of segmentation drops. U-net is now used as a part of more complex neural networks as the basis of image segmentation.

U-net implementation

First, let's write the first part of U-net (the upper left three layers in the above figure). In the following implementation, the size of the input image is set to (256, 256).

from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(256, 256, 1), name="img")
x = layers.Conv2D(64, 3, activation="relu", padding="same", kernel_initializer='he_normal')(inputs)
block_1_output = layers.Conv2D(64, 3, activation="relu", padding="same")(x)

An image of size (256, 256, 1) is convolved with a Convolutional neural network (CNN) with 64 channels and a filter size of 3. Since the activation function is a deep model, we use ReLU. Name the output block_1_output for a later skip connection. The next layer of U-net would be as follows.

x = layers.MaxPooling2D(pool_size=(2, 2), padding="same")(block_1_output)
x = layers.Conv2D(128, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_2_output = layers.Conv2D(128, 3, activation="relu", padding="same")(x)

It receives "block_1_output", coarse-grains it with Max Pooling, increases the number of channels with CNN, and convolves it. Save "block_2_output" for skip connection as before. If you write U-net in this condition, the whole picture will be as follows.

from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(256, 256, 1), name="img")
x = layers.Conv2D(64, 3, activation="relu", padding="same", kernel_initializer='he_normal')(inputs)
block_1_output = layers.Conv2D(64, 3, activation="relu", padding="same")(x)

x = layers.MaxPooling2D(pool_size=(2, 2), padding="same")(block_1_output)
x = layers.Conv2D(128, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_2_output = layers.Conv2D(128, 3, activation="relu", padding="same")(x)

x = layers.MaxPooling2D(pool_size=(2, 2), padding="same")(block_2_output)
x = layers.Conv2D(256, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_3_output = layers.Conv2D(256, 3, activation="relu", padding="same")(x)

x = layers.MaxPooling2D(pool_size=(2, 2), padding="same")(block_3_output)
x = layers.Conv2D(512, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_4_output = layers.Conv2D(512, 3, activation="relu", padding="same")(x)

x = layers.Dropout(0.5)(block_4_output)
x = layers.MaxPooling2D(pool_size=(2, 2), padding="same")(x)
x = layers.Conv2D(1024, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_5_output = layers.Conv2D(1024, 3, activation="relu", padding="same")(x)

x = layers.Dropout(0.5)(block_5_output)
x = layers.UpSampling2D(size=(2,2))(x)
x = layers.Conv2D(512, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
x = tf.concat([x, block_4_output], axis=3)
x = layers.Conv2D(512, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_6_output = layers.Conv2D(512, 3, activation="relu", padding="same")(x)

x = layers.UpSampling2D(size=(2,2))(block_6_output)
x = layers.Conv2D(256, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
x = tf.concat([x, block_3_output], axis=3)
x = layers.Conv2D(256, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_7_output = layers.Conv2D(256, 3, activation="relu", padding="same")(x)

x = layers.UpSampling2D(size=(2,2))(block_7_output)
x = layers.Conv2D(128, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
x = tf.concat([x, block_2_output], axis=3)
x = layers.Conv2D(128, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
block_8_output = layers.Conv2D(128, 3, activation="relu", padding="same")(x)

x = layers.UpSampling2D(size=(2,2))(block_8_output)
x = layers.Conv2D(64, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
x = tf.concat([x, block_1_output], axis=3)
x = layers.Conv2D(64, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
x = layers.Conv2D(64, 3, activation="relu", padding="same",kernel_initializer='he_normal')(x)
outputs = layers.Conv2D(1, 1, activation="sigmoid", padding="same")(x)

model = keras.Model(inputs, outputs, name="u-net")

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

It's been a long time since U-net is complicated, but what I'm doing is repeating similar things. Where to connect with skip connection, use tf.concat to connect with the channel part as the axis. The image of the U-net network does not include the Dropout layer, but when I read the text of the paper,

Drop-out layers at the end of the contracting path perform further >implicit data augmentation

Since it is, the Dropout layer is also included in the above network. Also, assuming segmentation into two classes, binary_crossentropy is used with the number of output channels set to 1. You can also learn about the output by setting the number of channels to 2 and using sparse_categorical_crossentropy.

Let's check if the model is made correctly.

tf.keras.utils.plot_model(model)

If you output the model with, it will be as follows. Unknown-4.png It's very long, but you can see that the first half layer and the second half layer are properly connected by skip connection. When you output the model with model.summary (), it looks like this.

model.summary()

Model: "u-net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
img (InputLayer)                [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 640         img[0][0]                        
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 64) 36928       conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 128, 128, 64) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 128, 128, 128 73856       max_pooling2d[0][0]              
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 128 147584      conv2d_2[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 64, 64, 128)  0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 256)  295168      max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 256)  590080      conv2d_4[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 32, 32, 256)  0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 32, 32, 512)  1180160     max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 512)  2359808     conv2d_6[0][0]                   
__________________________________________________________________________________________________
dropout (Dropout)               (None, 32, 32, 512)  0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 16, 16, 512)  0           dropout[0][0]                    
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 16, 16, 1024) 4719616     max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 1024) 9438208     conv2d_8[0][0]                   
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 16, 16, 1024) 0           conv2d_9[0][0]                   
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 32, 32, 1024) 0           dropout_1[0][0]                  
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 32, 32, 512)  4719104     up_sampling2d[0][0]              
__________________________________________________________________________________________________
tf_op_layer_concat (TensorFlowO [(None, 32, 32, 1024 0           conv2d_10[0][0]                  
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 512)  4719104     tf_op_layer_concat[0][0]         
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 512)  2359808     conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 64, 64, 512)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 256)  1179904     up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
tf_op_layer_concat_1 (TensorFlo [(None, 64, 64, 512) 0           conv2d_13[0][0]                  
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 256)  1179904     tf_op_layer_concat_1[0][0]       
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 256)  590080      conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 128, 128, 256 0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 128 295040      up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
tf_op_layer_concat_2 (TensorFlo [(None, 128, 128, 25 0           conv2d_16[0][0]                  
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 128, 128 295040      tf_op_layer_concat_2[0][0]       
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 128, 128, 128 147584      conv2d_17[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 256, 256, 128 0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 256, 256, 64) 73792       up_sampling2d_3[0][0]            
__________________________________________________________________________________________________
tf_op_layer_concat_3 (TensorFlo [(None, 256, 256, 12 0           conv2d_19[0][0]                  
                                                                 conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 256, 256, 64) 73792       tf_op_layer_concat_3[0][0]       
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 256, 256, 64) 36928       conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 256, 256, 1)  65          conv2d_21[0][0]                  
==================================================================================================
Total params: 34,512,193
Trainable params: 34,512,193
Non-trainable params: 0
__________________________________________________________________________________________________

The total number of parameters is 34,512,193 !, which is quite a number.

Learning

Please refer to Previous article for the preparation of input data such as padding of data. Cut the image data of ISBI challenge 2012 (Segmentation of neuronal structures in EM stacks) into patches of (256, 256) and perform data augmentation with ImageDataGenerator. There is.

Now, let's learn as follows. my_generator is a generator that passes training image data, and my_val_gen is a generator that passes verification image data.

EPOCHS = 200
STEPS_PER_EPOCH = 300
VALIDATION_STEPS = STEPS_PER_EPOCH//10
model_history = model.fit(my_generator, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=my_val_gen)

The state of learning is as follows. The horizontal axis is the number of EPOCHS and the vertical axis is accuracy or loss. The training loss and training accuracy are decreasing monotonically, indicating that learning is progressing. The validation accuracy is also over 90%, and it seems that there is predictive performance. If you look closely, the validation accuracy / loss is the maximum / minimum per 20 to 30 steps. This suggests that it is overfitting as the number of EPOCHS increases further. history.png

Let's look at the segmentation of test data using a model around EPOCHS 30.

def create_mask(pred_img):
  pred_img = tf.math.greater(pred_img, 0.5)
  return pred_img[0]

plt.figure(dpi=150)
plt.subplot(1, 2, 1)
plt.imshow(test_image[:256, :256], cmap="gray")

plt.subplot(1, 2, 2)
test_image = test_image[:256,:256,tf.newaxis]
test_image = np.expand_dims(test_image, axis=0)
pred_img = model.predict(test_image)
masked_img = create_mask(pred_img)
plt.imshow(masked_img[:,:,0])

1.jpg On the other hand, the segmentation of the training data was as follows. 2.jpg You can see that the segmentation is working well for both the training data and the test data. However, compared to the training data, is there some noise in the segmentation results for the test data? I think. In this regard, it may be necessary to devise a little more about the model and data augmentation method.

At the end

We introduced the implementation using tensorflow for image segmentation using U-net. I implemented the U-net model of the original paper as it is, but there are various techniques that can be used such as Batch Normalization, so I will write an introduction again if I have a chance.

Recommended Posts

Image segmentation using U-net
Try using Jupyter's Docker image
Cloud image prediction using convLSTM
Generate a Docker image using Fabric
SLIC Superpixel segmentation in scikit image
Implemented image segmentation in python (Union-Find)
Judgment of backlit image using OpenCV
[Python] Using OpenCV with Python (Image transformation)
Image segmentation with scikit-image and scikit-learn
Environmentally friendly scraping using image processing
[FSL] Image measurement using ROI (VOI)
Image binarization using linear discriminant analysis
Image recognition of fruits using VGG16
Python: Basics of image recognition using CNN
Pokemon icon image discrimination using HOG features
Category estimation using docomo's image recognition API
Python: Application of image recognition using CNN
Face image inference using Flask and TensorFlow
Image recognition using CNN Horses and deer
Image segment using Oxford_iiit_pet on Google Colab
Image collection using Google Custom Search API
Image segmentation with CaDIS: a Cataract Dataset
(Reading the paper) Instance-aware Image Colorization (Region division: Color imaging using instance segmentation)