Модель обучения на Colab TPU с распределенной стратегией

Я пытаюсь обучить и запустить модель классификации изображений в Colab, используя TPU. Нет питорча.

Я знаю, что TPU работает только с файлами из корзин GCS, поэтому я загружаю набор данных из корзины, а также прокомментировал функции контрольной точки и ведения журнала, чтобы не было ошибок такого типа. Я просто хочу, чтобы он тренировался без ошибок на ТПУ.

На CPU и GPU код работает, но проблема возникает, когда я использую with strategy.scope(): перед созданием модели. Это функция, которая доставляет мне проблемы при обучении модели:

def train_step(self, images, labels):
    with tf.GradientTape() as tape:
        predictionProbs = self(images, training=True)
        loss = self.loss_fn(labels, predictionProbs)

    grads = tape.gradient(loss, self.trainable_weights)

    predictionLabels = tf.squeeze(tf.cast(predictionProbs > PROB_THRESHOLD_POSITIVE, tf.float32), axis=1)
    acc = tf.reduce_mean(tf.cast(predictionLabels == labels, tf.float32))

    self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) # here is the problem

    return loss, acc

И это ошибка, с которой я сталкиваюсь:

RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context.

Я посмотрел на https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy и я думаю, что вот решение, но я никогда не делал этого раньше, и я не знаю, откуда Я могу начать.

Может кто-нибудь, пожалуйста, дать мне совет по этой проблеме?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
338
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Вы должны вызвать эту процедуру с помощью Strategy.run():

strategy.run(train_step, args=(images, labels))

Спасибо чувак. Мне просто нужно было поставить @tf.function перед функцией, и это сработало.

AndreiV6 11.12.2020 13:26

Другие вопросы по теме