[PYTHON] Der kleinste Hinweis der Welt wurde mit numpy & numba gelöst

Einführung

Ich habe das Gefühl, dass ich durch das Lösen von Mathematik mit Python erschöpft bin, aber ich habe beschlossen, einen Artikel zu schreiben, weil es sich um Code handelt, der nur in meinem Gedächtnis vorhanden ist. (Der Hauptgrund ist, dass es schmerzhaft war, einen Artikel zu schreiben, aber ich werde ihn schreiben, weil die Frühlingsferien in Corona um einen Monat verlängert wurden und ich viel Zeit habe.)

Politik

Tiefenprioritätssuche. Geben Sie in aufsteigender Reihenfolge vorübergehend Zahlen in die leeren Zellen ein, die den Anforderungen von Suguru entsprechen. Versuchen Sie, dasselbe mit anderen Quadraten zu tun, auch wenn Sie es einfügen. Wenn irgendwo ein Widerspruch besteht, korrigieren Sie die unmittelbar zuvor getroffene Annahme. Dies wird ernsthaft wiederholt. Eines Tages wird der Widerspruch nicht mehr auftreten und alle Lücken werden mit den angenommenen Werten gefüllt. Das ist die richtige Antwort.

image.png Quelle der Abbildung

Implementierung

Überprüfen Sie, ob die Annahmen den Anforderungen Deutschlands entsprechen

Um "Index" zu zählen, zählen Sie 0 bis 81 von links oben nach rechts. Am Beispiel der obigen Abbildung sind die "Indizes" der rosa Zahlen "5, 3, 6, 9, 8" im oberen linken 3x3-Block "0, 1, 9, 19, 20".

