Избегайте tf.data.Dataset.from_tensor_slices с помощью api оценки

Я пытаюсь выяснить рекомендуемый способ использования API dataset вместе с API estimator. Все, что я видел в сети, - это некоторые вариации этого:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset

который затем можно передать функции поезда оценки:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )

но руководство по набору данных предупреждает, что:

the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

а затем описывает метод, который включает определение заполнителей, которые затем заполняются feed_dict:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict = {features_placeholder: features,
                                          labels_placeholder: labels})

Но если вы используете API estimator, вы не запускаете сеанс вручную. Так как же использовать API dataset с оценщиками, избегая проблем, связанных с from_tensor_slices()?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
5
0
1 192
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Чтобы использовать инициализируемые или повторно инициализируемые итераторы, необходимо создать класс, наследующий от tf.train.SessionRunHook, который имеет доступ к сеансу несколько раз на этапах обучения и оценки.

Затем вы можете использовать этот новый класс для инициализации итератора, как вы обычно делаете в классической настройке. Вам просто нужно передать этот вновь созданный хук функциям обучения / оценки или правильной спецификации поезда.

Вот быстрый пример, который вы можете адаптировать к своим потребностям:

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict = {X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

Другие вопросы по теме