У меня возникла странная проблема при индексировании массива Jax с использованием списка. Если я помещу отладчик в середину своего кода, у меня будет следующее:
Этот массив создается путем преобразования массива numpy.
Однако когда я пробую это в новом экземпляре Python, у меня получается правильное поведение:
[
Что происходит?
Это работает так, как ожидалось. JAX следует семантике индексации NumPy, а в случае расширенной индексации с несколькими скалярами и целочисленными массивами, разделенными срезами, индексированные измерения объединяются посредством широковещательной передачи и перемещаются в начало выходного массива. Подробнее о деталях такого индексирования вы можете прочитать в документации NumPy: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing. В частности:
Следует различать два случая комбинации индексов:
- Расширенные индексы разделены срезом, эллипсисом или новой осью. Например
x[arr1, :, arr2]
.- Все продвинутые индексы расположены рядом друг с другом. Например,
x[..., arr1, arr2, :]
, но неx[arr1, :, 1]
, поскольку 1 в этом отношении является расширенным индексом.В первом случае измерения, полученные в результате расширенной операции индексации, идут первыми в массиве результатов, а затем размеры подпространства. Во втором случае измерения из расширенных операций индексирования вставляются в результирующий массив в том же месте, где они были в исходном массиве.
Код вашей программы подпадает под первый случай, а код вашего отдельного интерпретатора — под второй случай. Вот почему вы видите разные результаты.
Вот краткий пример этой разницы:
>>> import numpy as np
>>> x = np.zeros((3, 4, 5))
>>> x[0, :, [1, 2]].shape # size-2 dimension moved to front
(2, 4)
>>> x[:, 0, [1, 2]].shape # size-2 dimension not moved to front
(3, 2)
Это потому, что вы выполняете другой код. coords_[0, :, :, 0, [1, 2]]
имеет скалярные индексы и индексы массива, разделенные срезами, поэтому индексированные измерения объединяются посредством широковещательной передачи и перемещаются в начало выходного массива. aj[:, :, [1, 2]]
не имеет индексов скаляра и массива, разделенных срезом, поэтому индексированные измерения не перемещаются в начало выходного массива.
Я отредактировал свой ответ, процитировав соответствующий отрывок из документации NumPy, на которую я дал ссылку. Надеюсь, это поможет!
Спасибо, но я до сих пор не понимаю, почему поведение меняется независимо от того, запускаю ли я его «внутри» своего кода или в новом экземпляре Python.