Срез на основе замаскированного тензора в тензорном потоке

Это мой воспроизводимый код:

tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
                         [0.7,     8.33,    0.1  ,   0.1   ],
                         [0.9,     0.1,     6,       7.33],
                         [6.60,    0.1,     3,       5.5 ],
                         [9.49,    0.2,     0.2,     0.2   ],
                         [0.4,     8.45,    0.2,     0.2 ],
                         [0.3,     0.2,     5.82,    8.28]])

tf_ent_var = tf.constant([True, False, False, False, False, True, False])

Я хочу сохранить строки в tf_ent, в которых соответствующие индексы в tf_ent_var равны True, и сделать остальные строки минимальными во всей матрице.

поэтому ожидаемый результат будет таким:

                    [[9.96,    8.65,    0.99,   0.1 ],
                     [0.1,     0.1,     0.1  ,  0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.4,     8.45,    0.2,      0.2 ],
                     [0.1,     0.1,     0.1,    0.1 ]]

Любая идея, как я могу это сделать?

Я пытался получить индексы из замаскированного тензора, а затем использовать tf.gather для выполнения этого, но индексы, которые я получаю, были такими [[0], [6]], что имеет смысл, потому что он давал индекс одного вектора.

Как «минимум во всей матрице» равен 0,1, если 0 существует как в нужных, так и в ненужных строках?

Imperishable Night 22.06.2019 23:22

@ImperishableNight, мне пришлось уменьшить матрицу и забыть заменить нули :|, я обновлю свой вопрос

sariii 22.06.2019 23:29
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
2
359
3
Перейти к ответу Данный вопрос помечен как решенный

Ответы 3

Ответ принят как подходящий

Обновлено: для тензорного потока 1.x используйте:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)

К сожалению, правила вещания не являются подмножеством правил 2.0 (что то же самое, что и numpy), а «просто разные». Tensorflow — не лучшая библиотека, когда речь идет о совместимости версий.


Основная идея состоит в том, чтобы использовать tf.where, но вам нужно будет сначала изменить tf_ent_var на тензор с формой (7, 1), чтобы тензорный поток знал, что он транслирует его по второй оси, а не по первой оси. Так:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var[:, tf.newaxis], tf_ent, val)

Конечно, вы также можете преобразовать его в (-1, 1), но я думаю, что нарезка с помощью tf.newaxis короче и понятнее.


Вот мой интерактивный сеанс Python с 1.13.1 для устранения неполадок.

Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 16:52:21) 
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
2019-06-22 15:51:09.210852: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
>>> tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
...                          [0.7,     8.33,    0.1  ,   0.1   ],
...                          [0.9,     0.1,     6,       7.33],
...                          [6.60,    0.1,     3,       5.5 ],
...                          [9.49,    0.2,     0.2,     0.2   ],
...                          [0.4,     8.45,    0.2,     0.2 ],
...                          [0.3,     0.2,     5.82,    8.28]])
WARNING:tensorflow:From /Users/REDACTED/Documents/test/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
>>> 
>>> tf_ent_var = tf.constant([True, False, False, False, False, True, False])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> val = tf.math.reduce_min(tf_ent)
>>> tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
<tf.Tensor 'Select:0' shape=(7, 4) dtype=float32>
>>> _.eval()
array([[9.96, 8.65, 0.99, 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.4 , 8.45, 0.2 , 0.2 ],
       [0.1 , 0.1 , 0.1 , 0.1 ]], dtype=float32)
>>> tf.__version__
'1.13.1'

Спасибо за ответ, хотя я получил эту ошибку «tensorflow.python.framework.errors_impl.InvalidArgumentErro‌​r: входные данные для операции Select типа Select должны иметь одинаковый размер и форму. Вход 0: [7,1] != вход 1: [7,4] [Op:Select] '

sariii 22.06.2019 23:56

О, я использую tensorflow 2.0.0-beta1 (и tensorflow убедился, что я знаю об этом, потому что каждая ошибка, которую он выдает, исходит от функции с именем something_v2, например where_v2 и select_v2). В tensorflow 1.x where, вероятно, не такой гибкий. Попробую поискать решение в 1.x.

Imperishable Night 22.06.2019 23:59

Я не уверен, стоит ли мне обновляться до бета-версии, дайте мне знать, если вы найдете какой-либо подход, работающий с tf.13.

sariii 23.06.2019 00:09

значение val не будет добавлено к тензору tf.zeros, поэтому на выходе не будет 0,1 в окончательной матрице

sariii 23.06.2019 00:49

Я пытался обновить его, но похоже, что это связано с версией tensorflow, вы получили точно такой же вывод, который я поделился в вопросе?

sariii 23.06.2019 00:52

Хм? Когда я тестировал, все работало нормально. Вы уверены, что ничего не перепутали? Моя версия тензорного потока для этого — 1.13.1.

Imperishable Night 23.06.2019 00:53
min_mat = tf.broadcast_to(tf.reduce_min(tf_ent), tf_ent.shape)
output = tf.where(tf_ent_var, tf_ent, min_mat)
sess.run(output)

Вот моя реализация с использованием операторов tf.concat() и if-else. Это не так элегантно, как другой ответ, но работает:

import tensorflow as tf
tf.enable_eager_execution()

def slice_tensor_based_on_mask(tf_ent, tf_ent_var):
    res = tf.fill([1, 4], 0.0)  
    min_value_tensor = tf.fill([1,int(tf_ent.shape[1])], tf.reduce_min(tf_ent))

    for i in range(int(tf_ent.shape[0])):
        if tf_ent_var[i:i+1].numpy()[0]: # true value in tf_ent_var
            res = tf.concat([res, tf_ent[i:i+1]], 0)
        else:
            res = tf.concat([res, min_value_tensor], 0)
    return res[1:]

tf_ent = tf.Variable([[9.96,    8.65,    0.99,   0.1 ],
                     [0.7,     8.33,    0.1  ,   0.1 ],
                     [0.9,     0.1,     6,       7.33],
                     [6.60,    0.1,     3,       5.5 ],
                     [9.49,    0.2,     0.2,     0.2 ],
                     [0.4,     8.45,    0.2,     0.2 ],
                     [0.3,     0.2,     5.82,    8.28]])

tf_ent_var = tf.constant([True, False, False, False, False, True, False])
print(slice_tensor_based_on_mask(tf_ent, tf_ent_var))

выход:

tf.Tensor(
[[9.96 8.65 0.99 0.1 ]
 [0.1  0.1  0.1  0.1 ]
 [0.1  0.1  0.1  0.1 ]
 [0.1  0.1  0.1  0.1 ]
 [0.1  0.1  0.1  0.1 ]
 [0.4  8.45 0.2  0.2 ]
 [0.1  0.1  0.1  0.1 ]], shape=(7, 4), dtype=float32)

Другие вопросы по теме