Я пытаюсь создать собственную функцию Loss в Keras 3, которая будет использоваться либо в jax, либо в бэкэнде torch. Я хочу замаскировать из y_pred и y_true все индексы, где y_true — определенное значение. Передача оставшихся значений в данную функцию loss_function.
Но каждый раз, когда я пытаюсь подогнать модель под свою функцию потерь как с помощью jax-сервера, так и с помощью Torch, она ломается, практически говоря, что она не может принять индексы или выполнить маскировку. Потому что для этого потребуется доступ к значениям тензора.
Я использую два способа:
import keras
from keras import Loss, ops
class NanValueLossA(Loss):
def __init__(
self,
loss_to_use=None,
nan_value=None,
name = "nan_value_loss",
**kwargs,
):
self.nan_value = nan_value
self.loss_to_use=loss_to_use
super().__init__(name=name, **kwargs)
def call(self, y_true, y_pred):
valid_mask = ops.not_equal(y_true, self.nan_value)
return self.loss_to_use(y_true[valid_mask], y_pred[valid_mask])
class NanValueLossB(Loss):
def __init__(
self,
loss_to_use=None,
nan_value=None,
name = "nan_value_loss",
**kwargs,
):
self.nan_value = nan_value
self.loss_to_use=loss_to_use
super().__init__(name=name, **kwargs)
def call(self, y_true, y_pred):
valid_mask = ops.not_equal(y_true, self.nan_value)
valid_indices = ops.where(valid_mask)
masked_y_pred = ops.take(y_pred,valid_indices)
masked_y_true = ops.take(y_true,valid_indices)
return self.loss_to_use(masked_y_true, masked_y_pred)
Я пробовал эти две формы как в Jax, так и в Torch. Я пробовал еще пару способов, но проблема каждый раз одна и та же. Вот ошибки:
НаНвалуэлоссА: факел:
File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
x = x.to(device)
^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
Джекс:
File "c:....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6976, in _expand_bool_indices
raise errors.NonConcreteBooleanIndexError(abstract_i)
jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[32,1,128,128,1])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
НаНвалуэлоссБ: факел:
File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
x = x.to(device)
^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
Джекс:
File "C:....\advanced_losses.py", line 651, in call
valid_indices = ops.where(valid_mask)
^^^^^^^^^^^^^^^^^^^^^
File "....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1946, in where
return nonzero(condition, size=size, fill_value=fill_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2378, in nonzero
calculated_size = core.concrete_dim_or_error(calculated_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function wrapped_fn at c:.....\Lib\site-packages\keras\src\backend\jax\core.py:153 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
До keras 3 я использовал функцию потерь на основе тензорного потока, и она работала, но теперь я хочу, чтобы что-то работало с факелом. Это была моя реализация тензорного потока:
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
def nan_mean_squared_error_loss(nan_value=np.nan):
# Create a loss function
def loss(y_true, y_pred):
indices = tf.where(tf.not_equal(y_true, nan_value))
return tf.keras.losses.mean_squared_error(
tf.gather_nd(y_true, indices), tf.gather_nd(y_pred, indices)
)
# Return a function
return loss
В среде, которую я использовал, у меня не был установлен тензорный поток. Я имел:
керас 3.4.1 факел 2.3.1 торчаудио 2.3.1 факелвидение 0.18.1
Но я установил tensorflow для тестирования, и теперь работа с факелом работает. Я предполагаю, что были некоторые внутренние функции, которые Torch может получить из тензорного потока.
тензорный поток 2.16.2
Теперь решение NaNValueLossA работает!