def fillable(table, index, value):
    """Überprüfen Sie in der zu berechnenden Matrix, ob der als richtige Antwort angenommene Wert die Bedingung mehrerer Deutscher erfüllt.

    Parameters
    -------
        table : np.ndarray of int
Die Matrix wird berechnet. 2D-array。
        index : int
Die Position des angenommenen Quadrats. Die erste Zeile ist 0~8,Die zweite Zeile ist 9~17, ...
        value : int
Angenommene Anzahl.

    Returns
    -------
        bool
Ist es möglich, dass die Annahme in der zu berechnenden Matrix korrekt ist?
    """

    row = index // 9
    col = index % 9

    fillable_in_row = value not in table[row, :]
    fillable_in_col = value not in table[:, col]
    fillable_in_block = value not in table[
                                     (row // 3) * 3: (row // 3 + 1) * 3,
                                     (col // 3) * 3: (col // 3 + 1) * 3,
                                     ]

    return fillable_in_row and fillable_in_col and fillable_in_block

Ich denke, es ist eine ziemlich natürliche Implementierung, da nur die grundlegende Syntax von Python verwendet wird.

Vorausgesetzt

def fill(table_flat):
    """Lösen Sie die deutsche Suche nach Tiefenpriorität.

    Parameters
    -------
        table_flat : np.ndarray of int
Das Problem mehrerer Deutscher. 1D-array

    Returns
    -------
        np.ndarray of int
Eine Matrix, der Annahmen zugeordnet sind. gestalten= (9, 9)
    """

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)

Ich werde es Zeile für Zeile erklären.

Betrachten Sie in der Reihenfolge vom Quadrat mit dem kleinsten Index.

for tmp_i, tmp_val in enumerate(table_flat):

Ich habe vergessen, es zu sagen, aber ich gehe davon aus, dass Sie es als "Frage Leerzeichen == 0" eingeben. Daher sollte dieser Prozess nur berücksichtigt werden, wenn die Zelle leer ist.

if tmp_val == 0:

Eine etwas längere Listeneinschlussnotation. Die folgenden Annahmen sollten größer sein als die aktuellen Annahmen. Ich werde jedoch die Annahme entfernen, dass es die Bedingungen von Susumu nicht erfüllt. Die Überprüfung verwendet die Funktion "ausfüllbar". Außerdem habe ich vergessen zu erwähnen, dass "Index" einfacher mit Zahlen als mit Tapples zu tun ist, sodass angenommen wird, dass die Eingabe des Problems ein 1x81-Vektor einer 9x9-Matrix (Abflachung) ist. Da es jedoch erforderlich ist, gemäß den Regeln in 3x3-Blöcken zu extrahieren, muss die Form erst zu diesem Zeitpunkt geändert werden.

fillable_vals = [val
                 for val in range(tmp_val + 1, 10)
                 if fillable(table_flat.reshape(9, 9), tmp_i, val)]

Versuchen Sie, die Anzahl der möglichen richtigen Antworten durch das Quadrat zu ersetzen.

for new_val in fillable_vals:
    table_flat[tmp_i] = new_val

** Das ist der Punkt. ** ** ** Zunächst wird die Anweisung "for / else" beschrieben. Gibt die "else" -Anweisung ein, wenn die "for" -Anweisung normal endet. Der Fall, in dem die "for" -Anweisung nicht abgeschlossen ist, ist, wenn die "for" -Anweisung noch nicht abgeschlossen ist, aber "break" auftritt. Übrigens, selbst wenn die "for" -Anweisung im Leerlauf ausgeführt wird, endet sie normal, sodass die "else" -Anweisung eingegeben wird. Hier sollten Sie also auf die Position von "Pause" achten.

Wiederholt die angenommene Matrix als Argument für die Funktion "Füllen". numpy.ndarray.all () gibt True zurück, wenn keine Zahlen ungleich Null vorhanden sind. Dieses Mal wird das Leerzeichen im Problem auf == 0 gesetzt, sodass mit anderen Worten das Verhalten von "all ()" in Silbenformulierung umformuliert werden kann. Mit anderen Worten, wenn die Silbenschrift gelöst ist, wird die Anweisung "for" beendet. Wenn es nicht gelöst ist, setzen Sie den angenommenen Massenwert auf 0 zurück und beenden Sie die Untersuchung. Wenn Sie "return" mit dem auf 0 zurückgegebenen Wert ausführen, wird dieser an "if fill (table_flat) .all ():" übertragen, und der Vorgang zum Korrigieren der Annahme wird ausgeführt.

    for ~~~Abkürzung~~~:
     ~~~Abkürzung~~~
        if fill(table_flat).all():
            break
    else:
        table_flat[tmp_i] = 0
        break
return table_flat.reshape(9, 9)

Ausführungsmethode

Sie kann ausgeführt werden, indem Sie die Funktion "Füllen" und die Funktion "Füllen" in derselben Datei anordnen und den Vektor 81. Ordnung mit der ersten Zeile in das Argument der Funktion "Füllen" einfügen.

Beschleunigen (das kleinste Hinweisproblem der Welt)

Selbst derzeit können gewöhnliche Probleme in wenigen Millisekunden bis zu einigen Sekunden gelöst werden. Er wurde jedoch von einem Problem besiegt. Es ist der kleinste Hinweis der Welt auf Deutschland. image.png

Als ich versuchte, dies mit dem aktuellen Code zu lösen, tat mir mein Computer leid. Ich beschloss, hier mit Numba zu beschleunigen.

Änderungen ① importieren

Import numba am Anfang der Datei

from numba import njit, int64, bool_

Ändern (2) Ein Dekorateur wurde hinzugefügt

Definieren Sie Eingabe- und Ausgabetypen als "@njit (Ausgabe, Eingabe)". ** Dies war diesmal der vollste Ort. ** Wenn Sie "numpy.ndarray.reshape ()" verwenden möchten, müssen Sie "int64 [:: 1]" verwenden. (Es wurde irgendwo in der offiziellen Numba-Dokumentation geschrieben. Es war hier.) Eigentlich scheint "int8" gut zu sein anstatt "int64", aber es hat nicht funktioniert, also habe ich es als Standard belassen. (Wird es als dtype = int8 angegeben?)

#Eine Funktion, die ein zweidimensionales Array von int und zwei Ints eingibt und bool zurückgibt
@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):

#Eine Funktion, die ein eindimensionales Array von int eingibt und ein zweidimensionales Array von int zurückgibt
@njit(int64[:, :](int64[::1]))
def fill(table_flat):

Ändern ③ Entfernen Sie den In-Betrieb

numba sollte die Python-Syntax ziemlich gut unterstützen, aber ich habe einen Fehler mit der in-Operation innerhalb der fillable-Funktion erhalten. (Fillable_in_row = Wert nicht in Tabelle [Zeile ,:] usw.) Daher wurde die "füllbare" Funktion ohne Verwendung der "in" -Operation erneut implementiert.

def fillable(table, index, value):

    row = index // 9
    col = index % 9
    block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)

    same_in_areas = False
    for area in [
        table[row, :], table[:, col],
        table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
    ]:
        for i in area:
            same_in_areas |= (i == value)

    return not same_in_areas

Das "in" in der "for" -Anweisung scheint in Ordnung zu sein. Geben Sie die eingegebene Zahl in die vertikalen und horizontalen Blöcke der Zelle, die Sie überprüfen möchten, in "i" ein und drehen Sie sie. Stellen Sie sicher, dass das "i" mit dem angenommenen Wert "value" übereinstimmt, dh, die gleiche Zahl kann in den vertikalen und horizontalen Blöcken vorhanden sein und möglicherweise nicht die deutsche Bedingung erfüllen. In diesem Fall ist der numba-Fehler nicht aufgetreten.

