[PYTHON] Heat Map for Grid Search with Matplotlib

Grid Search http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html First, for grid search method, you need to select which parameters are used for the optimization and define parameter sets.

from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer

learner = RandomForestClassifier(random_state = 2)
n_estimators = [12, 24, 36, 48, 60]
min_samples_leaf = [1, 2, 4, 8, 16]
parameters = {'n_estimators': n_estimators, 'min_samples_leaf': min_samples_leaf}

In this case, AUC is used as a scorer. Thus, you need to create you own scorer for AUC.

def auc_scorer(target_score, prediction):
    auc_value = roc_auc_score(prediction, target_score)    
    return auc_value

scorer = make_scorer(auc_scorer, greater_is_better=True)

Finally, you can define Grid Search Object.

grid_obj = GridSearchCV(learner, parameters,  scorer)

Heat Map http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html To create a heat map, you need to have 2 dimentional matrix at first. From Grid Search Object, you can retrieve all prediction results corresponding to grid search parameter set. In the example below, all result are put into scores .

scores = grid_obj.cv_results_['mean_test_score'].reshape(len(n_estimators),len(min_samples_leaf))

Note: scores contains the following array.

[[ 0.91803961  0.92444425  0.9264368   0.92730609  0.92808348]
 [ 0.91263539  0.91757799  0.91892211  0.91957058  0.91950196]
 [ 0.90143663  0.90590379  0.90669241  0.90751479  0.90758263]
 [ 0.89168321  0.89370183  0.89414698  0.89497685  0.89506426]
 [ 0.88276445  0.88386261  0.88380793  0.88408826  0.88448689]]

Then, you can use scores for plotting a heat map.

plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('n_estimators')
plt.ylabel('min_samples_leaf')
plt.colorbar()
plt.xticks(np.arange(len(n_estimators)), n_estimators)
plt.yticks(np.arange(len(min_samples_leaf)), min_samples_leaf)
plt.title('Grid Search AUC Score')
plt.show()

Finally, you can plot a heat map like below.

image

Recommended Posts

Heat Map for Grid Search with Matplotlib
Grid search of hyperparameters with Scikit-learn
Search for homeomorphic idioms with opencv
[Python] Calendar-style heat map (with holiday display)
Search for files with the specified extension
Causal reasoning and causal search with Python (for beginners)
Plot ROC Curve for Binary Classification with Matplotlib
Animation with matplotlib
Japanese with matplotlib
Animation with matplotlib
Align Matplotlib graph colors with similar colors (color map)
Histogram with matplotlib
Animate with matplotlib
Visualize grib2 on a map with python (matplotlib)
[For beginners] Script within 10 lines (8. Plot map with folium [2]
(For those unfamiliar with Matplotlib) Tips for drawing graphs with Seaborn
[Boto3] Search for Cognito users with the List Users API
COVID-19 simulation with python (SIR model) ~~ with prefectural heat map ~~
Sequential search with Python
2-axis plot with Matplotlib
Heatmap with Python + matplotlib
Band graph with matplotlib
Learn with Cheminformatics Matplotlib
Binary search with python
Binary search with Python3
Database search with db.py
Real-time drawing with matplotlib
Various colorbars with Matplotlib
3D plot with matplotlib
Adjust axes with matplotlib
[NetworkX] I want to search for nodes with specific attributes
4 Techniques for Creating Diagrams for Papers and Presentation Materials with matplotlib
A must-see for those involved in Materials Informatics! Visualize compound data with a periodic table heat map.