У меня есть tf.Estimator, model_fn
которого содержит tf.Variable, инициализированный значением 1.0. Я хотел бы изменять значение переменной в каждую эпоху в зависимости от точности набора разработчика. Для этого я реализовал SessionRunHook
, но когда я пытаюсь изменить значение, я получаю следующую ошибку:
raise RuntimeError("Graph is finalized and cannot be modified.")
Это код для крючка:
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
with tf.variable_scope("lambda_scope", reuse=True):
lambda_tensor = tf.get_variable("lambda_value")
tf.assign(lambda_tensor, self.gamma_value)
self.gamma_value += 0.1
Я понимаю, что график завершается, когда я запускаю ловушку, но я хотел бы знать, есть ли другой способ изменить значение переменной в графе model_fn с помощью API-интерфейса Estimator во время обучения.
То, как ваш хук настроен прямо сейчас, вы, по сути, пытаетесь создавать новые переменные / операции после каждого запуска сеанса. Вместо этого вы должны заранее определить операцию tf.assign
и передать ее ловушке, чтобы она могла запускать операцию сама, если это необходимо, или определить операцию назначения в __init__
ловушки. Вы можете получить доступ к сеансу внутри after_run
через аргумент run_context
. Так что-то вроде
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op)
self.gamma += 0.1
Здесь есть некоторые предостережения. Во-первых, я не уверен, можно ли сделать tf.assign
с таким целым числом Python, то есть будет ли он правильно обновляться после изменения gamma
. Если это не сработает, вы можете попробовать следующее:
class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.gamma_placeholder = tf.placeholder(tf.float32, [])
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)
def before_run(self, run_context):
self.steps += 1
def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op, feed_dict = {self.gamma_placeholder: self.gamma})
self.gamma += 0.1
Здесь мы используем дополнительный заполнитель, чтобы иметь возможность передавать «текущую» гамму операции назначения в любое время.
Во-вторых, поскольку хукам нужен доступ к переменным, вам нужно будет определить ловушку внутри функции модели. Вы можете передать такие хуки в тренировочный процесс в EstimatorSpec
(см. здесь).
Спасибо, это была хорошая идея. Думаю, об этом стоит упомянуть в документации Tensorflow. Я добавил свои операции в функцию __init __ () и запустил их в before_run (), и это сработало.