У меня есть n
сетей, каждая с одинаковым вводом/выводом. Я хочу случайным образом выбрать один из выходов в соответствии с категориальным распределением. Tfp.Категорический выводит только целые числа, и я попытался сделать что-то вроде
act_dist = tfp.distributions.Categorical(logits=act_logits) # act_logits are all the same, so the distribution is uniform
rand_out = act_dist.sample()
x = nn_out1 * tf.cast(rand_out == 0., dtype=tf.float32) + ... # for all my n networks
Но rand_out == 0.
всегда ложно, как и другие условия.
Любая идея для достижения того, что мне нужно?
Я думаю, вам нужно использовать tf.equal, потому что Tensor == 0 всегда False.
Однако отдельно вы можете использовать OneHotCategorical. Для обучения вы также можете попробовать использовать RelaxedOneHotCategorical.
Вы также можете взглянуть на MixtureSameFamily, который собирает для вас под одеялом.
nn_out1 = tf.expand_dims(nn_out1, axis=2)
...
outs = tf.concat([nn_out1, nn_nout2, ...], axis=2)
probs = tf.tile(tf.reduce_mean(tf.ones_like(nn_out1), axis=1, keepdims=True) / n, [1, n]) # trick to have ones of shape [None,1]
dist = tfp.distributions.MixtureSameFamily(
mixture_distribution=tfp.distributions.Categorical(probs=probs),
components_distribution=tfp.distributions.Deterministic(loc=outs))
x = dist.sample()
Большое спасибо! Я не знал этот класс, и это именно то, что мне нужно. Я отредактировал ваш ответ с помощью кода, который я написал, используя MixtureSameFamily.
Один вопрос. Если вероятности являются функцией одного и того же ввода компонентов, и я хочу ее обучить, похоже, я не могу. Я читал, что в TF градиент не распространяется обратно через целые числа, и, например, он не проходит через one_hot
и gather
. Поскольку вывод Categorical
является целым числом, может ли это быть причиной?
Да, я знаю, почему это ложь. Я ищу способ сказать, чтобы он сравнивался с выводом функции
output
, а не с самой функцией. В чем преимущество использования OneHotCategorical? Я никогда не задумывался об этом, и, судя по описанию, в основном то же самое для моей цели.