Сохранять веса моделей в конце каждых N эпох

Я тренирую NN и хочу сохранять веса модели каждые N эпох для фазы прогнозирования. Я предлагаю этот черновик кода, он вдохновлен ответом @grovina здесь. Не могли бы вы внести предложения? Заранее спасибо.

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs = {}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

Затем добавьте его к подходящему вызову: чтобы сохранять веса каждые 5 эпох:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

Спасибо @Ghilas_BELHADJ :)

Belkacem Thiziri 05.07.2018 13:23
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
10
1
13 600
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

Вам не нужно передавать модель для обратного вызова. У него уже есть доступ к модели через супер. Поэтому удалите аргумент __init__(..., model, ...) и self.model = model. В любом случае вы сможете получить доступ к текущей модели через self.model. Вы также сохраняете его на каждом конце партии, а это не то, что вам нужно, вы, вероятно, хотите, чтобы это был on_epoch_end.

Но в любом случае то, что вы делаете, можно сделать через наивный обратный вызов modelcheckpoint. Вам не нужно писать собственный. Вы можете использовать это следующим образом;

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])

Спасибо за ответ, работает хорошо :) именно то, что я ищу.

Belkacem Thiziri 05.07.2018 11:23

Вы должны реализовать on_epoch_end, а не on_batch_end. А также передача модели в качестве аргумента для __init__ является избыточной.

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs = {}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1

Другие вопросы по теме