[PYTHON] La matrice de confusion exige-t-elle également le rapport de chaque élément à la ligne totale?

introduction

Manuel de science des données Pythonの勉強中に思ったこと。

J'ai fait une carte thermique avec seaborn pour visualiser les erreurs de classification, mais n'est-ce pas la couleur avec le plus grand nombre dans l'ensemble? (Bien que le nombre d'échantillons devrait être le même pour toutes les catégories, des déséquilibres de données se produisent souvent)

Il serait préférable d'avoir une carte thermique qui montre le rapport de chaque élément au nombre total de lignes (= chaque classification). Je l'ai fait.

Lire des données et appliquer des algorithmes de classification

load_and_modelfitting.py


import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

#À titre d'exemple, chargez cette fois l'image des caractères manuscrits en tant que tâche de classification
digits = load_digits()
X = digits.data
y = digits.target

#Divisé pour la formation et l'évaluation
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)

#Appliquer de manière appropriée les bayes naïfs gaussiennes à l'algorithme de classification
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
accuracy_score(ytest, y_model)

Le sujet principal d'ici

Heatmap la matrice de confusion habituelle et le tableau des ratios aux lignes

create_confmrx.py


#Tableau bidimensionnel de matrice de confusion ordinaire
mat = confusion_matrix(ytest, y_model)
#Un tableau à deux dimensions qui calcule le rapport au total de chaque ligne et arrondit le troisième chiffre.
mat_dec = np.round(mat / np.sum(mat, axis=1), decimals=2)

fig, axes = plt.subplots(1, 2, figsize=(10, 10))
kwargs = dict(square=True, annot=True, cbar=False, cmap='RdPu')

#Dessinez deux cartes thermiques
for i, dat in enumerate([mat, mat_dec]):
    sns.heatmap(dat, **kwargs, ax=axes[i])

#Définir le titre du graphique, les étiquettes des axes X et Y
for ax, t in zip(axes, ['Real number', 'Percentage(per row)']):
    plt.axes(ax)
    plt.title(t)
    plt.xlabel('predicted value')
    plt.ylabel('true value')

heatmap.png

Interprétation proposée des informations de la carte thermique dessinée

en conclusion

La matrice de confusion est difficile à voir ...

Recommended Posts

La matrice de confusion exige-t-elle également le rapport de chaque élément à la ligne totale?
Obtenez l'index de chaque élément de la matrice de confusion en Python
Comment compter le nombre d'occurrences de chaque élément de la liste en Python avec poids
Essayez de résoudre les problèmes / problèmes du "programmeur matriciel" (Chapitre 1)
À propos de la matrice de confusion
Essayez de résoudre les problèmes / problèmes du "programmeur matriciel" (fonction du chapitre 0)
Obtenez le nombre d'occurrences pour chaque élément de la liste