У меня есть регрессия tf.keras.Model
, которая учитывает:
x: tuple[np.ndarray, np.ndarray]
, где два предмета имеют разную форму
(128, 1152)
и (1, 256)
y: float
У меня есть моя модель и обучение, кодифицированное следующим образом:
class MyModel(tf.keras.Model):
def __init__(self):
... # Omitted for brevity
def call(self, inputs: tuple[tf.Tensor, tf.Tensor], training=None, mask=None):
# Unpacks the two-tuple
weights_1, weights_2 = inputs
... # Omitted for brevity
# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float
model = MyModel()
model(inputs=datapoint_x) # Works fine
Однако, когда я перехожу к fit
модели, я получаю Exception
:
>>> model.fit(x=datapoint_x, y=np.array(datapoint_y))
Traceback (most recent call last):
File "/path/to/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-5-a5dfb3dd4846>", line 1, in <module>
model.fit(x=datapoint_x, y=np.array(datapoint_y))
File "/path/to/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/path/to/python3.10/site-packages/tensorflow/python/framework/tensor_shape.py", line 910, in __getitem__
return self._dims[key]
IndexError: tuple index out of range
Я исследовал это, и self._dims
это ()
, а key
это 0
.
Как правильно вызвать Model.fit в наборе данных с двумя кортежами x?
Ответ таков: Model.fit
перебирает x и y, поэтому мне пришлось заранее добавить пакетное измерение к моим x[0]
и y
.
Это легко сделать с помощью np.newaxis
или np.expand_dims
.
import numpy as np
# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float
# NOTE: now item 0's shape is (1, 128, 1152), and item 1's shape remains (1, 256)
batch_x = (datapoint_x[0][np.newaxis, :], datapoint_x[1])
# NOTE: now y's shape is (1,)
batch_y = np.array([datapoint_y])
model.fit(x=batch_x, y=batch_y)