Keras вызывает model.fit, где x - это два кортежа np.ndarray

У меня есть регрессия tf.keras.Model, которая учитывает:

  • x: tuple[np.ndarray, np.ndarray], где два предмета имеют разную форму
    • Фигуры (128, 1152) и (1, 256)
  • y: float

У меня есть моя модель и обучение, кодифицированное следующим образом:

class MyModel(tf.keras.Model):

    def __init__(self):
        ...  # Omitted for brevity

    def call(self, inputs: tuple[tf.Tensor, tf.Tensor], training=None, mask=None):
        # Unpacks the two-tuple
        weights_1, weights_2 = inputs
        ...  # Omitted for brevity


# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float

model = MyModel()
model(inputs=datapoint_x)  # Works fine

Однако, когда я перехожу к fit модели, я получаю Exception:

>>> model.fit(x=datapoint_x, y=np.array(datapoint_y))

Traceback (most recent call last):
  File "/path/to/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-a5dfb3dd4846>", line 1, in <module>
    model.fit(x=datapoint_x, y=np.array(datapoint_y))
  File "/path/to/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/path/to/python3.10/site-packages/tensorflow/python/framework/tensor_shape.py", line 910, in __getitem__
    return self._dims[key]
IndexError: tuple index out of range

Я исследовал это, и self._dims это (), а key это 0.

Как правильно вызвать Model.fit в наборе данных с двумя кортежами x?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
53
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Ответ таков: Model.fit перебирает x и y, поэтому мне пришлось заранее добавить пакетное измерение к моим x[0] и y.

Это легко сделать с помощью np.newaxis или np.expand_dims.

import numpy as np

# NOTE: item 0's shape is (128, 1152), item 1's shape is (1, 256)
datapoint_x: tuple[np.ndarray, np.ndarray]
datapoint_y: float

# NOTE: now item 0's shape is (1, 128, 1152), and item 1's shape remains (1, 256)
batch_x = (datapoint_x[0][np.newaxis, :], datapoint_x[1])
# NOTE: now y's shape is (1,)
batch_y = np.array([datapoint_y])

model.fit(x=batch_x, y=batch_y)

Другие вопросы по теме