Как раскрасить гистограмму matplotlib по значениям в другом столбце фрейма данных

Я создаю фрейм данных, используя

import pandas as pd 
import matplotlib.pyplot as plt 

df_dict = {
    "test_predictions": [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
    "y_true": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1],
    "distance" : [-0.1, -0.09, -0.08, -0.08, -0.07, -0.05, -0.05, -0.05, -0.05, -0.05, -0.04, -0.04, -0.04, -0.03, -0.02, -0.01, 0.01, 0.01, 0.01, 0.02, 0.03, 0.03, 0.04, 0.05, 0.05, 0.06, 0.06, 0.07, 0.08, 0.08, 0.09, 0.1]
}
df = pd.DataFrame(df_dict)

Затем я создаю график, состоящий из двух прямых графиков и гистограммы, используя

fig, ax1 = plt.subplots()
ax1.plot([0, 1], [0, 1], color = "red", linestyle = ":", label = "Perfect Model")
ax1.plot(df['test_predictions'], df['y_true'], label = "NN3", color='blue')
ax2 = ax1.twinx()
ax2.hist(df['test_predictions'], bins=10, alpha=0.7, color='darkgreen', label='Histogram')

Я хотел бы раскрасить гистограмму на основе значений в df['distance'], а также включить карту цветов. Таким образом, по сути, в одном интервале гистограммы может быть несколько цветов. Любая помощь будет высоко ценится. Спасибо!

Редактировать:

Я пробовал это сделать раньше ax2.hist(df['test_predictions'], bins=10, alpha=0.7, color='darkgreen', label='Histogram')

bins = np.linspace(df['test_predictions'].min(), df['test_predictions'].max(), 10)

for index, row in df.iterrows():
    bin_index = np.digitize(row['test_predictions'], bins)
    color = plt.cm.viridis(row['distance']/df['distance'].max())
    ax2.bar(bins[bin_index-1], 1, width=np.diff(bins)[0], color = color, alpha  = 0.7)

Тем не менее, меня беспокоит тот факт, что я использую цикл for, потому что мой фактический фрейм данных может содержать более 10000 строк, а также я не получаю желаемый результат, когда делаю это, он получается в виде отдельных столбцов, а не сложенных, как гистограмма.

Вы смотрели галерею matplotlib?

import random 11.03.2024 22:39

@importrandom да, я пытался сделать что-то похожее на то, что они делают в цикле for раздела обновления цветов гистограммы, но обнаружил, что могу обновлять цвет только одного столбца за раз. В то время как в моем случае, поскольку я раскрашиваю на основе значений в другом столбце, я мог бы использовать несколько цветов в одной полосе.

Caesar 11.03.2024 22:47
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
2
88
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Пример данных из поста

Вы можете использовать imshow для значений расстояний, принадлежащих каждой полосе.

Код ниже сначала создает дополнительный столбец данных с идентификаторами ячеек каждой строки. Затем выбираются расстояния для каждого интервала и используются в качестве входных данных для imshow().

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import pandas as pd
import numpy as np

df_dict = {
    "test_predictions": [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
    "y_true": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1],
    "distance": [-0.1, -0.09, -0.08, -0.08, -0.07, -0.05, -0.05, -0.05, -0.05, -0.05, -0.04, -0.04, -0.04, -0.03, -0.02, -0.01, 0.01, 0.01, 0.01, 0.02, 0.03, 0.03, 0.04, 0.05, 0.05, 0.06, 0.06, 0.07, 0.08, 0.08, 0.09, 0.1]
}
df = pd.DataFrame(df_dict)

cmap = plt.get_cmap('RdYlBu')
norm = plt.Normalize(vmin=df['distance'].min(), vmax=df['distance'].max())

num_bins = 10
bins = np.linspace(df['test_predictions'].min(), df['test_predictions'].max() + 0.001, num_bins + 1)
df['bin'] = np.digitize(df['test_predictions'], bins)

fig, ax1 = plt.subplots()
for bin_id, bin_df in df.groupby('bin'):
    ax1.imshow(bin_df['distance'].values. Reshape(-1, 1), interpolation='nearest', cmap=cmap, norm=norm,
               extent=[bins[bin_id - 1], bins[bin_id], 0, len(bin_df)], aspect='auto')

ax1.use_sticky_edges = False # remove stickiness due to imshow
ax1.autoscale_view()
ax1.set_ylim(ymin=0)

plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), label='Distance', ax=ax1)
plt.tight_layout()
plt.show()

Набор данных «советы»

Вот еще один пример с набором данных «советы» от Seaborn.

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import seaborn as sns  # to get the 'tips' dataset
import numpy as np

df = sns.load_dataset('tips')
df.sort_values(by='tip', ascending=True, inplace=True, ignore_index=True)

cmap = plt.get_cmap('RdYlBu_r')
norm = plt.Normalize(vmin=df['tip'].min(), vmax=df['tip'].max())

num_bins = 10
bins = np.linspace(df['total_bill'].min(), df['total_bill'].max() + 0.001, num_bins + 1)
df['bin'] = np.digitize(df['total_bill'], bins)

fig, ax1 = plt.subplots()
for bin_id, bin_df in df.groupby('bin'):
    ax1.imshow(bin_df['tip'].values. Reshape(-1, 1), interpolation='nearest', cmap=cmap, norm=norm,
               extent=[bins[bin_id - 1], bins[bin_id], 0, len(bin_df)], aspect='auto')

ax1.use_sticky_edges = False  # remove stickiness due to imshow
ax1.autoscale_view()
ax1.set_ylim(ymin=0)
ax1.set_xlabel('Total Bill')
ax1.set_ylabel('Count')

plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), label='Tip', ax=ax1)
plt.tight_layout()
plt.show()

Другие вопросы по теме