Я пытаюсь выяснить рекомендуемый способ использования API dataset вместе с API estimator. Все, что я видел в сети, - это некоторые вариации этого:
def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset
который затем можно передать функции поезда оценки:
classifier.train(
input_fn=train_input_fn,
#...
)
но руководство по набору данных предупреждает, что:
the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.
а затем описывает метод, который включает определение заполнителей, которые затем заполняются feed_dict:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
sess.run(iterator.initializer, feed_dict = {features_placeholder: features,
labels_placeholder: labels})
Но если вы используете API estimator, вы не запускаете сеанс вручную. Так как же использовать API dataset с оценщиками, избегая проблем, связанных с from_tensor_slices()?






Чтобы использовать инициализируемые или повторно инициализируемые итераторы, необходимо создать класс, наследующий от tf.train.SessionRunHook, который имеет доступ к сеансу несколько раз на этапах обучения и оценки.
Затем вы можете использовать этот новый класс для инициализации итератора, как вы обычно делаете в классической настройке. Вам просто нужно передать этот вновь созданный хук функциям обучения / оценки или правильной спецификации поезда.
Вот быстрый пример, который вы можете адаптировать к своим потребностям:
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self, session, coord):
# Initialize the iterator with the data feed_dict
self.iterator_initializer_func(session)
def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)
dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict = {X_pl: X, y_pl: y})
return next_example, next_label
return input_fn, iterator_initializer_hook
...
train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
...
estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])