Я пытаюсь создать собственный обучающий код для модели TensorFlow Keras.
Поскольку нам нужно разделить train_step (который обновляет веса модели) и test_step (который делает только вывод), интересно, можем ли мы создать одну функцию, которая будет работать для обоих.
Моя идея состоит в том, чтобы написать что-то вроде этого.
def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
with tf.GradientTape() as tape1 'if is_training': # Enable or disable the with statement by condition
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
with tf.GradientTape() as tape2 'if is_training': # Enable or disable the with statement by condition
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
if is_training:
model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)
model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)
return model1_loss, model2_loss
def train_step(self, inputs):
inputs = inputs, (gt1, gt2)
return self._train_or_evaluate(inputs, gt1, gt2, True)
def test_step(self, inputs):
inputs = inputs, (gt1, gt2)
return self._train_or_evaluate(inputs, gt1, gt2, False)
Есть ли способ указать условие для включения или отключения оператора with при сохранении работы блока? Поэтому, когда условие равно False, блок продолжает работать так же, как и без оператора with.
Такой, что:
Когда self._train_or_evaluate(inputs, gt1, gt2, True) это эквивалентно:
def _train(self, inputs, gt1, gt2):
with tf.GradientTape() as tape1:
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
with tf.GradientTape() as tape2:
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
model1_gradients = tape1.gradient(model1_loss, self.model1.trainable_variables)
self.optimizer1.apply_gradients(model1_gradients, self.model1.trainable_variables)
model2_gradients = tape2.gradient(model2_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(model2_gradients, self.model2.trainable_variables)
return model1_loss, model2_loss
И когда self._train_or_evaluate(inputs, gt1, gt2, False) это эквивалентно:
def _evaluate(self, inputs, gt1, gt2):
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
inputs2 = self.process_output(model1_output)
model2_output = self.model2(inputs2)
model2_loss = self.loss_obj2(gt2, model2_output)
return model1_loss, model2_loss
Спасибо.
@ShadowRanger, учитывая, что tape1 и tape2 нигде в коде не используются, я не уверен, что это имеет какое-то значение. Я очень смущен этим вопросом.
@MarkRansom: Мое предположение состоит в том, что их создание каким-то образом изменяет поведение содержащегося кода (во многом подобно контекстам модуля decimal, управляемым операторами with), но я могу придавать здесь слишком большое значение.
нет, даже не написав класс с очень умным (или запутанным) __enter__ модулем, можно пропустить выполнение блока with. Просто используйте блок if, вложив в него оператор ẁith` (или наоборот)
@MarkRansom @ShadowRange: Да, верно. Когда код написан внутри блока tf.GradientTape(), он будет отслеживать поток градиентов переменных. И это потребует больше вычислительных ресурсов. Он используется, когда вызывается tape.gradient(). Поэтому мне интересно, можем ли мы эффективно включить или отключить оператор with, учитывая условие в python. Я обновил вопрос, чтобы привести более четкие примеры моих ожиданий.






Как я уже говорил в комментарии выше, независимо от того, как вы настраиваете метод __enter__ класса, который вызывается оператором with, невозможно пропустить сам блок with.
Тем не менее, простое вложение дополнительного блока if нежелательно (или вложение with в if, если на то пошло), можно полностью эмулировать блок with, используя try/finally - и тогда вы можете использовать блок if вместо with. Тем не менее, это все равно потребует дополнительного уровня отступа из-за того, что необходимы блоки if и try/finally.
def _train_or_evaluate(self, inputs, gt1, gt2, is_training=False):
if is_training:
tape1 = (cm:=tf.GradientTape()).__enter__()
try:
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
finally:
if is_training:
cm.__exit__(*sys.exc_info())
...
# second with block and rest of the function
(обратите внимание, что я использую оператор walrus ( :=), чтобы сохранить ссылку на экземпляр, из которого оператор with будет вызывать методы __.enter__ и __exit__. Обычно в блоке with об этом не нужно заботиться)
Ваш заказ здесь немного не тот; cm должен быть создан вне try, поэтому, если tf.GradientTape() вызывает исключение, он не пытается получить доступ к cm в finally (я забыл, следует ли вызывать __enter__ внутри try при точном моделировании with; я так не думаю, но я слишком устал, чтобы искать это). Кроме того, вы, вероятно, имели в виду cm.__exit__, а не cm.exit.
если быть точным __enter__ должно быть и снаружи (исключение внутри __enter__ не проходит __exit__ в том же экземпляре).- Я только что отредактировал - спасибо.
В любом случае, код более высокого уровня должен использовать ExitStack, как в вашем ответе - я оставлю это опубликованным, поскольку он описывает «под капотом» with и может помочь людям, попавшим сюда, понять вещи, и может быть меньше вещей, которые нужно обернуть. -загляните в более простые сценарии.
Вы можете сделать это с помощью грамотного использования contextlib.ExitStack для разделения with и создания управляемого объекта, чтобы управляемый объект можно было создавать и управлять им условно. Для первого условного with вы должны заменить его на:
with contextlib.ExitStack() as stack:
if is_training:
tape1 = stack.enter_context(tf.GradientTape())
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
Опустите tape1 =, если вам на самом деле не нужно, чтобы он был привязан к чему-то с именем tape1 (поскольку вы не используете и не можете использовать tape1 благодаря условному определению, это, вероятно, не нужно).
Когда is_training ложно, ExitStack в основном не работает; он создается и удаляется, и, поскольку он ничем не управляет, очистка не выполняется. Когда is_training соответствует действительности, экземпляр GradientTape создается и немедленно регистрируется в ExitStack, поэтому, когда with завершается, он очищается как обычно.
Если вы хотите получить немного орехов, вы можете использовать ExitStack в качестве строительного блока для условного with создания (фактически задокументировано, что он предназначен для работы в качестве такого строительного блока):
@contextlib.contextmanager
def conditional_ctx(condition, callable, *args, **kwargs):
with contextlib.ExitStack() as stack:
yield stack.enter_context(callable(*args, **kwargs)) if condition else None
который после определения позволяет вам реализовать исходный код с помощью:
with conditional_ctx(is_training, tf.GradientTape) as tape1:
model1_output = self.model1(inputs)
model1_loss = self.loss_obj(gt1, model1_output)
Для одного или двух применений это не такая уж большая экономия, но если вы делаете это во многих местах, это экономит часть работы.
Вы хотите отключить создание
tape2? Или управление ею? Или все содержимое блока? В принципе, какие линии полностью отключены?