Это мой воспроизводимый код:
tf_ent = tf.Variable([ [9.96, 8.65, 0.99, 0.1 ],
[0.7, 8.33, 0.1 , 0.1 ],
[0.9, 0.1, 6, 7.33],
[6.60, 0.1, 3, 5.5 ],
[9.49, 0.2, 0.2, 0.2 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.3, 0.2, 5.82, 8.28]])
tf_ent_var = tf.constant([True, False, False, False, False, True, False])
Я хочу сохранить строки в tf_ent
, в которых соответствующие индексы в tf_ent_var
равны True, и сделать остальные строки минимальными во всей матрице.
поэтому ожидаемый результат будет таким:
[[9.96, 8.65, 0.99, 0.1 ],
[0.1, 0.1, 0.1 , 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.1, 0.1, 0.1, 0.1 ]]
Любая идея, как я могу это сделать?
Я пытался получить индексы из замаскированного тензора, а затем использовать tf.gather для выполнения этого, но индексы, которые я получаю, были такими [[0], [6]]
, что имеет смысл, потому что он давал индекс одного вектора.
@ImperishableNight, мне пришлось уменьшить матрицу и забыть заменить нули :|, я обновлю свой вопрос
Обновлено: для тензорного потока 1.x используйте:
val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
К сожалению, правила вещания не являются подмножеством правил 2.0 (что то же самое, что и numpy), а «просто разные». Tensorflow — не лучшая библиотека, когда речь идет о совместимости версий.
Основная идея состоит в том, чтобы использовать tf.where
, но вам нужно будет сначала изменить tf_ent_var
на тензор с формой (7, 1)
, чтобы тензорный поток знал, что он транслирует его по второй оси, а не по первой оси. Так:
val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var[:, tf.newaxis], tf_ent, val)
Конечно, вы также можете преобразовать его в (-1, 1)
, но я думаю, что нарезка с помощью tf.newaxis
короче и понятнее.
Вот мой интерактивный сеанс Python с 1.13.1 для устранения неполадок.
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 16:52:21)
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
2019-06-22 15:51:09.210852: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
>>> tf_ent = tf.Variable([ [9.96, 8.65, 0.99, 0.1 ],
... [0.7, 8.33, 0.1 , 0.1 ],
... [0.9, 0.1, 6, 7.33],
... [6.60, 0.1, 3, 5.5 ],
... [9.49, 0.2, 0.2, 0.2 ],
... [0.4, 8.45, 0.2, 0.2 ],
... [0.3, 0.2, 5.82, 8.28]])
WARNING:tensorflow:From /Users/REDACTED/Documents/test/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
>>>
>>> tf_ent_var = tf.constant([True, False, False, False, False, True, False])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> val = tf.math.reduce_min(tf_ent)
>>> tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
<tf.Tensor 'Select:0' shape=(7, 4) dtype=float32>
>>> _.eval()
array([[9.96, 8.65, 0.99, 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.4 , 8.45, 0.2 , 0.2 ],
[0.1 , 0.1 , 0.1 , 0.1 ]], dtype=float32)
>>> tf.__version__
'1.13.1'
Спасибо за ответ, хотя я получил эту ошибку «tensorflow.python.framework.errors_impl.InvalidArgumentError: входные данные для операции Select типа Select должны иметь одинаковый размер и форму. Вход 0: [7,1] != вход 1: [7,4] [Op:Select] '
О, я использую tensorflow 2.0.0-beta1 (и tensorflow убедился, что я знаю об этом, потому что каждая ошибка, которую он выдает, исходит от функции с именем something_v2
, например where_v2
и select_v2
). В tensorflow 1.x where
, вероятно, не такой гибкий. Попробую поискать решение в 1.x.
Я не уверен, стоит ли мне обновляться до бета-версии, дайте мне знать, если вы найдете какой-либо подход, работающий с tf.13.
значение val
не будет добавлено к тензору tf.zeros, поэтому на выходе не будет 0,1 в окончательной матрице
Я пытался обновить его, но похоже, что это связано с версией tensorflow, вы получили точно такой же вывод, который я поделился в вопросе?
Хм? Когда я тестировал, все работало нормально. Вы уверены, что ничего не перепутали? Моя версия тензорного потока для этого — 1.13.1.
min_mat = tf.broadcast_to(tf.reduce_min(tf_ent), tf_ent.shape)
output = tf.where(tf_ent_var, tf_ent, min_mat)
sess.run(output)
Вот моя реализация с использованием операторов tf.concat()
и if-else
. Это не так элегантно, как другой ответ, но работает:
import tensorflow as tf
tf.enable_eager_execution()
def slice_tensor_based_on_mask(tf_ent, tf_ent_var):
res = tf.fill([1, 4], 0.0)
min_value_tensor = tf.fill([1,int(tf_ent.shape[1])], tf.reduce_min(tf_ent))
for i in range(int(tf_ent.shape[0])):
if tf_ent_var[i:i+1].numpy()[0]: # true value in tf_ent_var
res = tf.concat([res, tf_ent[i:i+1]], 0)
else:
res = tf.concat([res, min_value_tensor], 0)
return res[1:]
tf_ent = tf.Variable([[9.96, 8.65, 0.99, 0.1 ],
[0.7, 8.33, 0.1 , 0.1 ],
[0.9, 0.1, 6, 7.33],
[6.60, 0.1, 3, 5.5 ],
[9.49, 0.2, 0.2, 0.2 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.3, 0.2, 5.82, 8.28]])
tf_ent_var = tf.constant([True, False, False, False, False, True, False])
print(slice_tensor_based_on_mask(tf_ent, tf_ent_var))
выход:
tf.Tensor(
[[9.96 8.65 0.99 0.1 ]
[0.1 0.1 0.1 0.1 ]
[0.1 0.1 0.1 0.1 ]
[0.1 0.1 0.1 0.1 ]
[0.1 0.1 0.1 0.1 ]
[0.4 8.45 0.2 0.2 ]
[0.1 0.1 0.1 0.1 ]], shape=(7, 4), dtype=float32)
Как «минимум во всей матрице» равен 0,1, если 0 существует как в нужных, так и в ненужных строках?