Думаю, большинство людей, знакомых с jax, видели этот пример в документации и знают, что он не работает:
import jax.numpy as jnp
from jax import jit
class CustomClass:
def __init__(self, x: jnp.ndarray, mul: bool):
self.x = x
self.mul = mul
@jit # <---- How to do this correctly?
def calc(self, y):
if self.mul:
return self.x * y
return y
c = CustomClass(2, True)
c.calc(3)
Упоминаются 3 обходных пути, но похоже, что применение jit напрямую как функции, а не декоратора, также работает нормально. То есть JAX не жалуется на то, что не знает, как бороться с типом CustomClass
self
:
import jax.numpy as jnp
from jax import jit
class CustomClass:
def __init__(self, x: jnp.ndarray, mul: bool):
self.x = x
self.mul = mul
# No decorator here !
def calc(self, y):
if self.mul:
return self.x * y
return y
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
6 # works fine!
Хотя это и не задокументировано (что, возможно, должно быть?), похоже, это работает идентично пометке self как статического с помощью @partial(jax.jit, static_argnums=0)
, поскольку изменение self
ничего не делает для последующих вызовов, т.е.:
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False
print(jitted_calc(3))
6
6 # no update
Поэтому я изначально предполагал, что декораторы в целом могут иметь дело с self как со статическим параметром при их непосредственном применении. Потому что метод может быть сохранен в другой переменной с конкретным экземпляром (копией) self. В качестве проверки работоспособности я проверил, действительно ли это делают не-jit-декораторы, но, похоже, это не так, поскольку приведенная ниже не-jit-декоративная функция успешно обрабатывает изменения в себе:
def decorator(func):
def wrapper(*args, **kwargs):
x = func(*args, **kwargs)
return x
return wrapper
custom = CustomClass(2, True)
decorated_calc = decorator(custom.calc)
print(decorated_calc(3))
custom.mul = False
print(decorated_calc(3))
6
3
Я видел несколько других вопросов о применении декораторов напрямую как функций в сравнении со стилем декоратора (например, здесь и здесь), и там упоминается, что между двумя версиями есть небольшая разница, но это почти никогда не должно иметь значения.
Мне остается задаться вопросом, что такого в декораторе jit, который заставляет эти версии вести себя так по-разному, поскольку JAX.jit может работать с типом self
, если не в декорированном стиле. Если у кого-то есть ответ, мы будем очень признательны.
Декораторы не имеют ничего общего со статическими аргументами: статические аргументы — это концепция, специфичная для jax.jit
.
Выполняя резервное копирование, вы должны иметь в виду, что всякий раз, когда jax.jit
компилирует функцию, он кэширует артефакт компиляции на основе нескольких величин, в том числе:
shape
и dtype
static_argnums
или static_argnames
Имея это в виду, давайте рассмотрим этот фрагмент:
c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False
print(jitted_calc(3))
причина, по которой jitted_calc
не обновляется при обновлении атрибутов c
, заключается в том, что ничего, связанное с ключом кэша, не изменилось: (1) идентификатор функции тот же, (2) форма и тип аргумента не изменились, (3 ) нет статических аргументов, (4) глобальные конфигурации не изменились. Таким образом, предыдущий кэшированный артефакт компиляции (с предыдущим значением mul
) выполняется снова. Это основная причина, по которой я не упомянул эту стратегию в документе, на который вы ссылаетесь: это редко бывает то поведение, которое хотелось бы пользователям.
Этот подход к обертыванию связанного метода в JIT, кстати, похож на обертывание определения метода с помощью @partial(jit, static_argnums=0)
, но детали не такие: в версии static_argnums
self
помечен как статический аргумент, и поэтому его хэш становится частью JIT. кэш. Метод __hash__
по умолчанию для класса просто основан на идентификаторе экземпляра, поэтому изменение c.mul
не меняет хеш и не запускает перекомпиляцию. Вы можете увидеть пример того, как это исправить, в разделе Стратегия 2 в документе, на который вы ссылаетесь: по сути, определите соответствующие методы __hash__
и __eq__
для класса:
class CustomClass:
def __init__(self, x: jnp.ndarray, mul: bool):
self.x = x
self.mul = mul
@partial(jit, static_argnums=0)
def calc(self, y):
if self.mul:
return self.x * y
return y
def __hash__(self):
return hash((self.x, self.mul))
def __eq__(self, other):
return (isinstance(other, CustomClass) and
(self.x, self.mul) == (other.x, other.mul))
В вашем последнем примере вы определяете это:
def decorator(func):
def wrapper(*args, **kwargs):
x = func(*args, **kwargs)
return x
return wrapper
Этот код вообще не использует jax.jit
. Тот факт, что изменения в c.mul
приводят к изменениям в выходных данных, не имеет ничего общего с синтаксисом декоратора, а скорее связан с тем, что здесь не задействован JIT-кеш.
Надеюсь, все понятно!
Спасибо за развернутый ответ, Джейк. Что мне остается немного неясным, так это то, почему именно вторая версия вообще работает и не дает ошибок. Я предполагаю, что тип self
по-прежнему является типом, с которым JAX не знает, как справиться. Я вижу, что когда я это делаю jax.jit(c.calc)
, в игре вообще нет self
, так что, очевидно, проблем быть не должно. Но объект все равно должен где-то жить, чтобы JAX мог получить к нему доступ. Все свойства под self
просто преобразуются в неявные константы в jaxpr, как вы упомянули здесь?
когда вы пишете c.calc
, вы не возвращаете функцию, которая принимает self
в качестве аргумента, вы возвращаете связанный объект метода. JAX не видит self
нигде в этой конструкции; это похоже на запись f = lambda x: c.calc(x)
: это допустимый вызываемый объект, который не принимает self
в качестве явного аргумента. Имеет ли это смысл?
Ах, я понимаю, я думал, что с типом self
нужно где-то иметь дело (в конце концов, нам нужны свойства), но, похоже, это не так, по крайней мере, не на границе jit (?). Спасибо @jakevdp
Я думаю, разница в том, что
CustomClass
еще не существует, когда используется декоратор, но если вместо этого он используется как функция после определения класса, класс существует.