[PYTHON] Benötigt die Verwirrungsmatrix auch das Verhältnis jedes Elements zur gesamten Zeile?

Einführung

Python Data Science Handbookの勉強中に思ったこと。

Ich habe mit Seaborn eine Heatmap erstellt, um eine Fehlklassifizierung zu visualisieren. Ist dies nicht die Farbe mit der größeren Anzahl im Ganzen? (Obwohl die Anzahl der Stichproben für alle Kategorien gleich sein sollte, treten häufig Datenungleichgewichte auf.)

Es wäre besser, eine Heatmap zu haben, die das Verhältnis jedes Elements zur Gesamtzahl der Zeilen (= jede Klassifizierung) zeigt. Ich habe es gemacht.

Daten lesen und Klassifizierungsalgorithmen anwenden

load_and_modelfitting.py


import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

#Laden Sie diesmal als Beispiel das Bild handgeschriebener Zeichen als Klassifizierungsaufgabe
digits = load_digits()
X = digits.data
y = digits.target

#Geteilt für Schulung und Bewertung
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)

#Wenden Sie Gaußsche Naive Bayes entsprechend auf den Klassifizierungsalgorithmus an
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
accuracy_score(ytest, y_model)

Das Hauptthema von hier

Heatmap der üblichen Verwirrungsmatrix und der Anordnung von Verhältnissen zu Zeilen

create_confmrx.py


#Zweidimensionales Array einer gewöhnlichen Verwirrungsmatrix
mat = confusion_matrix(ytest, y_model)
#Ein zweidimensionales Array, das das Verhältnis zur Summe jeder Zeile berechnet und die dritte Ziffer abrundet.
mat_dec = np.round(mat / np.sum(mat, axis=1), decimals=2)

fig, axes = plt.subplots(1, 2, figsize=(10, 10))
kwargs = dict(square=True, annot=True, cbar=False, cmap='RdPu')

#Zeichnen Sie zwei Heatmaps
for i, dat in enumerate([mat, mat_dec]):
    sns.heatmap(dat, **kwargs, ax=axes[i])

#Legen Sie die Beschriftung für Diagrammtitel, x-Achse und y-Achse fest
for ax, t in zip(axes, ['Real number', 'Percentage(per row)']):
    plt.axes(ax)
    plt.title(t)
    plt.xlabel('predicted value')
    plt.ylabel('true value')

heatmap.png

Interpretationsplan der gezeichneten Wärmekarteninformationen

abschließend

Die Verwirrungsmatrix ist unangenehm zu sehen ...

Recommended Posts

Benötigt die Verwirrungsmatrix auch das Verhältnis jedes Elements zur gesamten Zeile?
Ruft den Index jedes Elements der Verwirrungsmatrix in Python ab
So zählen Sie die Anzahl der Vorkommen jedes Elements in der Liste in Python mit der Gewichtung
Versuchen Sie, die Probleme des "Matrix-Programmierers" zu lösen (Kapitel 1).
Über die Verwirrungsmatrix
Versuchen Sie, die Probleme / Probleme des "Matrix-Programmierers" zu lösen (Kapitel 0-Funktion)
Ermitteln Sie die Anzahl der Vorkommen für jedes Element in der Liste