[Python] Implementation of Nelder–Mead method and saving of GIF images by matplotlib

This paper describes the Nelder–Mead method, which is known as an optimization algorithm.

Overview and implementation of the Nelder–Mead method

Algorithm overview

The Nelder–Mead method is an algorithm [^ 1] published by John A. Nelder and Roger Mead in 1965, which makes a $ n $ dimensional simplex consisting of $ n + 1 $ vertices like an amoeba. Search for the minimum value of the function while moving it. ([Wikipedia](https://en.wikipedia.org/wiki/%E3%83%8D%E3%83%AB%E3%83%80%E3%83%BC%E2%80%93%E3% From 83% 9F% E3% 83% BC% E3% 83% 89% E6% B3% 95)) For example, if the decision variable is a $ 2 $ dimensional problem, the optimal solution is searched while moving the triangle.

The specific algorithm is as follows. (Refer to reference [^ 2])


Update the point $ x_h $ that gives the largest function value among the $ n + 1 $ vertices. At that time, the following update candidate points are calculated using the center of gravity $ c $ of $ n $ vertices excluding $ x_h $.

If none of these candidate points are good, all points except $ x_ \ ell $ are contracted closer to $ x_ \ ell $. (SHRINK)

https://codesachin.wordpress.com/2016/01/16/nelder-mead-optimization/ The figures on this blog are easy to understand for Reflect, Expand, Contract, and SHRINK.

Implementation in Python

In Python, specify method ='Nelder-Mead' in scipy.optimize.minimize It can be used by. However, in this paper, we want to use all the vertices of the triangle to create a GIF image, so we implemented it as follows.

from typing import Callable, Tuple, Union

import numpy as np

def _order(x: np.ndarray, ordering: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    indices = np.argsort(ordering)
    return x[indices], ordering[indices]

def optimize(
    fun: Callable,
    x0: np.ndarray,
    maxiter: Union[int, None] = None,
    initial_simplex: Union[np.ndarray, None] = None
    if x0.ndim != 1:
        raise ValueError(f'Expected 1D array, got {x0.ndim}D array instead')

    # initialize simplex
    if initial_simplex is not None:
        if initial_simplex.ndim != 2:
            raise ValueError(f'Expected 2D array, got {x0.ndim}D array instead')
        x = initial_simplex.copy()
        n = x[0].size
        h = lambda x: (x[0][x[1]] != 0) * (0.05 - 0.00025) + 0.00025
        n = x0.size
        x = np.array([x0 + h([x0, i]) * e for i, e in enumerate(np.identity(n))] + [x0])

    if maxiter is None:
        maxiter = 200 * n

    # parameters
    alpha = 1.0
    gamma = 2.0
    rho = 0.5
    sigma = 0.5

    # order
    fx = np.array(list(map(fun, x)))
    x, fx = _order(x, fx)

    # centroid
    xo = np.mean(x[:-1], axis=0)
    n_inv = 1 / n

    for _ in range(maxiter):
        fx1 = fx[0]
        fxn = fx[-2]
        fxmax = fx[-1]
        xmax = x[-1]

        xr = xo + alpha * (xo - xmax)
        fxr = fun(xr)

        if fx1 <= fxr and fxr < fxn:
            # reflect
            x[-1] = xr
            fx[-1] = fun(xr)
            x, fx = _order(x, fx)
            xo = xo + n_inv * (xr - x[-1])

        elif fxr < fx1:
            xe = xo + gamma * (xo - xmax)
            fxe = fun(xe)
            if fxe < fxr:
                # expand
                x = np.append(xe.reshape(1, -1), x[:-1], axis=0)
                fx = np.append(fxe, fx[:-1])
                xo = xo + n_inv * (xe - x[-1])
                # reflect
                x = np.append(xr.reshape(1, -1), x[:-1], axis=0)
                fx = np.append(fxr, fx[:-1])
                xo = xo + n_inv * (xr - x[-1])

            if fxr > fxmax:
                xc = xo + rho * (xmax - xo)
                xc = xo + rho * (xr - xo)
                fxmax = fxr
            if fun(xc) < fxmax:
                # contract
                x[-1] = xc
                fx[-1] = fun(xc)
                x, fx = _order(x, fx)
                xo = xo + n_inv * (xc - x[-1])
                # shrink
                x[1:] = (1 - sigma) * x[0] + sigma * x[1:]
                fx[1:] = np.array(list(map(fun, x[1:])))
                x, fx = _order(x, fx)
                xo = np.mean(x[:-1], axis=0)

    return x, fx

We also compared it with the Scipy implementation. ($ \ mathop {\ mathrm {minimize}} _ {x, y} \ quad f (x, y) = x ^ 2 + y ^ 2 $)

from scipy.optimize import minimize

maxiter = 25

fun = lambda x: x @ x
x0 = np.array([0.08, 0.08])

# scipy
%time res = minimize(fun=fun, x0=x0, options={'maxiter': maxiter}, method='Nelder-Mead')
xopt_scipy = res.x

# implemented
%time xopt, _ = optimize(fun=fun, x0=x0, maxiter=maxiter)

print(f'Scipy: {xopt_scipy}')
print(f'Implemented: {xopt[0]}')

Execution result

CPU times: user 1.49 ms, sys: 41 µs, total: 1.53 ms
Wall time: 1.54 ms
CPU times: user 1.64 ms, sys: 537 µs, total: 2.18 ms
Wall time: 1.86 ms

Scipy: [-0.00026184 -0.00030341]
Implemented: [ 2.98053651e-05 -1.26493496e-05]

Creating GIF images with matplotlib

A GIF image was created using matplotlib.animation.FuncAnimation. At the time of implementation, I referred to the following article.

First, calculate the vertices of the triangle to be used. As in the previous example, the objective function is $ f (x, y) = x ^ 2 + y ^ 2 $.

maxiter = 25

fun = lambda x: x @ x
x = np.array([[0.08, 0.08], [0.13, 0.08], [0.08, 0.13]])
X = [x]
for _ in range(maxiter):
    x, fx = optimize(fun, x[0], maxiter=1, initial_simplex=x)

This saves maxiter + 1 vertices in X.

Next, create a GIF image using FuncAnimation.

FuncAnimation (fig, func, frames, fargs) creates a GIF image withfunc (frames [i], * fragments)as one frame.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import matplotlib.animation as animation

def func(x, xmin, xmax, ymin, ymax, xx, yy, vals):
    # clear the current axes
    # set x-axis and y-axis
    plt.xlim([xmin, xmax])
    plt.ylim([ymin, ymax])
    plt.hlines(0, xmin=xmin, xmax=xmax, colors='gray')
    plt.vlines(0, ymin=ymin, ymax=ymax, colors='gray')
    # set aspect
    plt.gca().set_aspect('equal', adjustable='box')
    # draw filled contour
    plt.contourf(xx, yy, vals, 50, cmap='Blues')
    # draw triangle
    plt.axes().add_patch(pat.Polygon(x, ec='k', fc='m', alpha=0.2))
    # draw three vertices
    plt.scatter(x[:, 0], x[:, 1], color=['r', 'g', 'b'], s=20)


xmax, ymax = np.max(X, axis=(0, 1)) + delta
xmin, ymin = np.min(X, axis=(0, 1)) - delta

# function values of lattice points
xx, yy = np.meshgrid(np.linspace(xmin, xmax, n_grid), np.linspace(ymin, ymax, n_grid))
vals = np.array([fun(np.array([x, y])) for x, y in zip(xx.ravel(), yy.ravel())]).reshape(n_grid, n_grid)

fig = plt.figure(figsize=(10, 10))
ani = animation.FuncAnimation(fig=fig, func=func, frames=X, fargs=(xmin, xmax, ymin, ymax, xx, yy, vals), interval=interval)
ani.save("nelder-mead.gif", writer = 'imagemagick')

Created GIF image nelder-mead.gif

