Как создать в Matplotlib график, похожий на роевой, но с перекрывающимися точками?

Я пытаюсь создать (своего рода) роевой график, который должен четко показывать формы распределений, но позволять быстро строить графики десятков тысяч точек данных путем перекрытия точечных представлений точек данных. Например:

Моя идея состоит в том, чтобы по существу создать точечный график, но разделить каждое распределение на квантили и применить джиттер к горизонтальным положениям точек данных, величина которых пропорциональна количеству точек в данном квартиле. Это отлично работает, когда распределения имеют одинаковый размер, но мне нужен какой-то способ масштабирования джиттера, чтобы, когда одно из распределений имеет только несколько точек данных, представляющие их точки были расположены на (почти) вертикальной линии, т.е. НЕ, как показано ниже. :

Вот мой код для создания сюжета:

import matplotlib.pyplot as plt
import numpy as np


def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                            number_of_segments=12,
                            separation_between_plots=0.1,
                            separation_between_subplots=0.1,
                            vertical_limits=None,
                            grid=False,
                            remove_outlier_above_segment=None,
                            remove_outlier_below_segment=None,
                            y_label=None,
                            title=None):
    fig, ax = plt.subplots()

    number_of_plots = len(distributions)
    # print(f" number of plots {number_of_plots}")
    # print(f" max x line {number_of_plots * (max_plot_wwidth + separation_between_plots) + separation_between_plots}")

    ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)

    ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
             for i in range(0, number_of_plots)]
    print(ticks)

    for i in range(len(distributions)):
        distribution = distributions[i]
        # print(f"distribution {distribution}")
        segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
        # print(f"segments {segments}")
        segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
        # print(f"quantile indices {segment_indices}")
        if remove_outlier_above_segment:
            a = remove_outlier_above_segment[i]
            distribution = distribution[segment_indices <= a]
            segment_indices = segment_indices[segment_indices <= a]

        if remove_outlier_below_segment:
            b = remove_outlier_below_segment[i]
            distribution = distribution[segment_indices >= b - 1]
            segment_indices = segment_indices[segment_indices >= b - 1]

        values, counts = np.unique(segment_indices, return_counts=True)
        # print(f"values {values}")
        # print(f"counts {counts}")
        counts_filled = []
        j = 0
        for k in range(number_of_segments):
            if k in values:
                counts_filled.append(counts[j])
                j += 1
            else:
                counts_filled.append(0)
        variances = (max_plot_width / 2) * (counts_filled / np.max(counts))
        # print(f"variances {variances}")
        jitter_unadjusted = np.random.uniform(-1, 1, len(distribution))
        jitter = np.take(variances, segment_indices) * jitter_unadjusted
        # print(f"jitter {jitter}")
        ax.scatter(jitter + ticks[i], distribution, alpha=alpha)

    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    if vertical_limits:
        ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
    if not grid:
        ax.grid(False)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    plt.show()

И код для воссоздания второй диаграммы выше:

# Create example random data
np.random.seed(0)
distro1 = np.random.normal(0, 2, 4)
distr2 = np.random.normal(1, 1, 10)
distr3 = np.random.normal(2, 3, 1000)

distributions = [distro1, distr2, distr3]
fancy_distribution_plot(distributions, tick_labels=['distro1', 'distro2', 'distro3'], number_of_segments=100,
                        grid=False)

Может быть, здесь будет интересен violinplot Seaborn?

JohanC 16.06.2024 11:44

@JohanC Да, определенно, это одно из решений, но мне было интересно, смогу ли я сделать форму распределения еще более точной.

ufghd34 16.06.2024 14:39
Как сделать так, чтобы точки в Swarmlot перекрывались друг с другом предлагает хак, временно изменяющий размер фигуры для создания перекрывающихся точек. Но этот хак, похоже, больше не работает в последней (улучшенной) библиотеке matplotlib. Альтернатива — использование точек меньшего размера (sns.swarmplot(..., size=2).
JohanC 17.06.2024 20:51

Я предлагаю изменить эту строку variances = (max_plot_width / 2) * (counts_filled / np.max(counts)) на что-то вроде variances = (max_plot_width / 2) * (counts_filled / max_counts), где max_counts — максимальное количество для всех распределений, а не только i-го. Для этого вам нужно разделить цикл for на две части: одну для получения счетчиков, а другую для установки масштабированного джиттера и построения графика.

Luca 18.06.2024 16:15
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
4
55
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Развернув мой комментарий, вы можете масштабировать дисперсию (и, следовательно, джиттер), разделив ее на максимум count среди всех распределений.

Возможная реализация (начиная с вашей функции):

import matplotlib.pyplot as plt
import numpy as np


def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                            number_of_segments=12,
                            separation_between_plots=0.1,
                            separation_between_subplots=0.1,
                            vertical_limits=None,
                            grid=False,
                            remove_outlier_above_segment=None,
                            remove_outlier_below_segment=None,
                            y_label=None,
                            title=None):
    fig, ax = plt.subplots()

    number_of_plots = len(distributions)

    ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)

    ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
             for i in range(0, number_of_plots)]
    
    max_counts = 0.0
    counts_filled_list = []
    segment_indices_list = []
    for i in range(len(distributions)):
        distribution = distributions[i]
        
        segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
        
        segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
        
        if remove_outlier_above_segment:
            a = remove_outlier_above_segment[i]
            distribution = distribution[segment_indices <= a]
            segment_indices = segment_indices[segment_indices <= a]

        if remove_outlier_below_segment:
            b = remove_outlier_below_segment[i]
            distribution = distribution[segment_indices >= b - 1]
            segment_indices = segment_indices[segment_indices >= b - 1]
        segment_indices_list.append(segment_indices)

        values, counts = np.unique(segment_indices, return_counts=True)
        if np.max(counts) > max_counts:
            max_counts = np.max(counts)
        counts_filled = []
        j = 0
        for k in range(number_of_segments):
            if k in values:
                counts_filled.append(counts[j])
                j += 1
            else:
                counts_filled.append(0)
        counts_filled_list.append(counts_filled)

    for i in range(len(distributions)):    
        #print(f"counts filled {counts_filled}")
        variances = (max_plot_width / 2) * (counts_filled_list[i] / max_counts)
        #print(f"variances {variances}")
        jitter_unadjusted = np.random.uniform(-1, 1, len(distributions[i])) 
        jitter = np.take(variances, segment_indices_list[i]) * jitter_unadjusted

        # print(f"jitter {jitter}")
        ax.scatter(jitter + ticks[i], distributions[i], alpha=alpha)

    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    if vertical_limits:
        ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
    if not grid:
        ax.grid(False)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    plt.show()

Это из данных вашего игрушечного примера дает

Код довольно беспорядочный, а дублирование циклов for не является ни элегантным, ни эффективным: я надеюсь, что, по крайней мере, результат будет тем, что вы искали.

Другие вопросы по теме