Пользовательская функция потерь Keras 3 для маскировки NaN

Я пытаюсь создать собственную функцию Loss в Keras 3, которая будет использоваться либо в jax, либо в бэкэнде torch. Я хочу замаскировать из y_pred и y_true все индексы, где y_true — определенное значение. Передача оставшихся значений в данную функцию loss_function.

Но каждый раз, когда я пытаюсь подогнать модель под свою функцию потерь как с помощью jax-сервера, так и с помощью Torch, она ломается, практически говоря, что она не может принять индексы или выполнить маскировку. Потому что для этого потребуется доступ к значениям тензора.

Я использую два способа:



import keras
from keras import Loss, ops



class NanValueLossA(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name = "nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        return self.loss_to_use(y_true[valid_mask], y_pred[valid_mask])
    


class NanValueLossB(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name = "nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        valid_indices = ops.where(valid_mask)
        masked_y_pred = ops.take(y_pred,valid_indices)
        masked_y_true = ops.take(y_true,valid_indices)

        return self.loss_to_use(masked_y_true, masked_y_pred)

Я пробовал эти две формы как в Jax, так и в Torch. Я пробовал еще пару способов, но проблема каждый раз одна и та же. Вот ошибки:

НаНвалуэлоссА: факел:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

Джекс:

  File "c:....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6976, in _expand_bool_indices
    raise errors.NonConcreteBooleanIndexError(abstract_i)
jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[32,1,128,128,1])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

НаНвалуэлоссБ: факел:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

Джекс:

  File "C:....\advanced_losses.py", line 651, in call
    valid_indices = ops.where(valid_mask)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1946, in where
    return nonzero(condition, size=size, fill_value=fill_value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2378, in nonzero
    calculated_size = core.concrete_dim_or_error(calculated_size,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function wrapped_fn at c:.....\Lib\site-packages\keras\src\backend\jax\core.py:153 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

До keras 3 я использовал функцию потерь на основе тензорного потока, и она работала, но теперь я хочу, чтобы что-то работало с факелом. Это была моя реализация тензорного потока:

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K




def nan_mean_squared_error_loss(nan_value=np.nan):
    # Create a loss function
    def loss(y_true, y_pred):
        indices = tf.where(tf.not_equal(y_true, nan_value))
        return tf.keras.losses.mean_squared_error(
            tf.gather_nd(y_true, indices), tf.gather_nd(y_pred, indices)
        )

    # Return a function
    return loss
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
53
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

В среде, которую я использовал, у меня не был установлен тензорный поток. Я имел:

керас 3.4.1 факел 2.3.1 торчаудио 2.3.1 факелвидение 0.18.1

Но я установил tensorflow для тестирования, и теперь работа с факелом работает. Я предполагаю, что были некоторые внутренние функции, которые Torch может получить из тензорного потока.

тензорный поток 2.16.2

Теперь решение NaNValueLossA работает!

Другие вопросы по теме