[PYTHON] Does the confusion matrix also need the ratio of each element to the row total?

Introduction

Python Data Science Handbookの勉強中に思ったこと。

I made a heat map with seaborn to visualize misclassification, but isn't this the color with the larger number in the whole? (It is good if the number of samples is the same for all classifications, but data imbalances often occur)

It would be better to have a heat map that shows the ratio of each element to the total number of rows (= each classification). I made it.

Data reading and application of classification algorithms

load_and_modelfitting.py


import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

#As a sample, this time load an image of handwritten characters as a classification task
digits = load_digits()
X = digits.data
y = digits.target

#Divided for training and evaluation
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)

#Appropriately apply Gaussian Naive Bayes to the classification algorithm
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
accuracy_score(ytest, y_model)

The main subject from here

Heatmap the usual confusion matrix and array of ratios to rows

create_confmrx.py


#Two-dimensional array of ordinary confusion matrix
mat = confusion_matrix(ytest, y_model)
#A two-dimensional array that calculates the ratio to the total of each row and rounds off to the third decimal place.
mat_dec = np.round(mat / np.sum(mat, axis=1), decimals=2)

fig, axes = plt.subplots(1, 2, figsize=(10, 10))
kwargs = dict(square=True, annot=True, cbar=False, cmap='RdPu')

#Draw two heatmaps
for i, dat in enumerate([mat, mat_dec]):
    sns.heatmap(dat, **kwargs, ax=axes[i])

#Set graph title, x-axis and y-axis labels
for ax, t in zip(axes, ['Real number', 'Percentage(per row)']):
    plt.axes(ax)
    plt.title(t)
    plt.xlabel('predicted value')
    plt.ylabel('true value')

heatmap.png

Interpretation plan of drawn heat map information

--Overview of the whole --Focusing on the light-colored elements on the diagonal line on the left side (ordinary heat map), identify "the line with the most loss in the whole" --Category 2,9,4,0 is applicable --Look at the heat map on the right side and check "elements that are often lost even in line units" --Category 2 and 9 are likely to be applicable, so increase the number of samples or prioritize tuning. --Conversely, classifications 4 and 0 have a low tuning priority because the color is light when looking only at the left side, but the color is dark when looking at the right side.

in conclusion

The confusion matrix is awkward to see ...

Recommended Posts

Does the confusion matrix also need the ratio of each element to the row total?
Get the index of each element of the confusion matrix in Python
How to count the number of occurrences of each element in the list in Python with weight
Try to solve the problems / problems of "Matrix Programmer" (Chapter 1)
About the confusion matrix
Try to solve the problems / problems of "Matrix Programmer" (Chapter 0 Functions)
Get the number of occurrences for each element in the list