[Python] Implementation of Nelder–Mead method and saving of GIF images by matplotlib

This paper describes the Nelder–Mead method, which is known as an optimization algorithm.

The purpose is.

Overview and implementation of the Nelder–Mead method

Algorithm overview

The Nelder–Mead method is an algorithm [^ 1] published by John A. Nelder and Roger Mead in 1965, which makes a $ n $ dimensional simplex consisting of $ n + 1 $ vertices like an amoeba. Search for the minimum value of the function while moving it. ([Wikipedia](https://en.wikipedia.org/wiki/%E3%83%8D%E3%83%AB%E3%83%80%E3%83%BC%E2%80%93%E3% From 83% 9F% E3% 83% BC% E3% 83% 89% E6% B3% 95)) For example, if the decision variable is a $ 2 $ dimensional problem, the optimal solution is searched while moving the triangle.

The specific algorithm is as follows. (Refer to reference [^ 2])

Nelder–Mead

Update the point $ x_h $ that gives the largest function value among the $ n + 1 $ vertices. At that time, the following update candidate points are calculated using the center of gravity $ c $ of $ n $ vertices excluding $ x_h $.

If none of these candidate points are good, all points except $ x_ \ ell $ are contracted closer to $ x_ \ ell $. (SHRINK)

https://codesachin.wordpress.com/2016/01/16/nelder-mead-optimization/ The figures on this blog are easy to understand for Reflect, Expand, Contract, and SHRINK.

Implementation in Python

In Python, specify method ='Nelder-Mead' in scipy.optimize.minimize It can be used by. However, in this paper, we want to use all the vertices of the triangle to create a GIF image, so we implemented it as follows.

from typing import Callable, Tuple, Union

import numpy as np


def _order(x: np.ndarray, ordering: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    indices = np.argsort(ordering)
    return x[indices], ordering[indices]

def optimize(
    fun: Callable,
    x0: np.ndarray,
    maxiter: Union[int, None] = None,
    initial_simplex: Union[np.ndarray, None] = None
):
    if x0.ndim != 1:
        raise ValueError(f'Expected 1D array, got {x0.ndim}D array instead')

    # initialize simplex
    if initial_simplex is not None:
        if initial_simplex.ndim != 2:
            raise ValueError(f'Expected 2D array, got {x0.ndim}D array instead')
        x = initial_simplex.copy()
        n = x[0].size
    else:
        h = lambda x: (x[0][x[1]] != 0) * (0.05 - 0.00025) + 0.00025
        n = x0.size
        x = np.array([x0 + h([x0, i]) * e for i, e in enumerate(np.identity(n))] + [x0])

    if maxiter is None:
        maxiter = 200 * n

    # parameters
    alpha = 1.0
    gamma = 2.0
    rho = 0.5
    sigma = 0.5

    # order
    fx = np.array(list(map(fun, x)))
    x, fx = _order(x, fx)

    # centroid
    xo = np.mean(x[:-1], axis=0)
    n_inv = 1 / n

    for _ in range(maxiter):
        fx1 = fx[0]
        fxn = fx[-2]
        fxmax = fx[-1]
        xmax = x[-1]

        xr = xo + alpha * (xo - xmax)
        fxr = fun(xr)

        if fx1 <= fxr and fxr < fxn:
            # reflect
            x[-1] = xr
            fx[-1] = fun(xr)
            x, fx = _order(x, fx)
            xo = xo + n_inv * (xr - x[-1])

        elif fxr < fx1:
            xe = xo + gamma * (xo - xmax)
            fxe = fun(xe)
            if fxe < fxr:
                # expand
                x = np.append(xe.reshape(1, -1), x[:-1], axis=0)
                fx = np.append(fxe, fx[:-1])
                xo = xo + n_inv * (xe - x[-1])
            else:
                # reflect
                x = np.append(xr.reshape(1, -1), x[:-1], axis=0)
                fx = np.append(fxr, fx[:-1])
                xo = xo + n_inv * (xr - x[-1])

        else:
            if fxr > fxmax:
                xc = xo + rho * (xmax - xo)
            else: 
                xc = xo + rho * (xr - xo)
                fxmax = fxr
            if fun(xc) < fxmax:
                # contract
                x[-1] = xc
                fx[-1] = fun(xc)
                x, fx = _order(x, fx)
                xo = xo + n_inv * (xc - x[-1])
            else:
                # shrink
                x[1:] = (1 - sigma) * x[0] + sigma * x[1:]
                fx[1:] = np.array(list(map(fun, x[1:])))
                x, fx = _order(x, fx)
                xo = np.mean(x[:-1], axis=0)

    return x, fx

We also compared it with the Scipy implementation. ($ \ mathop {\ mathrm {minimize}} _ {x, y} \ quad f (x, y) = x ^ 2 + y ^ 2 $)

from scipy.optimize import minimize

maxiter = 25

fun = lambda x: x @ x
x0 = np.array([0.08, 0.08])

# scipy
%time res = minimize(fun=fun, x0=x0, options={'maxiter': maxiter}, method='Nelder-Mead')
xopt_scipy = res.x

# implemented
%time xopt, _ = optimize(fun=fun, x0=x0, maxiter=maxiter)

print('\n')
print(f'Scipy: {xopt_scipy}')
print(f'Implemented: {xopt[0]}')

Execution result

CPU times: user 1.49 ms, sys: 41 µs, total: 1.53 ms
Wall time: 1.54 ms
CPU times: user 1.64 ms, sys: 537 µs, total: 2.18 ms
Wall time: 1.86 ms


Scipy: [-0.00026184 -0.00030341]
Implemented: [ 2.98053651e-05 -1.26493496e-05]

Creating GIF images with matplotlib

A GIF image was created using matplotlib.animation.FuncAnimation. At the time of implementation, I referred to the following article.

First, calculate the vertices of the triangle to be used. As in the previous example, the objective function is $ f (x, y) = x ^ 2 + y ^ 2 $.

maxiter = 25

fun = lambda x: x @ x
x = np.array([[0.08, 0.08], [0.13, 0.08], [0.08, 0.13]])
X = [x]
for _ in range(maxiter):
    x, fx = optimize(fun, x[0], maxiter=1, initial_simplex=x)
    X.append(x)

This saves maxiter + 1 vertices in X.

Next, create a GIF image using FuncAnimation.

FuncAnimation (fig, func, frames, fargs) creates a GIF image withfunc (frames [i], * fragments)as one frame.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import matplotlib.animation as animation

def func(x, xmin, xmax, ymin, ymax, xx, yy, vals):
    # clear the current axes
    plt.cla()
    
    # set x-axis and y-axis
    plt.xlim([xmin, xmax])
    plt.ylim([ymin, ymax])
    plt.hlines(0, xmin=xmin, xmax=xmax, colors='gray')
    plt.vlines(0, ymin=ymin, ymax=ymax, colors='gray')
    
    # set aspect
    plt.gca().set_aspect('equal', adjustable='box')
    
    # draw filled contour
    plt.contourf(xx, yy, vals, 50, cmap='Blues')
    
    # draw triangle
    plt.axes().add_patch(pat.Polygon(x, ec='k', fc='m', alpha=0.2))
    
    # draw three vertices
    plt.scatter(x[:, 0], x[:, 1], color=['r', 'g', 'b'], s=20)

n_grid=100
delta=0.005
interval=300

xmax, ymax = np.max(X, axis=(0, 1)) + delta
xmin, ymin = np.min(X, axis=(0, 1)) - delta

# function values of lattice points
xx, yy = np.meshgrid(np.linspace(xmin, xmax, n_grid), np.linspace(ymin, ymax, n_grid))
vals = np.array([fun(np.array([x, y])) for x, y in zip(xx.ravel(), yy.ravel())]).reshape(n_grid, n_grid)

fig = plt.figure(figsize=(10, 10))
ani = animation.FuncAnimation(fig=fig, func=func, frames=X, fargs=(xmin, xmax, ymin, ymax, xx, yy, vals), interval=interval)
ani.save("nelder-mead.gif", writer = 'imagemagick')

Created GIF image nelder-mead.gif

Recommended Posts

[Python] Implementation of Nelder–Mead method and saving of GIF images by matplotlib
Installation of SciPy and matplotlib (Python)
Implementation and experiment of convex clustering method
Derivation of multivariate t distribution and implementation of random number generation by python
Low-rank approximation of images by HOSVD and HOOI
Implementation of TRIE tree with Python and LOUDS
Explanation of edit distance and implementation in Python
Split Python images and arrange them side by side
[Python] Comparison of Principal Component Analysis Theory and Implementation by Python (PCA, Kernel PCA, 2DPCA)
Implementation of DB administrator screen by Flask-Admin and Flask-Login
Visualization method of data by explanatory variable and objective variable
Overview of generalized linear models and implementation in Python
Installation of matplotlib (Python 3.3.2)
Python implementation of CSS3 blend mode and talk of color space
[Deep Learning from scratch] Implementation of Momentum method and AdaGrad method
A simple Python implementation of the k-nearest neighbor method (k-NN)
Perform isocurrent analysis of open channels with Python and matplotlib
Verification and implementation of video reconstruction method using GRU and Autoencoder
Explanation and implementation of SocialFoceModel
Python implementation of particle filters
Faster loading of Python images
Maxout description and implementation (Python)
Implementation of quicksort in Python
Source installation and installation of Python
Automatic acquisition of gene expression level data by python and R
Crawling with Python and Twitter API 2-Implementation of user search function
[Python] I thoroughly explained the theory and implementation of logistic regression
[Python] I thoroughly explained the theory and implementation of decision trees
Mathematical explanation of binary search and ternary search and implementation method without bugs
Practice of data analysis by Python and pandas (Tokyo COVID-19 data edition)
Aligning scanned images of animated video paper using OpenCV and Python
[Recommendation] Summary of advantages and disadvantages of content-based and collaborative filtering / implementation method
Environment construction of python and opencv
Pixel manipulation of images in Python
The story of Python and the story of NaN
Explanation and implementation of PRML Chapter 4
Introduction and Implementation of JoCoR-Loss (CVPR2020)
Explanation and implementation of ESIM algorithm
[Python] font family and font with matplotlib
Expansion by argument of python dictionary
Introduction and implementation of activation function
Sorting algorithm and implementation in Python
[Python / matplotlib] Understand and use FuncAnimation
Memorandum of saving and loading model
Python implementation of self-organizing particle filters
Einsum implementation of value iterative method
Implementation of life game in Python
Explanation and implementation of simple perceptron
This and that of python properties
Installation of Python, SciPy, matplotlib (Windows)
Implementation of desktop notifications using Python
Load gif images with Python + OpenCV
Python implementation of non-recursive Segment Tree
Behavior of python3 by Sakura's server
Implementation of Light CNN (Python Keras)
Implementation of original sorting in Python
Implementation of Dijkstra's algorithm with python
[Python] Difference between function and method
[python] -1 meaning of numpy's reshape method
Coexistence of Python2 and 3 with CircleCI (1.0)
Story of power approximation by Python