Показывать количество в каждой ячейке диаграммы

Как объяснено в histplot seaborn — выведите значения y над каждым столбцом, счетчики можно отобразить на каждом столбце в 1-мерной гистограмме с помощью .bar_label(ax.containers[0]).

Я изо всех сил пытаюсь понять, как сделать эквивалент двумерной гистограммы (созданной с помощью sns.histplot(data, x='var1', y='var2')).

Я знаю, что могу сделать аннотацию для контейнера (a,b) с помощью .annotate('foo', xy=(a, b)), но я не уверен, как получить счетчик для этого контейнера (чтобы передать его в .annotate()).

Я бы хотел, чтобы результат был похож на тот, что показан на https://seaborn.pydata.org/examples/spreadsheet_heatmap.html, за исключением того, что это гистограмма, а не тепловая карта.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
116
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

1. Работаем с np.histogram2d

Вы можете использовать np.histogram2d() для подсчета, а затем отобразить результат с помощью sns.heatmap():

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

np.random.seed(20240529)
x = np.random.randn(1000).cumsum()
y = np.random.randn(1000).cumsum()

hist, xbins, ybins = np.histogram2d(x, y)
xlabels = [f'{x:.2f}' for x in (xbins[:-1] + xbins[1:]) / 2]
ylabels = [f'{y:.2f}' for y in (ybins[:-1] + ybins[1:]) / 2]
sns.heatmap(hist.T, xticklabels=xlabels, yticklabels=ylabels, annot=True, fmt='.0f', cbar=False)

2. Извлечение информации из sns.histplot

Как упоминалось в ответе Питера, информацию о количестве и положении также можно извлечь из QuadMesh, созданного sns.histplot. Вот обобщение.

Вызов .get_array() на QuadMesh дает замаскированный массив со счетчиками каждой ячейки. Это 2D-матрица. Первая строка этой матрицы — это счетчики первой строки (наименьшее значение y) на графике.

Аналогично, .get_coordinates() обозначает позиции. Это положение не центров ячеек, а их краев. Ребер на 1 строку и 1 столбец больше, чем ячеек. Координаты организованы в виде строк и столбцов значений xy (формирующих трехмерный массив).

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

np.random.seed(20240529)
x = np.random.randn(1000).cumsum()
y = np.random.randn(1000).cumsum()

ax = sns.histplot(x=x, y=y, bins=(10, 10), cbar=False)
coords = ax.collections[0].get_coordinates()
half_width = (coords[0, 1, 0] - coords[0, 0, 0]) / 2
half_height = (coords[1, 0, 1] - coords[0, 0, 1]) / 2
for v, (xv, yv) in zip(ax.collections[0].get_array().ravel(), coords[:-1, :-1, :].reshape(-1, 2)):
    if not np.ma.is_masked(v):
        ax.text(xv + half_width, yv + half_height, f'{v:.0f}', ha='center', va='center', color='white')

Вот как это будет выглядеть с категориальными данными:

sns.set_style('white')
titanic = sns.load_dataset('titanic')
ax = sns.histplot(data=titanic, x='who', y='class', cbar=False)
coords = ax.collections[0].get_coordinates()
half_width = (coords[0, 1, 0] - coords[0, 0, 0]) / 2
half_height = (coords[1, 0, 1] - coords[0, 0, 1]) / 2
for v, (xv, yv) in zip(ax.collections[0].get_array().ravel(), coords[:-1, :-1, :].reshape(-1, 2)):
    if not np.ma.is_masked(v):
        ax.text(xv + half_width, yv + half_height, f'{v:.0f}', ha='center', va='center', color='white')
ax.margins(x=0, y=0)  # remove unneeded whitespace
plt.tight_layout()

Спасибо! К сожалению, я не думаю, что смогу сделать простой np.histogram2d, поскольку мне нужно подсчитывать нечисловые значения (т. е. нечисловые метки ячеек).

Peter Thomassen 29.05.2024 11:19
Ответ принят как подходящий

Я обнаружил, что счетчики можно извлечь из атрибута histplot().collections, а затем .annotate(), следующим образом:

import numpy as np
import seaborn as sns

ax = sns.histplot(data, x='var1', y='var2', cbar=True)
w = ax.collections[0].get_coordinates().shape[1] - 1
for k, v in enumerate(ax.collections[0].get_array()):
    if not np.ma.is_masked(v):
        ax.annotate(v, xy=(k % w, k // w), ha='center', color='white')

Можете ли вы добавить некоторые воспроизводимые тестовые данные? Кажется, это не работает с моими тестовыми данными.

JohanC 29.05.2024 11:49

Чтобы ваш код работал с категориальными данными со значением x более 1, вы можете изменить его на for k, v in enumerate(ax.collections[0].get_array().ravel()): if not np.ma.is_masked(v): ax.annotate(f'{v:.0f}', ...)

JohanC 29.05.2024 13:18

Другие вопросы по теме