[PYTHON] Photo segmentation and clustering with DBSCAN

Introduction

This article is part of the sample code for Kaggle: The Nature Conservancy Fisheries Monitoring. The program related to CNN is not listed here. Details will be summarized at a later date.

Here, two analyzes are performed from photographs of fishery sites.

  1. Estimate a unique ID for each boat
  2. Segment the fish location from the photo

1. Clustering of boats that took pictures

Here, we will introduce How to automatically cluster the types of boats that have taken pictures. For example, it can be used to create a different model for each boat, or to mask areas where there are no fish on the board.

Create a function to display images side by side. There are two types, four and eight. After that, 500 samples are read from the train.

import pandas as pd
import numpy as np
import glob
from sklearn import cluster
from scipy.misc import imread
import cv2
import skimage.measure as sm
# import progressbar
import multiprocessing
import random
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
new_style = {'grid': False}
plt.rc('axes', **new_style)

# Function to show 4 images
def show_four(imgs, title):
    #select_imgs = [np.random.choice(imgs) for _ in range(4)]
    select_imgs = [imgs[np.random.choice(len(imgs))] for _ in range(4)]
    _, ax = plt.subplots(1, 4, sharex='col', sharey='row', figsize=(20, 3))
    plt.suptitle(title, size=20)
    for i, img in enumerate(select_imgs):
        ax[i].imshow(img)

