Мне трудно понять, как работают mplcursors
курсоры. Позвольте мне привести пример.
import pandas as pd
import matplotlib.pyplot as plt
import mplcursors
%matplotlib qt5
def random_data(rows_count):
data = []
for i in range(rows_count):
row = {}
row["x"] = np.random.uniform()
row["y"] = np.random.uniform()
if (i%2 == 0):
row["type"] = "sith"
row["color"] = "red"
else:
row["type"] = "jedi"
row["color"] = "blue"
data.append(row)
return pd.DataFrame(data)
data_df = random_data(30)
fig, ax = plt.subplots(figsize=(8,8))
ax = plt.gca()
types = ["jedi","sith"]
for scat_type in types:
local_data_df = data_df.loc[data_df["type"] == scat_type]
scat = ax.scatter(local_data_df["x"],
local_data_df["y"],
c=local_data_df["color"],
label=scat_type)
cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
@cursor.connect("add")
def on_add(sel):
annotation = (local_data_df.iloc[sel.index]["type"]+
"\n"+str(local_data_df.iloc[sel.index]["x"])+
"\n"+str(local_data_df.iloc[sel.index]["y"]))
sel.annotation.set(text=annotation)
ax.legend()
plt.title("a battle of Force users")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-1, 2)
plt.ylim(-1, 2)
ax.set_aspect('equal', adjustable='box')
plt.show()
Этот код должен генерировать DataFrame таким образом, что каждая строка имеет случайные свойства x
, y
, type
, который является jedi
или sith
, и color
, который является blue
или red
, в зависимости от того, является ли строка jedi
или sith
, затем диаграмма рассеяния джедаев в их цвете, прикрепите к ним курсор, а затем постройте диаграмму рассеяния ситхов в их цвете, прикрепите к ним другой курсор и отобразите поле легенды, сообщающее читателю, что синие точки соответствуют jedi
рядам, а красные — sith
.
Однако при наведении точек в аннотациях написано, что все точки sith
и координаты выглядят не очень.
Я хотел бы понять, почему код не делает то, что я хотел бы.
Просто для уточнения: я вызываю .scatter()
для каждого типа (jedi
или sith
), а затем пытаюсь прикрепить курсор к каждому из графиков, потому что я пробовал вызывать scatter
в целом data_df
, но тогда .legend()
не отображает то, что я хочу.
Я надеюсь, что ответа, который вы мне дадите, будет достаточно, чтобы я смог написать код, который отображает точки jedi
и sith
, показывает правильные аннотации и правильное поле легенды.
Происходит много странных вещей.
Одна из путаниц заключается в том, что наличие переменной local_data_df
внутри цикла for
создаст переменную, которая будет локальной только для одного цикла цикла. Вместо этого это просто глобальная переменная, которая переопределяется для каждого цикла. Точно так же определение функции on_add
внутри цикла for
не делает ее локальной. Также on_add
будет глобальным и будет переопределяться каждым циклом цикла for
.
Другая путаница заключается в том, что подключенная функция будет иметь доступ к локальным переменным из другой функции или цикла. Вместо этого такие локальные переменные становятся недоступными после завершения функции или цикла.
Далее не то, что sel.index
будет индексом не в датафрейм, а в точки точечной диаграммы. Вы можете сбросить индекс «local df», чтобы он был похож на то, как упорядочивается sel.index
.
Чтобы имитировать вашу локальную переменную, вы можете добавить дополнительные данные к объекту scat
. Например. scat.my_data = local_df
добавит эту переменную к глобальному объекту, содержащему точечный элемент (PathCollection
, который содержит всю информацию, необходимую matplotlib для представления точечных точек). Хотя переменная scat
переопределяется, для каждого вызова PathCollection
существует одна ax.scatter
. (Вы также можете получить к ним доступ через ax.collections
).
Вот переписываем ваш код, стараясь максимально приблизиться к оригиналу:
import pandas as pd
import matplotlib.pyplot as plt
import mplcursors
def random_data(rows_count):
df = pd.DataFrame({'x': np.random.uniform(0, 1, rows_count),
'y': np.random.uniform(0, 1, rows_count),
'type': np.random.choice(['sith', 'jedi'], rows_count)})
df['color'] = df['type'].replace({'sith': 'red', 'jedi': 'blue'})
return df
def on_add(sel):
local_data_df = sel.artist.my_data
annotation = (local_data_df.iloc[sel.index]["type"] +
"\n" + str(local_data_df.iloc[sel.index]["x"]) +
"\n" + str(local_data_df.iloc[sel.index]["y"]))
sel.annotation.set(text=annotation)
data_df = random_data(30)
fig, ax = plt.subplots(figsize=(8, 8))
types = ["jedi", "sith"]
for scat_type in types:
local_data_df = data_df.loc[data_df["type"] == scat_type].reset_index() # resetting the index is necessary
scat = ax.scatter(local_data_df["x"],
local_data_df["y"],
c=local_data_df["color"],
label=scat_type)
scat.my_data = local_data_df # store the data into the scat object
cursor = mplcursors.cursor(scat, hover=mplcursors.HoverMode.Transient)
cursor.connect("add", on_add)
ax.legend()
ax.set_title("a battle of Force users")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_xlim(-1, 2)
ax.set_ylim(-1, 2)
ax.set_aspect('equal', adjustable='box')
plt.show()