Как использовать разные слои свертки в разных ветвях tf.map_fn?

Я попытался установить простой слой внимания с несколькими головками в tensorflow1.14. Каждая голова содержит три разных conv1d слоя. И я хочу использовать tf.map_fn для параллельных вычислений.

import tensorflow as tf

n_head = 50 # heads counts

conv1d = tf.layers.conv1d
normalize = tf.contrib.layers.instance_norm 
activation = tf.nn.elu

f1d = tf.placeholder(shape=(None, 42),dtype=tf.float32) # input feats

f1ds = tf.tile(f1d[None, ...], [n_head, 1, 1]) # n_head copys to apply different attention heads

def apply_attention(f1):
    f1 = activation(normalize(f1[None, ...]))
    q = conv1d(f1, 32, 3, padding='same')
    k = conv1d(f1, 32, 3, padding='same')  # [1,ncol, 32]
    v = conv1d(f1, 1, 3, padding='same')  # [1,ncol, 1]
    attention_map = tf.nn.softmax(tf.reduce_sum(q[0, None, :, :] * k[0, :, None, :], axis=-1) / (32 ** .5),
                                  axis=0)  # [ncol,ncol]
    return attention_map * v[0]

f1d_attention = tf.map_fn(lambda x: apply_attention(x), f1ds, dtype=tf.float32) 

Но когда я обнаруживаю переменные в этой модели, кажется, что во всей модели есть только одна группа слоев conv1d.

conv1d/bias/Adam [32]
conv1d/bias/Adam_1 [32]
conv1d/kernel [3, 42, 32]
conv1d/kernel/Adam [3, 42, 32]
conv1d/kernel/Adam_1 [3, 42, 32]
conv1d_1/bias [32]
conv1d_1/bias/Adam [32]
conv1d_1/bias/Adam_1 [32]
conv1d_1/kernel [3, 42, 32]
conv1d_1/kernel/Adam [3, 42, 32]
conv1d_1/kernel/Adam_1 [3, 42, 32]
conv1d_2/bias [1]
conv1d_2/bias/Adam [1]
conv1d_2/bias/Adam_1 [1]
conv1d_2/kernel [3, 42, 1]
conv1d_2/kernel/Adam [3, 42, 1]
conv1d_2/kernel/Adam_1 [3, 42, 1]

Что не так с моим кодом?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
168
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

В этом случае вы не можете использовать tf.map_fn. tf.map_fn оценит вашу функцию один раз и пропустит разные входные данные через одну и ту же функцию, эффективно используя одни и те же слои свертки для каждого входа.

Вы можете добиться того, чего хотите, с помощью простого цикла for:

# Creating a different set of conv for each head
multi_head = [apply_attention(f1d) for _ in range(n_head)]
# stacking the result together on the first axis
f1d_attention = tf.stack(multi_head, axis=0)

Я уменьшил количество головок до 2 для наглядности, но если мы посмотрим на переменные, мы увидим, что были созданы 2 группы свертки.

>>> tf.global_variables()
[<tf.Variable 'InstanceNorm/beta:0' shape=(42,) dtype=float32_ref>,
 <tf.Variable 'InstanceNorm/gamma:0' shape=(42,) dtype=float32_ref>,
 <tf.Variable 'conv1d/kernel:0' shape=(3, 42, 32) dtype=float32_ref>,
 <tf.Variable 'conv1d/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'conv1d_1/kernel:0' shape=(3, 42, 32) dtype=float32_ref>,
 <tf.Variable 'conv1d_1/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'conv1d_2/kernel:0' shape=(3, 42, 1) dtype=float32_ref>,
 <tf.Variable 'conv1d_2/bias:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'InstanceNorm_1/beta:0' shape=(42,) dtype=float32_ref>,
 <tf.Variable 'InstanceNorm_1/gamma:0' shape=(42,) dtype=float32_ref>,
 <tf.Variable 'conv1d_3/kernel:0' shape=(3, 42, 32) dtype=float32_ref>,
 <tf.Variable 'conv1d_3/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'conv1d_4/kernel:0' shape=(3, 42, 32) dtype=float32_ref>,
 <tf.Variable 'conv1d_4/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'conv1d_5/kernel:0' shape=(3, 42, 1) dtype=float32_ref>,
 <tf.Variable 'conv1d_5/bias:0' shape=(1,) dtype=float32_ref>]

Боковое примечание: если у вас нет действительно веской причины, вам следует отказаться от TensorFlow 1 и вместо этого использовать TensorFlow 2. Поддержка TF1 ограничена.

Другие вопросы по теме