Как использовать vmap для нескольких экземпляров Dense в модели льна? пытаясь избежать циклического перебора списка экземпляров Dense

from jax import random,vmap
from jax import numpy as jnp
import pprint

def f(s,layers,do,dx):
    x = jnp.zeros((do,dx))
    for i,layer in enumerate(layers):
        x=x.at[i].set( layer( s[i] ) )
    return x

class net(nn.Module):
    dx: int 
    do: int 
    def setup(self):
        self.layers = [ nn.Dense( self.dx, use_bias=False )
                        for _ in range(self.do) ]
    def __call__(self, s):
        x = vmap(f,in_axes=(0,None,None,None))(s,self.layers,self.do,self.dx)
        return x

if __name__ == '__main__':
    seed = 123
    key = random.PRNGKey( seed )
    key,subkey = random.split( key )
    outer_batches = 4
    s_observations = 5 # AKA the inner batch
    x_features = 2
    s_features = 3
    s_shape = (outer_batches,s_observations, s_features)
    s = random.uniform( subkey, s_shape )

    key,subkey = random.split( key )    
    model = net(x_features,s_observations)
    p = model.init( subkey, s )
    x = model.apply( p, s )    

    params = p['params']
    pkernels = jnp.array([params[key]['kernel'] for key in params.keys()])
    x_=jnp.zeros((outer_batches,s_observations,x_features))
    
    g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
    
    x_=g(s,pkernels)
    print('s shape:',s.shape)
    print('p shape:',pkernels.shape)
    print('x shape:',x.shape)
    print('x_ shape:',x_.shape)
    print('sum of difference:',jnp.sum(x-x_))

Привет. Мне нужны некоторые параметры, специфичные для партии, в моей модели. Здесь существует «внутренний пакет» длины do, в котором есть экземпляр flax.linen.Dense для каждого элемента в этом пакете. Внешний пакет просто передает несколько экземпляров данных в эти слои. Я добиваюсь этого, создавая список экземпляров flax.linen.Dense в методе setup. Затем в методе __call__ я перебираю эти слои, чтобы заполнить массив. Эта итерация инкапсулирована в функцию f, а эта функция обернута в jax.vmap.

Я также включил эквивалентную логику, написанную в виде умножения матриц (см. функцию g), чтобы было ясно, какую операцию я надеялся захватить с помощью этого класса.

Я хотел бы заменить цикл for в методе __call__ вызовом jax.vmap. Я получаю сообщение об ошибке, когда передаю список vmap, и я получаю сообщение об ошибке, когда пытаюсь поместить несколько экземпляров Dense в массив jax. Есть ли альтернатива использованию списка для хранения нескольких экземпляров Dense? Ограничение состоит в том, что я должен иметь возможность создать произвольное количество экземпляров Dense во время инициализации модели.

Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
2
0
159
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

vmap можно использовать для сопоставления одной функции с пакетами данных. Вы пытаетесь использовать его для сопоставления нескольких функций с пакетами данных, чего он сделать не может.


Обновленный ответ на основе обновленного вопроса:

Поскольку каждый слой идентичен, за исключением параметров, соответствующих входным данным, похоже, что вам нужно сопоставить один плотный слой с пакетом данных. Это может выглядеть примерно так:

keys = vmap(random.fold_in, in_axes=(None, 0))(subkey, jnp.arange(s_observations))
model = nn.Dense(x_features, use_bias=False)
p = vmap(model.init, in_axes=(0, 1))(keys, s)
x = vmap(model.apply, in_axes=(0, 1), out_axes=1)(p, s)

pkernels = p['params']['kernel']
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)

print('sum of difference:',jnp.sum(x-x_))
# sum of difference: 0.0

Предыдущий ответ

В общем, исправлением будет определение одного параметризованного слоя, который вы можете передать vmap. В приведенном вами примере все слои идентичны, поэтому для достижения желаемого результата вы можете написать что-то вроде этого:

def f(s,layer,dx):
  return layer(s)

class net(nn.Module):
    dx: int 
    do: int 
    def setup(self):
        self.layer = nn.Dense( self.dx, use_bias=False )
    def __call__(self, s):
        x = vmap(f,in_axes=(0,None,None))(s,self.layer,self.dx)
        return x

Если бы у вас были разные параметризации для каждого слоя, вы могли бы добиться этого в vmap, передав эти параметры также в vmap.

В приведенном мной примере существует экземпляр Dense для каждого элемента в этом внутреннем пакете размером s_observations, каждый с разными параметрами. Хотя они относятся к одному классу, у них разные значения параметров. Мне нужно будет обработать данные в этих пакетах с помощью дополнительных слоев или операций с независимыми параметрами. Было бы полезно реализовать единый конвейер обработки, с помощью которого я мог бы векторизовать свои данные. Возможно, моя ошибка здесь заключается в том, что я полагаюсь на экземпляр Dense, а не на создание экземпляра массива параметров и векторизацию по соответствующему размеру.

user137146 25.04.2024 19:40

Я не уверен, что понимаю: в приведенном вами примере каждый слой определяется как nn.Dense(self.dx, use_bias=False) с self.dx константой для всех слоев. Где определяются различные параметры?

jakevdp 25.04.2024 20:51

Привет, это вопрос о том, как настроить эти льняные объекты для использования автовекторизации, как это предусмотрено jax.vmap; не вопрос о том, как использовать vmap. Каждый экземпляр flax.linen.Dense содержит атрибут, который является параметрами модели машинного обучения, и эти параметры модели машинного обучения специфичны для этих экземпляров Dense. Таким образом, каждый из вызовов метода __call__ этих Dense экземпляров будет работать с входными s[:,i] и параметрами, скрытыми внутри каждого Dense экземпляра.

user137146 26.04.2024 05:03

Конечно, я это понимаю. Но в вашем примере кода ничего из этого нет, поэтому на ваш вопрос сложно ответить, поскольку основная часть является чисто гипотетической. Возможно, вы сможете обновить свой код, чтобы показать, как параметризуются ваши слои? От этих деталей будет зависеть ответ на ваш вопрос.

jakevdp 26.04.2024 05:46

Привет. (1) Я нашел/исправил ошибку в функции f, из-за которой переменная x фактически не заполнялась :facepalm:. (2) В нижней части скрипта я включил дополнительную логику, которая перезаписывает параметры модели и весь класс net как умножение матриц. Надеюсь, это сделает структуру моей проблемы менее гипотетической. Я намеревался переписать свою модель, используя эти льняные Dense примеры, чтобы сделать ее более доступной для широкой аудитории, но этот разговор заставляет меня думать, что эти Dense примеры не облегчают общение.

user137146 26.04.2024 10:17

Я вижу обновленный код: не думаю, что это меняет мой ответ. Все слои в layers идентичны, поэтому вместо того, чтобы перебирать их, вы можете создать один слой и vmap поверх него. В тех случаях, когда слои различаются, вам следует включить разные параметры в vmap. Если вы покажете пример, где каждый слой имеет разные параметры, я мог бы показать пример того, как это сделать.

jakevdp 26.04.2024 18:11

если вы добавите строку print(p) в конец скрипта, он покажет вам разные значения.

user137146 26.04.2024 18:38

Другие вопросы по теме