Как динамически обновлять норму партии в TF2?

Я нашел реализацию PyTorch, которая уменьшает параметр пакетной нормы momentum с 0.1 в первой эпохе до 0.001 в последней эпохе. Любые предложения о том, как сделать это с параметром нормы партии momentum в TF2? (т. е. начинать с 0.9 и заканчивать на 0.999). Например, вот что делается в коде PyTorch:

# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)

# model class function
def set_bn_momentum(self, momentum):
    self.expand_bn.momentum = momentum
    for bn in self.layers_bn:
        bn.momentum = momentum

РЕШЕНИЕ:

Выбранный ниже ответ обеспечивает жизнеспособное решение при использовании tf.keras.Model.fit() API. Тем не менее, я использовал пользовательский цикл обучения. Вот что я сделал вместо этого:

После каждой эпохи:

mi = 1 - initial_momentum  # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum  # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)

Функция set_bn_momentum (спасибо этой статье):

def set_bn_momentum(model, momentum):
    for layer in model.layers:
        if hasattr(layer, 'momentum'):
            print(layer.name, layer.momentum)
            setattr(layer, 'momentum', momentum)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
    model.save_weights(tmp_weights_path)

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model

Этот метод не добавлял значительных накладных расходов в программу обучения.

Не понятно, что ты задумал. Можете ли вы показать код pytorch, который делает именно то, что вы хотите?

Innat 10.12.2020 13:15
Udacity Nanodegree Capstone Project: Классификатор пород собак
Udacity Nanodegree Capstone Project: Классификатор пород собак
Вы можете ознакомиться со скриптами проекта и данными на github .
0
1
379
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Вы можете установить действие в начале/конце каждого пакета, чтобы вы могли контролировать любой параметр в течение эпохи.

Ниже варианты обратных вызовов:

class CustomCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

Вы можете получить доступ к импульсу

batch = tf.keras.layers.BatchNormalization()
batch.momentum = 0.001

Внутри модели вы должны указать правильный слой

model.layers[1].momentum = 0.001

Дополнительную информацию и примеры можно найти на странице writing_your_own_callbacks

Спасибо. Знаете ли вы, как получить доступ к параметру импульса пакетной нормы вашей модели или, что еще лучше, к конкретным слоям в обратном вызове?

wmcnally 11.12.2020 14:06

Я думаю, вы можете получить доступ через self.model.layers

wmcnally 11.12.2020 14:08

Извините, я прочитал импульс, но я написал скорость обучения, я обновил свой ответ.

Fernando Silva 11.12.2020 15:25

Другие вопросы по теме