Использование Jax Jit для метода в качестве декоратора вместо прямого применения функции jit

Думаю, большинство людей, знакомых с jax, видели этот пример в документации и знают, что он не работает:

import jax.numpy as jnp
from jax import jit

class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  @jit  # <---- How to do this correctly?
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y


c = CustomClass(2, True)
c.calc(3)  

Упоминаются 3 обходных пути, но похоже, что применение jit напрямую как функции, а не декоратора, также работает нормально. То есть JAX не жалуется на то, что не знает, как бороться с типом CustomClassself:

import jax.numpy as jnp
from jax import jit

class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  # No decorator here !
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y


c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
6 # works fine!

Хотя это и не задокументировано (что, возможно, должно быть?), похоже, это работает идентично пометке self как статического с помощью @partial(jax.jit, static_argnums=0), поскольку изменение self ничего не делает для последующих вызовов, т.е.:

c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False 
print(jitted_calc(3))
6
6 # no update

Поэтому я изначально предполагал, что декораторы в целом могут иметь дело с self как со статическим параметром при их непосредственном применении. Потому что метод может быть сохранен в другой переменной с конкретным экземпляром (копией) self. В качестве проверки работоспособности я проверил, действительно ли это делают не-jit-декораторы, но, похоже, это не так, поскольку приведенная ниже не-jit-декоративная функция успешно обрабатывает изменения в себе:

def decorator(func):
    def wrapper(*args, **kwargs):
        x = func(*args, **kwargs)
        return x
    return wrapper

custom = CustomClass(2, True)
decorated_calc = decorator(custom.calc)
print(decorated_calc(3))
custom.mul = False
print(decorated_calc(3))
6
3

Я видел несколько других вопросов о применении декораторов напрямую как функций в сравнении со стилем декоратора (например, здесь и здесь), и там упоминается, что между двумя версиями есть небольшая разница, но это почти никогда не должно иметь значения. Мне остается задаться вопросом, что такого в декораторе jit, который заставляет эти версии вести себя так по-разному, поскольку JAX.jit может работать с типом self, если не в декорированном стиле. Если у кого-то есть ответ, мы будем очень признательны.

Я думаю, разница в том, что CustomClass еще не существует, когда используется декоратор, но если вместо этого он используется как функция после определения класса, класс существует.

Michael Butscher 27.08.2024 12:39
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
1
51
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Декораторы не имеют ничего общего со статическими аргументами: статические аргументы — это концепция, специфичная для jax.jit.

Выполняя резервное копирование, вы должны иметь в виду, что всякий раз, когда jax.jit компилирует функцию, он кэширует артефакт компиляции на основе нескольких величин, в том числе:

  1. идентификатор компилируемой функции или вызываемого объекта
  2. статические атрибуты любых нестатических аргументов, таких как shape и dtype
  3. хэш любых аргументов, помеченных как статические с помощью static_argnums или static_argnames
  4. значение любых глобальных конфигураций, которые могут повлиять на результаты

Имея это в виду, давайте рассмотрим этот фрагмент:

c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False 
print(jitted_calc(3))

причина, по которой jitted_calc не обновляется при обновлении атрибутов c, заключается в том, что ничего, связанное с ключом кэша, не изменилось: (1) идентификатор функции тот же, (2) форма и тип аргумента не изменились, (3 ) нет статических аргументов, (4) глобальные конфигурации не изменились. Таким образом, предыдущий кэшированный артефакт компиляции (с предыдущим значением mul) выполняется снова. Это основная причина, по которой я не упомянул эту стратегию в документе, на который вы ссылаетесь: это редко бывает то поведение, которое хотелось бы пользователям.

Этот подход к обертыванию связанного метода в JIT, кстати, похож на обертывание определения метода с помощью @partial(jit, static_argnums=0), но детали не такие: в версии static_argnumsself помечен как статический аргумент, и поэтому его хэш становится частью JIT. кэш. Метод __hash__ по умолчанию для класса просто основан на идентификаторе экземпляра, поэтому изменение c.mul не меняет хеш и не запускает перекомпиляцию. Вы можете увидеть пример того, как это исправить, в разделе Стратегия 2 в документе, на который вы ссылаетесь: по сути, определите соответствующие методы __hash__ и __eq__ для класса:

class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  @partial(jit, static_argnums=0)
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y

  def __hash__(self):
    return hash((self.x, self.mul))

  def __eq__(self, other):
    return (isinstance(other, CustomClass) and
            (self.x, self.mul) == (other.x, other.mul))

В вашем последнем примере вы определяете это:

def decorator(func):
    def wrapper(*args, **kwargs):
        x = func(*args, **kwargs)
        return x
    return wrapper

Этот код вообще не использует jax.jit. Тот факт, что изменения в c.mul приводят к изменениям в выходных данных, не имеет ничего общего с синтаксисом декоратора, а скорее связан с тем, что здесь не задействован JIT-кеш.

Надеюсь, все понятно!

Спасибо за развернутый ответ, Джейк. Что мне остается немного неясным, так это то, почему именно вторая версия вообще работает и не дает ошибок. Я предполагаю, что тип self по-прежнему является типом, с которым JAX не знает, как справиться. Я вижу, что когда я это делаю jax.jit(c.calc), в игре вообще нет self, так что, очевидно, проблем быть не должно. Но объект все равно должен где-то жить, чтобы JAX мог получить к нему доступ. Все свойства под self просто преобразуются в неявные константы в jaxpr, как вы упомянули здесь?

Stackerexp 27.08.2024 15:46

когда вы пишете c.calc, вы не возвращаете функцию, которая принимает self в качестве аргумента, вы возвращаете связанный объект метода. JAX не видит self нигде в этой конструкции; это похоже на запись f = lambda x: c.calc(x): это допустимый вызываемый объект, который не принимает self в качестве явного аргумента. Имеет ли это смысл?

jakevdp 27.08.2024 16:08

Ах, я понимаю, я думал, что с типом self нужно где-то иметь дело (в конце концов, нам нужны свойства), но, похоже, это не так, по крайней мере, не на границе jit (?). Спасибо @jakevdp

Stackerexp 27.08.2024 18:17

Другие вопросы по теме