[PYTHON] Qu'est-ce que le réglage des hyper paramètres?

introduction

Le réglage des hyperparamètres est une technique utilisée pour améliorer la précision du modèle. Si vous créez un modèle avec scikit-learn et ne définissez pas de paramètres, il sera défini avec la complexité appropriée.

Qu'est-ce qu'un hyper paramètre?

C'est un paramètre qui est spécifié avant l'entraînement et qui détermine la méthode d'entraînement, la vitesse et la complexité du modèle.

Méthode (type)

bergstra12a-04.jpg Source: http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf

Recherche de grille

C'est une méthode pour décider des candidats pour chaque paramètre et essayer toutes les combinaisons d'entre eux. Comme tous sont essayés, il n'est pas possible d'augmenter le nombre de paramètres candidats.

Recherche aléatoire

C'est une méthode pour choisir les candidats pour chaque paramètre et répéter une combinaison aléatoire de paramètres n fois. Il peut ne pas être possible de rechercher une meilleure combinaison de paramètres car nous ne les essayons pas tous.

Combinaison de paramètres

import numpy as np

params_list01 = [1, 3, 5, 7]
params_list02 = [1, 2, 3, 4, 5]

#Recherche de grille
grid_search_params = []
for p1 in params_list01:
    for p2 in params_list02:
        grid_search_params.append(p1, p2)
# append():Ajouter un élément à la fin de la liste

#Recherche aléatoire
random_search_params = []
count = 10
for i in range(count):
    p1 = np.random.choice(params_list01)  # random.choice():Obtenez le contenu du tableau au hasard
    p2 = np.random.choice(params_list02)
    random_search_params.append(p1, p2)

scikit-learn

Cliquez ici pour une référence à scikit-learn

from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
params = {
    "max_depth": [2, 4, 6, 8, None],
    "n_estimators": [50,100,200,300,400,500],
    "max_features": range(1, 11),
    "min_samples_split": range(2, 11),
    "min_samples_leaf": range(1, 11)
}

#Recherche de grille
gscv = GridSearchCV(RandomForestRegressor(), params, cv=3, n_jobs=-1, verbose=1)
gscv.fit(X_train_valid, y_train_valid)
            
print("Best score: {}".format(gscv.best_score_))
print("Best parameters: {}".format(gscv.best_params_))

#Recherche aléatoire
rscv = RandomizedSearchCV(RandomForestRegressor(), params, cv=3, n_iter=10, n_jobs=-1, verbose=1)
rscv.fit(X_train_valid, y_train_valid)
            
print("Best score: {}".format(rscv.best_score_))
print("Best parameters: {}".format(rscv.best_params_))

en conclusion

Lorsqu'on demande lequel adopter, une recherche aléatoire est effectuée, il semble qu'une bonne combinaison de paramètres puisse être trouvée efficacement.

référence

Livre: Technologie d'analyse de données qui gagne avec Kaggle (Revue technique)

Recommended Posts

Qu'est-ce que le réglage des hyper paramètres?
Réglage des hyper paramètres
Réglage de l'hyper paramètre 2
Qu'est-ce que l'espace de noms
Qu'est-ce que copy.copy ()
Qu'est-ce que dotenv?
Qu'est-ce que Linux
Qu'est-ce que le klass?
Qu'est-ce que SALOME?
Qu'est-ce que Linux?
Qu'est-ce que python
Qu'est-ce que l'hyperopt?
Qu'est-ce que Linux
Qu'est-ce que pyvenv
Qu'est-ce que __call__
Qu'est-ce que Linux
Qu'est-ce que Python
Qu'est-ce qu'une distribution?
Qu'est-ce que le F-Score de Piotroski?
Qu'est-ce que Raspberry Pi?
[Python] Qu'est-ce que Pipeline ...
Qu'est-ce que Calmar Ratio?
Qu'est-ce qu'un terminal?
[Tutoriel PyTorch ①] Qu'est-ce que PyTorch?
Qu'est-ce qu'un hacker?
Qu'est-ce que JSON? .. [Remarque]
À quoi sert Linux?
Qu'est-ce qu'un pointeur?
Qu'est-ce que l'apprentissage d'ensemble?
Qu'est-ce que TCP / IP?
Qu'est-ce que __init__.py de Python?
Qu'est-ce qu'un itérateur?
Qu'est-ce que UNIT-V Linux?
[Python] Qu'est-ce que virtualenv
Qu'est-ce que l'apprentissage automatique?
Qu'est-ce que Mini Sam ou Mini Max?
Qu'est-ce que l'analyse de régression logistique?
Quelle est la fonction d'activation?
Qu'est-ce qu'une variable d'instance?
Qu'est-ce qu'un arbre de décision?
Qu'est-ce qu'un changement de contexte?
Qu'est-ce que Google Cloud Dataflow?
[DL] Qu'est-ce que la décroissance du poids?
[Python] Python et sécurité-① Qu'est-ce que Python?
Qu'est-ce qu'un super utilisateur?
La programmation du concours, c'est quoi (bonus)
[Python] * args ** Qu'est-ce que kwrgs?
Qu'est-ce qu'un appel système
[Définition] Qu'est-ce qu'un cadre?
A quoi sert l'interface ...
Qu'est-ce que Project Euler 3 Acceleration?
Qu'est-ce qu'une fonction de rappel?
Qu'est-ce que la fonction de rappel?
Quel est votre "coefficient de Tanimoto"?
Cours de base Python (1 Qu'est-ce que Python)
[Python] Qu'est-ce qu'une fonction zip?
[Python] Qu'est-ce qu'une instruction with?
Qu'est-ce que la régression de crête de rang réduit?