Как построить матрицу путаницы без цветового кодирования

Из всех ответов, которые я вижу в stackoverflow, такие как 1, 2 и 3 имеют цветовую кодировку.

В моем случае я бы не хотел, чтобы он был цветным, тем более что мой набор данных в значительной степени несбалансирован, классы меньшинств всегда отображаются светлым цветом. Вместо этого я бы предпочел отображать количество фактических/прогнозируемых в каждой ячейке.

В настоящее время я использую:

def plot_confusion_matrix(cm, classes, title,
                          normalize=False,
                          file='confusion_matrix',
                          cmap=plt.cm.Blues):
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_title = "Normalized confusion matrix"
    else:
        cm_title = title

    # print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(cm_title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.3f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment = "center",
                 color = "white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True class')
    plt.xlabel('Predicted class')
    plt.tight_layout()
    plt.savefig(file + '.png')

Выход:

Как построить матрицу путаницы без цветового кодирования

Поэтому я хочу, чтобы отображался только номер.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
83
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Вы можете использовать ListedColormap только с одним цветом для палитры. Использование Сиборн автоматизирует многие вещи, в том числе:

  • установка аннотаций в правильном месте, с черным или белым цветом в зависимости от темноты ячейки
  • некоторые параметры для установки разделительных линий
  • параметры для установки галочек
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
import pandas as pd
import seaborn as sns

def plot_confusion_matrix(cm, classes, title,
                          normalize=False, file='confusion_matrix', background='aliceblue'):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        plt.title("Normalized confusion matrix")
    else:
        plt.title(title)

    fmt = '.3f' if normalize else 'd'
    sns.heatmap(np.zeros_like(cm), annot=cm, fmt=fmt,
                xticklabels=classes, yticklabels=classes,
                cmap=ListedColormap([background]), linewidths=1, linecolor='navy', clip_on=False, cbar=False)
    plt.tick_params(axis='x', labelrotation=30)

    plt.tight_layout()
    plt.ylabel('True class')
    plt.xlabel('Predicted class')
    plt.tight_layout()
    plt.savefig(file + '.png')

cm = np.random.randint(1, 20000, (5, 5))
plot_confusion_matrix(cm, [*'abcde'], 'title')

heatmap with single color

Ответ принят как подходящий

Используйте seaborn.heatmap с цветовой картой в оттенках серого и установите vmin=0, vmax=0:

import seaborn as sns

sns.heatmap(cm, fmt='d', annot=True, square=True,
            cmap='gray_r', vmin=0, vmax=0,  # set all to white
            linewidths=0.5, linecolor='k',  # draw black grid lines
            cbar=False)                     # disable colorbar

# re-enable outer spines
sns.despine(left=False, right=False, top=False, bottom=False)

Полная функция:

def plot_confusion_matrix(cm, classes, title,
                          normalize=False,
                          file='confusion_matrix',
                          cmap='gray_r',
                          linecolor='k'):
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_title = 'Confusion matrix, with normalization'
    else:
        cm_title = title

    fmt = '.3f' if normalize else 'd'
    sns.heatmap(cm, fmt=fmt, annot=True, square=True,
                xticklabels=classes, yticklabels=classes,
                cmap=cmap, vmin=0, vmax=0,
                linewidths=0.5, linecolor=linecolor,
                cbar=False)
    sns.despine(left=False, right=False, top=False, bottom=False)

    plt.title(cm_title)
    plt.ylabel('True class')
    plt.xlabel('Predicted class')
    plt.tight_layout()
    plt.savefig(f'{file}.png')

Другие вопросы по теме