[PYTHON] Approximate Nearest Neighbor Search for Similar Image Analysis (For Beginners) (1)

I would like to write about similar images, but first I will explain the K-NN method.

K-NN method

Identify from k data that are close to a certain data. Those with a large number are judged to be the same class as the data class. ex) When k = 3, since there is one square and two triangles, it is regarded as a group of triangles. When k = 5, there are 3 squares and 2 triangles, so it is regarded as a group of squares. スクリーンショット 2020-03-01 22.05.27.png

Depending on which k you take, the solution will be different, so you need to find a suitable k. Use crossvalidation to find the generalization error for each k, and set the smallest one as the optimum k. Generalization error (= error from actual results based on test data) It takes a huge amount of processing time to perform all of these calculations for each k when vector conversion is performed using images. Therefore, approximate nearest neighbor search.

Approximate nearest neighbor search

Even if the nearest neighbor is far away, it is allowed and adopted. d(q,x) <= (1+ε)d(q,x)* Distance d (q, x) to approximate solution Distance to the nearest neighbor by d (q, x *)

Approximate solution is determined by best-first search Best-first search is a search algorithm that selects the most desirable node to search next according to some rules.

スクリーンショット 2020-03-01 22.28.23.png # Try moving the sample Use the Iris dataset as the dataset.

python


pip install annoy scikit-learn

python


from collections import Counter
from sklearn import datasets
from annoy import AnnoyIndex
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_validate
from sklearn.utils import check_X_y
from sklearn.utils import check_array


class AnnoyClassifier(BaseEstimator, ClassifierMixin):
    #Approximate nearest neighbor search model creation part, this is the liver.
    def __init__(self, n_trees, metric='angular', n_neighbors=1, search_k=-1):
        # k-number of d trees
        self.n_trees_ = n_trees
        #Distance used for calculation
        self.metric_ = metric
        #Number of neighborhoods
        self.n_neighbors_ = n_neighbors
        #Parameters used for accuracy
        self.search_k_ = search_k
        #model
        self.clf_ = None
        #Add a class label for training data
        self.train_y_ = None

    def fit(self, X, y):
        #Input part
        check_X_y(X, y)
        #Save the class label of the training data
        self.train_y_ = y
        #Prepare a model of Annoy
        self.clf_ = AnnoyIndex(X.shape[1], metric=self.metric_)
        #To learn
        for i, x in enumerate(X):
            self.clf_.add_item(i, x)
        # k-d tree part
        self.clf_.build(n_trees=self.n_trees_)
        return self

    def predict(self, X):
        check_array(X)
        #Returns the result
        y_pred = [self._predict(x) for x in X]
        return y_pred

    def _predict(self, x):
        #Find a neighborhood
        neighbors = self.clf_.get_nns_by_vector(x, self.n_neighbors_, search_k=self.search_k_)
        #Convert index to class label
        neighbor_classes = self.train_y_[neighbors]
        #Extract the mode
        counter = Counter(neighbor_classes)
        most_common = counter.most_common(1)
        #Returns the mode class label
        return most_common[0][0]

    def get_params(self, deep=True):
        #Classifier parameters
        return {
            'n_trees': self.n_trees_,
            'metric': self.metric_,
            'n_neighbors': self.n_neighbors_,
            'search_k': self.search_k_,
        }

def main():
    #Load the Iris dataset
    dataset = datasets.load_iris()
    X, y = dataset.data, dataset.target
    #Classifier
    clf = AnnoyClassifier(n_trees=10)
    # 3-fold CV
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    #Measure generalization performance using accuracy as an evaluation index
    score = cross_validate(clf, X, y, cv=skf, scoring='accuracy')
    mean_test_score = score.get('test_score').mean()
    print('acc:', mean_test_score)


if __name__ == '__main__':
    main()

result

python


acc: 0.98

reference

[Machine learning_k-nearest neighbor method_theory] (https://dev.classmethod.jp/machine-learning/2017ad_20171218_knn/#sec4)

Recommended Posts

Approximate Nearest Neighbor Search for Similar Image Analysis (For Beginners) (1)
CNN (1) for image classification (for beginners)
Search Geographic Nearest Neighbor Data (MongoDB)
Vectorize sentences and search for similar sentences
Causal reasoning and causal search with Python (for beginners)
How to use data analysis tools for beginners
Basic principles of image recognition technology (for beginners)