# Function to show 8 images
def show_eight(imgs, title):
    select_imgs = [imgs[np.random.choice(len(imgs))] for _ in range(8)]
    _, ax = plt.subplots(2, 4, sharex='col', sharey='row', figsize=(20, 6))
    plt.suptitle(title, size=20)
    for i, img in enumerate(select_imgs):
        ax[i // 4, i % 4].imshow(img)

select = 500 # Only load 500 images for speed
# Data loading
train_files = sorted(glob.glob('../input/train/*/*.jpg'), key=lambda x: random.random())[:select]
train = np.array([imread(img) for img in train_files])
print('Length of train {}'.format(len(train)))

The size of train images is not unified. Is there a photo size that includes only certain boats? Let's check the boat ID as the image size.

print('Sizes in train:')
shapes = np.array([str(img.shape) for img in train])
pd.Series(shapes).value_counts()
(720, 1280, 3)    287
(750, 1280, 3)     81
(974, 1280, 3)     50
(670, 1192, 3)     29
(718, 1276, 3)     28
(924, 1280, 3)      9
(974, 1732, 3)      7
(700, 1244, 3)      5
(854, 1518, 3)      3
(750, 1334, 3)      1
dtype: int64

Divided the size. Let's display four actual images at a time.

for uniq in pd.Series(shapes).unique():
    show_four(train[shapes == uniq], 'Images with shape: {}'.format(uniq))
    plt.show()

Other images, except for the image size of (854, 1518, 3), contain one or more boats. Another approach is likely to be needed to consider the boat ID.

Boat clustering

For the sake of simplicity, we will analyze these 500 data. Of course, the same processing can be performed using all image data. The analysis is performed in the following three steps.

# Function for computing distance between images
def compare(args):
    img, img2 = args
    img = (img - img.mean()) / img.std()
    img2 = (img2 - img2.mean()) / img2.std()
    return np.mean(np.abs(img - img2))

# Resize the images to speed it up.
train = [cv2.resize(img, (224, 224), cv2.INTER_LINEAR) for img in train]

# Create the distance matrix in a multithreaded fashion
pool = multiprocessing.Pool(8)
#bar = progressbar.ProgressBar(max=len(train))
distances = np.zeros((len(train), len(train)))
for i, img in enumerate(train): #enumerate(bar(train)):
    all_imgs = [(img, f) for f in train]
    dists = pool.map(compare, all_imgs)
    distances[i, :] = dists

Create an NxN matrix. N is the number of images, and this matrix shows the distance between the images. SKLearn has many clustering methods that can use pre-computed distance matrices. Here, clustering is performed by giving a distance matrix to DBSCAN.

print(distances)
plt.hist(distances.flatten(), bins=50)
plt.title('Histogram of distance matrix')
print('')

__results___9_1.png

You can see that there is an area of 0.8 or less. Probably when calculating the distance between images of the same boat. DBSCAN considers distances up to 0.5 to be similar clusters. As far as I can see the histogram, I judge that 0.6 is a suitable threshold.

cls = cluster.DBSCAN(metric='precomputed', min_samples=5, eps=0.6)
y = cls.fit_predict(distances)
print(y)
print('Cluster sizes:')
print(pd.Series(y).value_counts())

for uniq in pd.Series(y).value_counts().index:
    if uniq != -1:
        size = len(np.array(train)[y == uniq])
        if size > 10:
            show_eight(np.array(train)[y == uniq], 'BoatID: {} - Image count {}'.format(uniq, size))
            plt.show()
        else:
            show_four(np.array(train)[y == uniq], 'BoatID: {} - Image count {}'.format(uniq, size))
            plt.show()

It worked pretty well. However, there was a group that did not belong to any boat ID. There are two possible reasons.

  1. There are only 5 or less images set as the boat ID threshold
  2. The distance function is not sufficient for image clusters. For example, night and day photos are mixed.
size = len(np.array(train)[y == -1])
show_eight(np.array(train)[y == -1], 'BoatID: {} (Unclassified images) - Image count {}'.format(-1, size))

Some can be visually clustered. In other words, there is room for improvement in the algorithm. However, the fact that more than 75% of boats could be classified by unsupervised learning means that the classification was quite accurate.

2. Fish segmentation

Segment the fish image. As a work procedure

  1. Prepare a photo of the fish as a template
  2. Select one photo other than the template and select a suitable segmentation method with multiple algorithms
  3. Segment multiple photos using the algorithm you selected earlier
import os 
from scipy import ndimage
from subprocess import check_output

import cv2
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

img_rows, img_cols= 350, 425
im_array = cv2.imread('../input/train/LAG/img_00091.jpg',0)
template = np.zeros([ img_rows, img_cols], dtype='uint8') # initialisation of the template
template[:, :] = im_array[100:450,525:950] # I try multiple times to find the correct rectangle. 
#template /= 255.
plt.subplots(figsize=(10, 7))
plt.subplot(121),plt.imshow(template, cmap='gray') 
plt.subplot(122), plt.imshow(im_array, cmap='gray')

Use data different from the photo used as the template earlier. By using opencv matchTemplate, you can find a part similar to the prepared template. There are several types of optional methods, so we will experiment with each of the six methods. The specified location is set to be surrounded by a rectangle.

file_name = '../input/train/LAG/img_01512.jpg' # img_00176,img_02758, img_01512
img = cv2.imread(file_name,0) 
img2 = img
w, h = template.shape[::-1]

# All the 6 methods for comparison in a list
methods = ['cv2.TM_CCOEFF', 'cv2.TM_CCOEFF_NORMED', 'cv2.TM_CCORR',
            'cv2.TM_CCORR_NORMED', 'cv2.TM_SQDIFF', 'cv2.TM_SQDIFF_NORMED']

for meth in methods:
     img = img2
     method = eval(meth)
 
     # Apply template Matching
     res = cv2.matchTemplate(img,template,method)
     min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
 
     # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
     if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
         top_left = min_loc
     else:
         top_left = max_loc
     bottom_right = (top_left[0] + w, top_left[1] + h)
 
     cv2.rectangle(img,top_left, bottom_right, 255, 2)
     fig, ax = plt.subplots(figsize=(12, 7))
     plt.subplot(121),plt.imshow(res,cmap = 'gray')
     plt.title('Matching Result'), plt.xticks([]), plt.yticks([])
     plt.subplot(122),plt.imshow(img,cmap = 'gray') #,aspect='auto'
     plt.title('Detected Point'), plt.xticks([]), plt.yticks([])
     plt.suptitle(meth)
 
     plt.show()

When I run it, I can segment the fish well by methods other than TM_SQDIFF and TM_SQDIFF_NORMED. Therefore, this time, TM_CCOEFF is used as a discovery method.

The photos in this competition are limited to 8 types of fish. Therefore, select 4 photos from each fish species and perform segmentation with TM_CCOEFF.

method = eval('cv2.TM_CCOEFF')
indexes=[1,30,40,5]

train_path = "../input/train/"
sub_folders = check_output(["ls", train_path]).decode("utf8").strip().split('\n')
for sub_folder in sub_folders:
    file_names = check_output(["ls", train_path+sub_folder]).decode("utf8").strip().split('\n')
    k=0
    _, ax = plt.subplots(2,2,figsize=(10, 7))
    for file_name in [file_names[x] for x in indexes]: # I take only 4 images of each group. 
        img = cv2.imread(train_path+sub_folder+"/"+file_name,0)
        img2 = img
        w, h = template.shape[::-1]
        # Apply template Matching
        res = cv2.matchTemplate(img,template,method)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
        top_left = max_loc
        bottom_right = (top_left[0] + w, top_left[1] + h)
 
        cv2.rectangle(img,top_left, bottom_right, 255, 2)
        if k==0 : 
            ax[0,0].imshow(img,cmap = 'gray')
            plt.xticks([]), plt.yticks([])
        if k==1 : 
            ax[0,1].imshow(img,cmap = 'gray')
            plt.xticks([]), plt.yticks([])
        if k==2 : 
            ax[1,0].imshow(img,cmap = 'gray')
            plt.xticks([]), plt.yticks([])
        if k==3 : 
            ax[1,1].imshow(img,cmap = 'gray')
            plt.xticks([]), plt.yticks([])
        k=k+1
    plt.suptitle(sub_folder)
    plt.show()

When you run it, you can see that only the fish data "LAG" used in the template can be segmented with relatively high accuracy. You may want to prepare another template for other fish.

Recommended Posts

Photo segmentation and clustering with DBSCAN
Clustering with scikit-learn + DBSCAN
DBSCAN (clustering) with scikit-learn
Image segmentation with scikit-image and scikit-learn
Clustering with python-louvain
DBSCAN with scikit-learn
Clustering with scikit-learn (1)
Clustering with scikit-learn (2)
Collaborative filtering with principal component analysis and K-means clustering
Relationship data learning with numpy and NetworkX (spectral clustering)
DBSCAN algorithm (data clustering)
DBSCAN practices and algorithms
With and without WSGI
Perform (Visualization> Clustering> Feature Description) with (t-SNE, DBSCAN, Decision Tree)
With me, cp, and Subprocess
Programming with Python and Tkinter
Encryption and decryption with Python
Working with tkinter and mouse
Python and hardware-Using RS232C with Python-
I tried clustering with PyCaret
Clustering ID-POS data with LDA
Binarize photo data with OpenCV
Super-resolution with SRGAN and ESRGAN
group_by with sqlalchemy and sum
python with pyenv and venv
Deep Embedded Clustering with Chainer 2.0
With me, NER and Flair
Works with Python and R
Unsupervised learning of mnist with autoencoder and clustering and evaluating latent variables