В чем разница между сохранением на основе имени и на основе объекта в Tensorflow Eager Execution?
TensorFlow традиционно использует глобальные имена для переменных, чтобы сопоставить значения контрольных точек с переменными на графике. В основном это свойство переменной .name
:
import tensorflow as tf
tf.enable_eager_execution()
dense = tf.keras.layers.Dense(2)
dense(tf.ones([1, 1]))
print(dense.variables[0].name)
Печать:
dense/kernel:0
Это имя, которое tf.train.Saver
записывает в контрольную точку, и ключ, который он использует для сопоставления восстановленного значения. Он хорошо работает, когда программа Python содержит одну модель TensorFlow или когда построение модели является изолированным (как в случае с tf.estimator.Estimator
, который строит модель, которую он обертывает с нуля в новом Graph
).
Объектно-ориентированная контрольная точка, tf.contrib.eager.Checkpoint
/ tfe.Checkpoint
, направлена на то, чтобы сделать сопоставление этой переменной более надежным, когда программа Python изменяется или когда несколько моделей TensorFlow используются в одной программе Python. Он делает это путем построения графа зависимостей объектов с именованными ребрами и сохранения его с контрольной точкой:
(визуализация из пример нетерпеливого GAN; черные узлы - это объекты слоя, синий - переменные, красный - оптимизаторы, а оранжевый - переменные слота, созданные оптимизаторами)
Эти именованные зависимости создаются автоматически при назначении атрибутов объекту Checkpointable
, включая tf.keras.Model
. Например self.conv1 = layers.Conv2D(...)
создает край зависимости с именем "conv1", когда self
- это tf.keras.Model
..
При восстановлении должна совпадать структура модели (объекты и их именованные края), не обязательно точные имена переменных.
Возвращаясь к слою Dense
, мы можем создать для него контрольную точку, а затем восстановить ее во втором объекте, имена переменных которого не совпадают:
import tensorflow.contrib.eager as tfe
save_checkpoint = tfe.Checkpoint(dense=dense)
dense.variables[0].assign([[1., 2.]])
save_path = save_checkpoint.save("/tmp/tensorflow/ckpt")
# save_path = "/tmp/tensorflow/ckpt-1"
Затем при восстановлении все еще в той же программе:
second_dense = tf.keras.layers.Dense(2)
restore_checkpoint = tfe.Checkpoint(dense=second_dense)
restore_checkpoint.restore(save_path)
second_dense(tf.ones([1, 1]))
print(second_dense.variables[0])
Печать:
<tf.Variable 'dense_1/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[1., 2.]], dtype=float32)>
Значение [[1., 2.]]
было восстановлено до использования слоем Dense
(восстановление-при-создании), несмотря на то, что оно имело другое имя (dense_1/kernel
вместо dense/kernel
).
Хотя это особенно полезно при активном выполнении, объектно-ориентированное сохранение работает и при построении графов. Просто добавьте run_restore_ops()
:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
dense = tf.keras.layers.Dense(2)
dense(tf.ones([1, 1]))
save_checkpoint = tfe.Checkpoint(dense=dense)
assign_op = tf.group(dense.variables[0].assign([[1., 2.]]),
dense.variables[1].assign([3., 4.]))
second_dense = tf.keras.layers.Dense(2)
restore_checkpoint = tfe.Checkpoint(dense=second_dense)
second_dense(tf.ones([1, 1]))
with tf.Session() as session:
session.run(assign_op)
save_path = save_checkpoint.save("/tmp/tensorflow/ckpt")
restore_checkpoint.restore(save_path).assert_consumed().run_restore_ops()
print(session.run(second_dense.variables[0]))
Печать:
[[1. 2.]]
Полезные ресурсы:
tfe.Checkpoint
: https://www.tensorflow.org/api_docs/python/tf/contrib/eager/Checkpointtfe.Checkpointable
, который управляет зависимостями между объектами: https://www.tensorflow.org/api_docs/python/tf/contrib/eager/CheckpointableДовольно приятно, что не нужно указывать входное измерение, которое вместе с выходным размером определяет форму переменной в большинстве слоев. Я считаю, что это основная причина наличия отдельной фазы сборки (которую можно вызвать вручную). Также позволяет настроить слой, а затем добавить переменные в график, в котором он вызывается.
Спасибо за ваш ответ. Я заметил, что если вы не запускаете плотный (tf.ones ([1, 1])), плотный будет не содержать переменных, что сильно отличается от Pytorch, почему он спроектирован таким образом? Только после звонка тогда может слой содержать переменные?