У меня есть структурированный массив numpy, и я передаю в нем один элемент функции, как показано ниже.
from numba import njit
import numpy as np
dtype = np.dtype([
("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4"),
])
a = np.array([(1, b"24q1", 1.0)], dtype=dtype)
@njit
def upsert_numba(a, sid, qtrnm, val):
a[1] = qtrnm
a[2] = val
#i = 0
#a[i+1] = qtrnm
#a[i+2] = val
return a
x = (1, b"24q2", 3.0)
print(upsert_numba(a[0].copy(), *x))
Приведенный выше код работает без проблем. Но если обновление происходит через закомментированные коды, т. е. i=0;a[i+1]=qtrnm;a[i+2]=val, numba выдает следующую ошибку.
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(Record(id[type=int32;offset=0],qtrnm0[type=[char x 4];offset=4],qtr0[type=float32;offset=8];12;False), int64, readonly bytes(uint8, 1d, C))
Похоже, что индексирование разрешено только константой, которая может быть целым числом или CharSeq, известным во время компиляции, но не выражением константы, которое также известно во время компиляции. Могу ли я узнать, что происходит под капотом?
Я пробовал другую константу в качестве индекса, например «j=i; a[j]», которая тоже работает. Но неудивительно, что «j=i+1;a[j]» терпит неудачу.
Это кажется тот самый (нерешённый) вопрос. И в этом, наверное, проблема. MRE прояснит, так это или нет.
@ken, я только что обновил вопрос, добавив дополнительную информацию.
Рассмотрим следующие функции.
from numba import njit
@njit
def func():
i = 777
t = i
return t
@njit
def func2():
i = 776
t = i + 1
return t
Вы можете проверить, как выводится тип каждой переменной, используя следующий метод.
func()
func.inspect_types()
Это ключевые строки:
# i = const(int, 777) :: Literal[int](777)
# t = i :: Literal[int](777)
Часть после ::
— это тип переменной.
Это указывает на то, что оба i
и t
имеют целочисленный литеральный тип.
Далее для func2
:
func2()
func2.inspect_types()
# i = const(int, 776) :: Literal[int](776)
# t = i + $const10.2 :: int64
По сравнению с func
вы можете видеть, что t
выводится как int64
, а не как целочисленный литеральный тип.
Это означает, что numba выполняет вывод типа кода перед оптимизацией.
Это разумный выбор. Типизированный код необходим для оптимизации, но для генерации типизированного кода требуется определение типа. Таким образом, сначала выполняется вывод типа для байт-кода Python, а затем выполняется оптимизация на основе выведенных типов. Более точную и подробную информацию об этом потоке можно найти в официальной документации .
Таким образом, вам нужна постоянная переменная на этапе байт-кода Python.
Дополнительное примечание: numba не поддерживает индексацию записей с нелитеральными переменными. Однако это каким-то образом возможно, если явно определить сопоставление посредством перегрузки.
from operator import setitem
import numpy as np
from numba import njit, types
from numba.core.extending import overload
a_dtype = np.dtype([("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4")])
@overload(setitem)
def setitem_overload_for_a(a, index, value):
if getattr(a, "dtype", None) != a_dtype:
return None
if isinstance(value, (types.Integer, types.Float)):
def numeric_impl(a, index, value):
# You need to map these indexes correctly according to the dtype.
if index == 0:
a[0] = value
elif index == 2:
a[2] = value
else:
raise ValueError()
return numeric_impl
elif isinstance(value, (types.Bytes, types.CharSeq)):
def bytes_impl(a, index, value):
if index == 1:
a[1] = value
else:
raise ValueError()
return bytes_impl
else:
raise TypeError(f"Unsupported type: {index=}, {value=}, {a.dtype=}")
@njit
def upsert_numba(a, sid, qtrnm, val):
i = 0
a[i + 1] = qtrnm
a[i + 2] = val
return a
x = (1, b"24q2", 3.0)
a = np.array([(1, b"24q1", 1.0)], dtype=a_dtype)
print(upsert_numba(a[0].copy(), *x)) # (1, b'24q2', 3.)
Обратите внимание, что это специальная стратегия, которая требует жесткого кодирования setitem для каждого типа записи и может не работать в некоторых случаях. Тем не менее, это должно сработать, если только вы не делаете что-то очень сложное.
Спасибо за подробности. Я приму этот ответ.
Пожалуйста, предоставьте минимально воспроизводимый пример. Что такое dtype у
a
? Что именно значит не работает? Если вы получили ошибку, пожалуйста, добавьте ее в вопрос. А еще мне интересно, работает лиa[1]
.