[PYTHON] Beispiele und Lösungen, die mit scipy.optimize.least_squares nicht gut optimiert werden können

Mit scipy können Sie die Parameter einer nichtlinearen Funktion mithilfe von "optimize.least_squares" an Ihre Daten anpassen. Abhängig von der Form der nichtlinearen Funktion ist es jedoch möglicherweise nicht möglich, die optimalen Parameter zu finden. Dies liegt daran, dass optimize.least_squares nur lokale optimale Lösungen finden kann.

Dieses Mal werde ich ein Beispiel geben, in dem optimize.least_squares in eine lokal optimale Lösung fällt, und versuchen, mit optimize.basinhopping eine global optimale Lösung zu finden.

Ausführung:

Beispiele, die nicht gut optimiert werden können

Betrachten Sie die folgende Funktion mit $ a $ als Parameter.

y(x)=\frac{1}{100}(x-3a)(2x-a)(3x+a)(x+2a)

Angenommen, Sie erhalten verrauschte Daten, wenn $ a = 2 $.


import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

seed = 0
np.random.seed(seed)

def y(x, a):
    return (x-3.*a) * (2.*x-a) * (3.*x+a) * (x+2.*a) / 100.

a_orig = 2.
xs = np.linspace(-5, 7, 1000)
ys = y(xs,a_orig)

num_data = 30
data_x = np.random.uniform(-5, 5, num_data)
data_y = y(data_x, a_orig) + np.random.normal(0, 0.5, num_data)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(data_x, data_y, 'o', label='data')
plt.legend()

qiita_1.png

Versuchen Sie andererseits, die Parameter mit optimize.least_squares zu finden.

from scipy.optimize import least_squares

def calc_residuals(params, data_x, data_y):
    model_y = y(data_x, params[0])
    return model_y - data_y

a_init = -3
res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_2.png

Ich habe den Anfangswert des Parameters auf $ a_0 = -3 $ gesetzt, aber er passte nicht gut zu den Daten.

Warum Sie nicht gut optimieren können

Betrachten Sie, wie sich das Ergebnis abhängig vom Anfangswert des Parameters ändert.

a_inits = np.linspace(-4, 4, 1000)
a_fits = np.zeros(1000)
for i, a_init in enumerate(a_inits):
    res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))
    a_fits[i] = res.x[0]

plt.plot(a_inits, a_fits)
plt.xlabel("initial value")
plt.ylabel("optimized value")

qiita_3.png

Wenn der Anfangswert negativ ist, sind Sie in die lokal optimalen Parameter gefallen. Der Grund dafür lässt sich anhand der Beziehung zwischen Parameterwerten und Residuen erkennen. Wie in der folgenden Abbildung gezeigt, gibt es zwei Mindestwerte für den Parameter, sodass sich das Ergebnis abhängig vom Anfangswert ändert.

def calc_cost(params, data_x, data_y):
    residuals = calc_residuals(params, data_x, data_y)
    return (residuals * residuals).sum()

costs = np.zeros(1000)
for i, a in enumerate(a_inits):
    costs[i] = calc_cost(np.array([a]), data_x, data_y)
plt.plot(a_inits, costs)
plt.xlabel("parameter")
plt.ylabel("sum of squares")

qiita_4.png

Wie man gut optimiert

Um die optimalen Parameter global zu finden, reicht es aus, aus verschiedenen Anfangswerten zu berechnen. Es gibt ein "Optimize.basinhopping" in scipy, um dies gut zu machen. Machen wir das.

from scipy.optimize import basinhopping
a_init = -3.0
minimizer_kwargs = {"args":(data_x, data_y)}
res = basinhopping(calc_cost, np.array([a_init]),stepsize=2.,minimizer_kwargs=minimizer_kwargs)
print(res.x)

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit by basin-hopping a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_5.png

Die Parameter wurden erfolgreich erhalten. Der Trick ist das Argument "stepize". Legt fest, um wie viel dieses Argument den Anfangswert ändert.

Recommended Posts

Beispiele und Lösungen, die mit scipy.optimize.least_squares nicht gut optimiert werden können
Angelegenheiten, die mit sklearn nicht importiert werden können
Zusammenfassung von Beispielen, die nicht rückwärts pyTorch sein können
Importieren Sie Bibliotheken, die mit PyCharm nicht per Pip installiert werden können
Optionen bei der Installation von Bibliotheken, die nicht in pyenv weitergeleitet werden können
Behandlung des Fehlers, dass ein HTTP-Abruffehler in gpg auftritt und der Schlüssel nicht abgerufen werden kann
Konvertieren Sie Wetterdaten im GRIB2-Format, die mit pygrib nicht geöffnet werden können, in netCDF und visualisieren Sie sie
Parallele Berechnung (Pathos) beim Umgang mit Objekten, die nicht eingelegt werden können
Maßnahmen, um SSL nicht mit Pycharm installieren oder importieren zu können
Es wurde ein Fehler behoben, bei dem node.surface mit python3 + mecab nicht abgerufen werden konnte
Das matplotlib-Bild kann nicht umbenannt und gespeichert werden
Python-Modul mit "- (Bindestrich)" kann nicht gelöscht werden
Über localhost: Auf 4040 kann nicht zugegriffen werden, nachdem Spark mit Docker ausgeführt wurde