from jax import random,vmap
from jax import numpy as jnp
import pprint
def f(s,layers,do,dx):
x = jnp.zeros((do,dx))
for i,layer in enumerate(layers):
x=x.at[i].set( layer( s[i] ) )
return x
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layers = [ nn.Dense( self.dx, use_bias=False )
for _ in range(self.do) ]
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None,None))(s,self.layers,self.do,self.dx)
return x
if __name__ == '__main__':
seed = 123
key = random.PRNGKey( seed )
key,subkey = random.split( key )
outer_batches = 4
s_observations = 5 # AKA the inner batch
x_features = 2
s_features = 3
s_shape = (outer_batches,s_observations, s_features)
s = random.uniform( subkey, s_shape )
key,subkey = random.split( key )
model = net(x_features,s_observations)
p = model.init( subkey, s )
x = model.apply( p, s )
params = p['params']
pkernels = jnp.array([params[key]['kernel'] for key in params.keys()])
x_=jnp.zeros((outer_batches,s_observations,x_features))
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)
print('s shape:',s.shape)
print('p shape:',pkernels.shape)
print('x shape:',x.shape)
print('x_ shape:',x_.shape)
print('sum of difference:',jnp.sum(x-x_))
Привет. Мне нужны некоторые параметры, специфичные для партии, в моей модели. Здесь существует «внутренний пакет» длины do
, в котором есть экземпляр flax.linen.Dense
для каждого элемента в этом пакете. Внешний пакет просто передает несколько экземпляров данных в эти слои. Я добиваюсь этого, создавая список экземпляров flax.linen.Dense
в методе setup
. Затем в методе __call__
я перебираю эти слои, чтобы заполнить массив. Эта итерация инкапсулирована в функцию f
, а эта функция обернута в jax.vmap
.
Я также включил эквивалентную логику, написанную в виде умножения матриц (см. функцию g
), чтобы было ясно, какую операцию я надеялся захватить с помощью этого класса.
Я хотел бы заменить цикл for в методе __call__
вызовом jax.vmap
. Я получаю сообщение об ошибке, когда передаю список vmap
, и я получаю сообщение об ошибке, когда пытаюсь поместить несколько экземпляров Dense
в массив jax. Есть ли альтернатива использованию списка для хранения нескольких экземпляров Dense
? Ограничение состоит в том, что я должен иметь возможность создать произвольное количество экземпляров Dense
во время инициализации модели.
vmap
можно использовать для сопоставления одной функции с пакетами данных. Вы пытаетесь использовать его для сопоставления нескольких функций с пакетами данных, чего он сделать не может.
Обновленный ответ на основе обновленного вопроса:
Поскольку каждый слой идентичен, за исключением параметров, соответствующих входным данным, похоже, что вам нужно сопоставить один плотный слой с пакетом данных. Это может выглядеть примерно так:
keys = vmap(random.fold_in, in_axes=(None, 0))(subkey, jnp.arange(s_observations))
model = nn.Dense(x_features, use_bias=False)
p = vmap(model.init, in_axes=(0, 1))(keys, s)
x = vmap(model.apply, in_axes=(0, 1), out_axes=1)(p, s)
pkernels = p['params']['kernel']
g = vmap(vmap(lambda a,b: a@b),in_axes=(0,None))
x_=g(s,pkernels)
print('sum of difference:',jnp.sum(x-x_))
# sum of difference: 0.0
Предыдущий ответ
В общем, исправлением будет определение одного параметризованного слоя, который вы можете передать vmap
. В приведенном вами примере все слои идентичны, поэтому для достижения желаемого результата вы можете написать что-то вроде этого:
def f(s,layer,dx):
return layer(s)
class net(nn.Module):
dx: int
do: int
def setup(self):
self.layer = nn.Dense( self.dx, use_bias=False )
def __call__(self, s):
x = vmap(f,in_axes=(0,None,None))(s,self.layer,self.dx)
return x
Если бы у вас были разные параметризации для каждого слоя, вы могли бы добиться этого в vmap
, передав эти параметры также в vmap
.
Я не уверен, что понимаю: в приведенном вами примере каждый слой определяется как nn.Dense(self.dx, use_bias=False)
с self.dx
константой для всех слоев. Где определяются различные параметры?
Привет, это вопрос о том, как настроить эти льняные объекты для использования автовекторизации, как это предусмотрено jax.vmap
; не вопрос о том, как использовать vmap
. Каждый экземпляр flax.linen.Dense
содержит атрибут, который является параметрами модели машинного обучения, и эти параметры модели машинного обучения специфичны для этих экземпляров Dense
. Таким образом, каждый из вызовов метода __call__
этих Dense
экземпляров будет работать с входными s[:,i]
и параметрами, скрытыми внутри каждого Dense
экземпляра.
Конечно, я это понимаю. Но в вашем примере кода ничего из этого нет, поэтому на ваш вопрос сложно ответить, поскольку основная часть является чисто гипотетической. Возможно, вы сможете обновить свой код, чтобы показать, как параметризуются ваши слои? От этих деталей будет зависеть ответ на ваш вопрос.
Привет. (1) Я нашел/исправил ошибку в функции f
, из-за которой переменная x
фактически не заполнялась :facepalm:. (2) В нижней части скрипта я включил дополнительную логику, которая перезаписывает параметры модели и весь класс net
как умножение матриц. Надеюсь, это сделает структуру моей проблемы менее гипотетической. Я намеревался переписать свою модель, используя эти льняные Dense
примеры, чтобы сделать ее более доступной для широкой аудитории, но этот разговор заставляет меня думать, что эти Dense
примеры не облегчают общение.
Я вижу обновленный код: не думаю, что это меняет мой ответ. Все слои в layers
идентичны, поэтому вместо того, чтобы перебирать их, вы можете создать один слой и vmap
поверх него. В тех случаях, когда слои различаются, вам следует включить разные параметры в vmap
. Если вы покажете пример, где каждый слой имеет разные параметры, я мог бы показать пример того, как это сделать.
если вы добавите строку print(p)
в конец скрипта, он покажет вам разные значения.
В приведенном мной примере существует экземпляр Dense для каждого элемента в этом внутреннем пакете размером
s_observations
, каждый с разными параметрами. Хотя они относятся к одному классу, у них разные значения параметров. Мне нужно будет обработать данные в этих пакетах с помощью дополнительных слоев или операций с независимыми параметрами. Было бы полезно реализовать единый конвейер обработки, с помощью которого я мог бы векторизовать свои данные. Возможно, моя ошибка здесь заключается в том, что я полагаюсь на экземпляр Dense, а не на создание экземпляра массива параметров и векторизацию по соответствующему размеру.