Я хотел бы отправить тип второго аргумента в функции в numba, но мне это не удалось.
Если это целое число, то должен быть возвращен вектор, если он сам является массивом целых чисел, то должна быть возвращена матрица.
Первый код не работает
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
elif isinstance(indices, nb.int64[:]):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
а второй, с else, так и есть.
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
else:
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
Я предполагаю, что проблема в объявлении типа через nb.int64[:], но по-другому я не могу заставить его работать.
У вас есть идея?
Обратите внимание, что этот вопрос относится к numba>=0.59.
generated_jit устарел в более ранних версиях и фактически удален из
версии 0.59 и далее.






Вам не следует использовать isinstance в такой функции JIT, а вместо этого используйте @overload ( @generated_jit был старым устаревшим способом сделать это), который специально создан для этой цели. Это позволяет Numba генерировать код быстрее, поскольку для каждого случая компилируется только часть функции, а не весь случай для каждой специализации. Более того, isinstance является экспериментальным, как указано Numba в предупреждении при выполнении вашего первого кода (предупреждения выдаются, чтобы пользователи могли их прочитать ;)).
@overloadНачиная с Numba 0.59, вместо него необходимо использовать перегрузку:
import numba as nb
import numpy as np
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
# Pure-python fallback implementation
def test_dispatch_impl(X, indices):
if isinstance(indices, (int, np.integer)):
return test_dispatch_scalar(X, indices)
elif isinstance(indices, np.ndarray) and indices.ndim == 1 and np.issubdtype(indices.dtype, np.integer):
return test_dispatch_vector(X, indices)
else:
assert False # Unsupported
# Numba-specific overload
@nb.extending.overload(test_dispatch_impl)
def test_dispatch_impl_overload(X, indices):
if isinstance(indices, nb.types.Integer):
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
return test_dispatch_vector
else:
assert False # Unsupported
@nb.njit
def test_dispatch(X, indices):
return test_dispatch_impl(X, indices)
Вот пример рассуждений об универсальных типах:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if isinstance(indices, nb.types.Integer):
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
Вот пример рассуждений о конкретных типах:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if indices == nb.types.int64:
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and indices.dtype == nb.types.int64:
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
Запрос конкретных 64-битных целых чисел может быть слишком ограничительным, поэтому я советую вам смешивать тесты общего типа и тесты конкретных типов. По той же причине вам следует избегать прямого тестирования массивов определенного типа просто потому, что они часто могут быть смежными или нет или могут содержать типы элементов, совместимые с вашей функцией.
Обратите внимание, что общие функции JIT предназначены для создания функций, которые компилируются отдельно относительно целевого типа входных данных (а не значений).
Фактический код, который в настоящее время работает нормально и его необходимо переписать без generated_jit, — это github.com/mcocdawc/chemcoord/blob/master/src/chemcoord/…
Обратите внимание, что в реальном коде я использую nb.types.Array и nb.types.Integer, но для целей вопроса я хотел упростить и запросил специально int64.
Вы можете использовать extending.overload вместо generated_jit, как обсуждалось в этом выпуске.
@mcocdawc Я не знал об устаревании. Я просто следую инструкциям в документе (об использовании overload) по вашей ссылке, и все сработало.
Большое спасибо за вашу помощь. Мне не нравится новый синтаксис перегрузки, потому что он очень многословен и повторяется. В той же ссылке они также привели гораздо более хороший пример, в котором напрямую использовался синтаксис моего вопроса. Они пишут об этом [...] Кроме того, пользователи, которые используют сгенерированный_jit для отправки некоторых более примитивных типов, могут обнаружить, что поддержки Numba для isinstance достаточно [...] Видимо, уже просто массив целых чисел не является «примитивным». "Хватит уже. Я оставлю вопрос открытым еще немного в надежде, что можно обойтись без повторяющегося синтаксиса.
Я согласен с вами. Я немного изменил реализацию, чтобы сделать ее менее повторяющейся, но она все еще немного многословна и громоздка. К сожалению, это по своей сути связано с тем, как overload работает. Жаль, что они приняли такое решение. Здесь решение isinstance может сработать, если Numba будет поддерживать np.issubdtype, но пока нет :/ . Возможно, этот можно заменить, чтобы было проще. Но имейте в виду, что один и тот же код реплицируется для нескольких реализаций, поэтому для этого кода это подходит, но не для более крупных.
Большое спасибо! Ваш рефакторинг кода — лучшее решение, которое я видел в отношении нового механизма overload. Могу ли я спросить, можете ли вы изменить первое предложение своего ответа, чтобы новые пользователи не были обмануты, заставляя их поверить, что generated_jit по-прежнему является каноническим решением, возможно, вы можете обратиться к обновлению внизу? Тогда я, конечно, приму ваш ответ.
Хорошая точка зрения. Я изменил ответ, чтобы продвигать новое решение @overload, и явно пометил исходное решение как устаревшее.
Раньше я использовал
generated_jit. Мой вопрос поступил от пользователей моей библиотеки, которые жаловались на удалениеgenerated jitиз numba. github.com/mcocdawc/chemcoord/issues/76 Поэтому я хотел провести рефакторинг своего кода, используяisinstancenumba.readthedocs.io/en/stable/reference/…