Невозможно загрузить двунаправленную модель LSTM из файла .keras

Я впервые работаю с tensorflow и keras, и у меня возникли проблемы с сохранением и загрузкой моделей. Я понял, что это конкретно связано с двунаправленным слоем LSTM, но не знаю, что делать дальше.

Вот мой код:

model = Sequential()

model.add(Bidirectional(LSTM(128, activation='tanh'))) ### problem line
model.add(Dropout(0.2))
model.add(BatchNormalization())

model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(1, activation='sigmoid'))


optimizer = Adam(learning_rate=0.001)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)

model.fit(X_train_scaled, y_train, batch_size=32, epochs=3, validation_data=(X_test_scaled, y_test), callbacks=[early_stopping, reduce_lr])
model.summary()

model.save('models/sample_model.keras')
model = load_model('models/sample_model.keras')

Загрузка модели приводит к такой ошибке:

ValueError: A total of 1 objects could not be loaded. Example error message for object <LSTMCell name=lstm_cell, built=True>:

Layer 'lstm_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias']

List of objects that could not be loaded:
[<LSTMCell name=lstm_cell, built=True>]

Я попытался добавить input_shape в двунаправленные параметры, но возникла та же ошибка. Я также пытался сохранить файл .h5, а не .keras, но безуспешно. Я просмотрел документацию по keras и подумал, что делаю именно так, как они предлагали, но, должно быть, я где-то напутал.

Я даже комментировал по одной строке, сохранял и загружал модели. Ошибка возникает только тогда, когда я включаю двунаправленный режим в слой LSTM. Наличие слоя LSTM само по себе не вызывало ошибки.

Есть идеи, почему это происходит?

Обновлено: я использую tensorflow 2.16.1 и keras 3.2.0.

вы не даете полезный минимальный воспроизводимый пример, вы не предоставляете никаких примеров входных или выходных данных, мы не можем помочь, потому что вы не предоставили нам достаточно информации.

UpAndAdam 09.04.2024 22:26
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
1
373
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Кажется, это проблема с версией tensorflow 2.16.1. При обновлении до ночной версии вы можете загрузить модель, но не веса.

Формат .keras похож на zip-файл и содержит архитектуру модели в файле config.json. Попытка построить модель из этого файла json с использованием tf.keras.models.model_from_json() также приводит к этой ошибке, поэтому я подозреваю, что существует проблема с тем, как обрабатывается json в этой версии.

Даже в ночной версии я получаю эту ошибку при выполнении model.load_weights('model.weights.h5')

Ответ принят как подходящий

Единственное решение, которое сработало для меня, - вернуться к tensorflow 2.15.1 и keras 2.15.0, и теперь оно работает так, как задумано.

Другие вопросы по теме