Ganzer Code

suudoku.py


def fillable(table, index, value):
    """Überprüfen Sie in der zu berechnenden Matrix, ob der als richtige Antwort angenommene Wert die Bedingung mehrerer Deutscher erfüllt.

    Parameters
    -------
        table : np.ndarray of int
Die Matrix wird berechnet. 2D-array。
        index : int
Die Position des angenommenen Quadrats. Die erste Zeile ist 0~8,Die zweite Zeile ist 9~17, ...
        value : int
Angenommene Anzahl.

    Returns
    -------
        bool
Ist es möglich, dass die Annahme in der zu berechnenden Matrix korrekt ist?
    """

    row = index // 9
    col = index % 9

    fillable_in_row = value not in table[row, :]
    fillable_in_col = value not in table[:, col]
    fillable_in_block = value not in table[
                                     (row // 3) * 3: (row // 3 + 1) * 3,
                                     (col // 3) * 3: (col // 3 + 1) * 3,
                                     ]

    return fillable_in_row and fillable_in_col and fillable_in_block


def fill(table_flat):
    """Lösen Sie die deutsche Suche nach Tiefenpriorität.

    Parameters
    -------
        table_flat : np.ndarray of int
Das Problem mehrerer Deutscher. 1D-array

    Returns
    -------
        np.ndarray of int
Eine Matrix, der Annahmen zugeordnet sind. gestalten= (9, 9)
    """

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)

suudoku_faster.py


from numba import njit, int64, bool_


@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):

    row = index // 9
    col = index % 9
    block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)

    same_in_areas = False
    for area in [
        table[row, :], table[:, col],
        table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
    ]:
        for i in area:
            same_in_areas |= (i == value)

    return not same_in_areas


@njit(int64[:, :](int64[::1]))
def fill(table_flat):

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)

suudoku_input.py


from suudoku import fill as fill_slow
from suudoku_faster import fill as fill_fast
import matplotlib.pyplot as plt
import numpy as np
from time import time

table =  [0, 0, 0, 8, 0, 1, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 4, 3, 0,
          5, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 7, 0, 8, 0, 0,
          0, 0, 0, 0, 0, 0, 1, 0, 0,
          0, 2, 0, 0, 3, 0, 0, 0, 0,
          6, 0, 0, 0, 0, 0, 0, 7, 5,
          0, 0, 3, 4, 0, 0, 0, 0, 0,
          0, 0, 0, 2, 0, 0, 6, 0, 0]

#Die erste Runde von numba enthält nicht nur die Berechnungszeit, sondern auch die Kompilierungszeit. Lassen Sie sie also einmal im Leerlauf laufen.
fill_fast(np.array(table).copy())

start = time()
ans = fill_fast(np.array(table).copy())
finish = time()

print("Antworten")
print(ans)
print()
print(f"Berechnungszeit(sec): {finish - start:5.7}")

fig, ax = plt.subplots(figsize=(6, 6))
fig.patch.set_visible(False)
ax.axis("off")

colors = np.where(table, "#F5A9BC", "#ffffff").reshape(9, 9)
picture = ax.table(cellText=ans, cellColours=colors, loc='center')
picture.set_fontsize(25)
picture.scale(1, 3)

ax.axhline(y=0.340, color="b")
ax.axhline(y=0.665, color="b")
ax.axvline(x=0.335, color="b")
ax.axvline(x=0.665, color="b")

fig.tight_layout()
plt.show()

Ausgabe

Wenn es ein normales Problem ist, dauert es ungefähr 0,01 Sekunden, aber welche Art von Problem dauert 100 Sekunden ... Diese schwierigste Zahl der Welt kann übrigens normal gelöst werden.

Antworten
[[2 3 4 8 9 1 5 6 7]
 [1 6 9 7 2 5 4 3 8]
 [5 7 8 3 4 6 9 1 2]
 [3 1 6 5 7 4 8 2 9]
 [4 9 7 6 8 2 1 5 3]
 [8 2 5 1 3 9 7 4 6]
 [6 4 2 9 1 8 3 7 5]
 [9 5 3 4 6 7 2 8 1]
 [7 8 1 2 5 3 6 9 4]]

Berechnungszeit(sec): 101.6835

image.png

Recommended Posts

Der kleinste Hinweis der Welt wurde mit numpy & numba gelöst
Suchen Sie mit numpy den kleinsten Index, der den kumulativen Summenschwellenwert erfüllt
2016 Todai Mathematik mit Python gelöst
Finden Sie mit NumPy die Position über dem Schwellenwert