Индексирование словаря с помощью Numpy/Jax

Я пишу процедуру интерполяции и имею словарь, в котором хранятся значения функции в точках подгонки. В идеале, ключи словаря должны представлять собой 2D-массивы Numpy координат подходящей точки, np.array([x, y]), но поскольку массивы Numpy не подлежат хэшированию, они преобразуются в кортежи для ключей.

# fit_pt_coords: (n_pts, n_dims) array
# fn_vals: (n_pts,) array
def fit(fit_pt_coords, fn_vals):
    pt_map = {tuple(k): v for k, v in zip(fit_pt_coords, fn_vals)}
    ...

Далее в коде мне нужно получить значения функции, используя координаты в качестве ключей, чтобы выполнить интерполяцию. Мне бы хотелось, чтобы это было в коде @jax.jit, но значения координат имеют тип <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, который невозможно преобразовать в кортеж. Я пробовал и другие вещи, например, создание ключа словаря как (x + y, x - y), но опять же, для этого требуются конкретные значения, и вызов .item() приводит к ConcretizationTypeError.

На данный момент я @jax.jitиспользовал весь код, который только мог, и только что оставил этот код необработанным. Однако было бы здорово, если бы я мог использовать и этот код. Есть ли какие-нибудь лучшие способы индексации словаря (или лучшие структуры данных, совместимые с Jax), которые позволили бы отбрасывать весь код? Я новичок в Jax и до сих пор не понимаю, как он работает, поэтому уверен, что должны быть лучшие способы сделать это...

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
57
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

Невозможно использовать отслеживаемые значения JAX в качестве ключей словаря. Проблема в том, что значения ключей не будут известны до момента выполнения компилятора XLA, а у XLA нет словарной структуры данных, до которой можно было бы свести такие запросы.

Существуют несовершенные решения, такие как хранение словаря на хосте и использование чего-то вроде io_callback для выполнения поиска по словарю на хосте, но этот подход приводит к снижению производительности, что, вероятно, сделает его непрактичным.

К сожалению, лучшим способом сделать это эффективно в JIT, вероятно, будет переключение на другой алгоритм интерполяции, который не зависит от поиска в хеш-таблице.

Большое спасибо за ваш ответ. Да, переключение на другую реализацию здесь кажется хорошей идеей — я думаю, что нашел лучший способ использовать собственную реализацию kd-tree, похожую на код Scipy.

LordCat 21.07.2024 13:44

Я согласен с @jakevdp, что это может быть не лучшее решение. Python не самый быстрый, когда встроенные модули зацикливаются.

Python может делать что угодно... За исключением циклов for. Для этого мы используем numpy.

  1. Возможно, подойдет pandas.DataFrame со столбцами ["x", "y", "v"].
  2. Можете ли вы использовать функции scipy.interpolate?

К сожалению, нужные мне функции интерполяции Scipy на данный момент несовместимы с Jax :( И я вынужден использовать Jax из-за остальной части моей реализации. Это нормально — мне нравится писать свои собственные!

LordCat 21.07.2024 13:48

Другие вопросы по теме