Я пишу процедуру интерполяции и имею словарь, в котором хранятся значения функции в точках подгонки. В идеале, ключи словаря должны представлять собой 2D-массивы Numpy координат подходящей точки, np.array([x, y])
, но поскольку массивы Numpy не подлежат хэшированию, они преобразуются в кортежи для ключей.
# fit_pt_coords: (n_pts, n_dims) array
# fn_vals: (n_pts,) array
def fit(fit_pt_coords, fn_vals):
pt_map = {tuple(k): v for k, v in zip(fit_pt_coords, fn_vals)}
...
Далее в коде мне нужно получить значения функции, используя координаты в качестве ключей, чтобы выполнить интерполяцию. Мне бы хотелось, чтобы это было в коде @jax.jit
, но значения координат имеют тип <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
, который невозможно преобразовать в кортеж. Я пробовал и другие вещи, например, создание ключа словаря как (x + y, x - y)
, но опять же, для этого требуются конкретные значения, и вызов .item()
приводит к ConcretizationTypeError
.
На данный момент я @jax.jit
использовал весь код, который только мог, и только что оставил этот код необработанным. Однако было бы здорово, если бы я мог использовать и этот код. Есть ли какие-нибудь лучшие способы индексации словаря (или лучшие структуры данных, совместимые с Jax), которые позволили бы отбрасывать весь код? Я новичок в Jax и до сих пор не понимаю, как он работает, поэтому уверен, что должны быть лучшие способы сделать это...
Невозможно использовать отслеживаемые значения JAX в качестве ключей словаря. Проблема в том, что значения ключей не будут известны до момента выполнения компилятора XLA, а у XLA нет словарной структуры данных, до которой можно было бы свести такие запросы.
Существуют несовершенные решения, такие как хранение словаря на хосте и использование чего-то вроде io_callback для выполнения поиска по словарю на хосте, но этот подход приводит к снижению производительности, что, вероятно, сделает его непрактичным.
К сожалению, лучшим способом сделать это эффективно в JIT, вероятно, будет переключение на другой алгоритм интерполяции, который не зависит от поиска в хеш-таблице.
Я согласен с @jakevdp, что это может быть не лучшее решение. Python не самый быстрый, когда встроенные модули зацикливаются.
Python может делать что угодно... За исключением циклов for. Для этого мы используем numpy.
pandas.DataFrame
со столбцами ["x", "y", "v"].К сожалению, нужные мне функции интерполяции Scipy на данный момент несовместимы с Jax :( И я вынужден использовать Jax из-за остальной части моей реализации. Это нормально — мне нравится писать свои собственные!
Большое спасибо за ваш ответ. Да, переключение на другую реализацию здесь кажется хорошей идеей — я думаю, что нашел лучший способ использовать собственную реализацию kd-tree, похожую на код Scipy.