scikit-learn ist Pythons fast de facto Bibliothek für maschinelles Lernen. Der Vorteil von Scikit-Learn besteht darin, dass viele Algorithmen implementiert sind, diese jedoch konsistent gestaltet sind und verschiedene Algorithmen auf gemeinsame Weise verarbeiten können. Wenn Sie einen neuen Algorithmus implementieren, der nicht in scikit-learn enthalten ist, oder wenn Sie ihn so implementieren, dass er bei Verwendung anderer Bibliotheken wie andere Schätzer von sciki-learn behandelt werden kann, wird er wie der ursprünglich implementierte Schätzer kreuzvalidiert. Sie können die Leistung bewerten und die Parameter durch Rastersuche optimieren. Hier ist die Implementierung des Mindestschätzers. Hier betrachten wir Diskriminatoren oder Regressionen als Ziele (nicht Clustering oder unbeaufsichtigtes Lernen).
from sklearn.base import BaseEstimator
class MyEstimator(BaseEstimator):
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
def fit(self, x, y):
return self
def predict(self, x):
return [1.0]*len(x)
def score(self, x, y):
return 1
def get_params(self, deep=True):
return {'param1': self.param1, 'param2': self.param2}
def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self,parameter, value)
return self
Definieren Sie die Schätzerklasse, indem Sie "sklearn.base.BaseEstimator" erben. Bitte schreiben Sie den Inhalt der Methode entsprechend um.
Cross validation:
x = [[2,3],[4,5],[6,1],[2,0]]
y = [0.0,9.4,2.1,0.9]
estimator = MyEstimator()
cross_validation.cross_val_score(estimator,x,y,cv=3)
Result:
array([ 1., 1., 1.])
Grid search:
gs = grid_search.GridSearchCV(estimator, {'param1': [0,10], 'param2': (1, 1e-1, 1e-2)})
gs.fit(x,y)
gs.best_estimator_, gs.best_params_, gs.best_score_
Result:
(MyEstimator(), {'param1': 0, 'param2': 1}, 1.0)
cross_validation
Um eine Kreuzvalidierung durchzuführen, sind die "Fit" -Methode zum Lernen von Trainingsdaten und die "Score" -Methode zum Eingeben von Testdaten, zum Vergleichen des daraus geschätzten Werts mit dem richtigen Antwortwert und zum Ausgeben des Scores erforderlich.
fit(self, x, y)
Es ist eine Funktion, die lernt, so dass die Ausgabe "y" für die Eingabe "x" ist.
predict(self, x)
Eine Funktion, deren Ausgabe "y_pred" für die Eingabe "x" zurückgibt. Sie brauchen keine "Vorhersage", wenn Sie nur "cross_validation" durchführen möchten, aber in den meisten Fällen nennen Sie "Vorhersage" innerhalb der "Punktzahl". Durch Implementieren mehrerer Vererbungen von "sklearnbase.ClassifierMixin" und "scikit-learn.base.RegressionMixin" und Implementieren nur "Vorhersagen" können Sie die implementierte "Score" -Funktion verwenden.
score(self, x, y)
Es ist eine Funktion, die die Ausgabe "y_pred" für die Eingabe "x" schätzt, die "y_pred" mit der richtigen Antwort "y" vergleicht und die Punktzahl zurückgibt (unabhängig davon, ob der Fehler oder die Bezeichnung übereinstimmt usw.).
grid_search
Um grid_search
durchzuführen, müssen Sie die Parameter zusätzlich zum Lernen und Berechnen der Punktzahl wie oben definiert manipulieren. Implementieren Sie die Methode get_params
, um datenunabhängige Parameter abzurufen, und die Methode set_params
, um Parameter festzulegen.
get_params(self, deep=True)
In der Methode "get_params" ist der Parameterschlüssel der Attributname. Versuchen Sie, ein Wörterbuch zurückzugeben, in dem value ein Wert ist.
set_params(self, **parameters)
Es ist ein Parametersetzer. Übergeben Sie es in einem Wörterbuch wie "get_params".
Mixin
Die implementierte Methode kann durch Mehrfachvererbung von "sklearn.base.ClassifierMixin" für das Identifikationsmodell und "sklearn.base.RegressorMixin" für das Regressionsmodell verwendet werden. Wenn Sie diese erben
sklearn.base.ClassifierMixin
_estimator_type
auf classifier
oder regressor
sklearn.base.RegressorMixin
_estimator_type
auf regressor
Sie können mit sklearn.utils.estimator_checks.check_estimator
überprüfen, ob Ihr eigener Schätzer mit sklearn kompatibel ist. Übrigens erhalte ich in dem in diesem Artikel gezeigten Beispiel die Fehlermeldung, dass die Eingabe nicht validiert ist. Es sollte kein Problem geben, wenn Sie es selbst verwenden.
--Erstellen Sie Ihre eigene Schätzerklasse, indem Sie "sklearn.base.BaseEstimator" erben
fit
, score
Methoden, um cross_validation
durchzuführengrid_search
auszuführen, benötigen Sie mehr get_params
- und set_params
-Methoden.Das meiste, was ich hier geschrieben habe API-Referenz für das sklearn.base-Modul Informationen für Entwickler auf der offiziellen Website Wird bezeichnet.
Recommended Posts