В статье OpenAI Five упоминается, что «градиенты дополнительно обрезаются для каждого параметра, чтобы они находились в пределах ± 5√v, где v — текущая оценка второго момента (неотсеченного) градиента». Это то, что я хотел бы реализовать в своем проекте, но я не уверен, как это сделать ни в теории, ни на практике.
Из них я бы предположил, что √v — это экспоненциальное скользящее среднее стандартного dev. градиентов и может быть рассчитано с помощью:
estimate = alpha * torch.std(list(param.grad for param in model.parameters())) + (1-alpha) * estimate
Верна ли моя теория? Есть ли лучший способ сделать это? Заранее спасибо.
Обновлено: исправлен сбор градиента после ответа Mr. For Example.
Я думаю, что вы на правильном пути, мое предположение в основном такое же, как и у вас, только немного отличается.
Во-первых, что такое момент?
N-й момент случайной величины определяется как ожидаемое значение этой переменной в степени n. Более формально:
m — момент, X — случайная величина
Таким образом, первый момент — это среднее значение, а второй момент — это нецентрированная дисперсия (это означает, что мы не вычитаем среднее значение при вычислении дисперсии), интуитивно понятно, что отсечение градиентов с помощью скользящего среднего значения его стандартного отклонения относительно нуля имеет смысл.
Во-вторых, какой правильный код?
list(network.parameters())
дайте вам только параметры, чтобы получить градиент каждого параметра, который вам нужен [param.grad for param in network.parameters()]
Учитывая все то, что мы знаем выше, правильный код должен быть (вы можете попытаться оптимизировать его любыми способами):
grads_square = torch.FloatTensor([torch.square(param.grad) for param in network.parameters()])
estimate = alpha * torch.sqrt(torch.mean(grads_square)) + (1-alpha) * estimate
Спасибо, я намеревался получить градиенты именно так, как вы советовали. Но я не смог сделать это со своим мозгом "30 минут после ночи ума". Отредактирую мой вопрос только об этом!
Нет проблем, рад видеть такой хороший вопрос, хорошего дня :)
Чем больше я думаю над вашим ответом, тем больше я думаю, что вы правы. Однако sqrt
в вашем коде кажется немного преждевременным, так как нам нужно сначала обновить estimate
и взять только квадратный корень из обновленной оценки.
Я только что понял, что ваш код сохраняет оценку √*v*, а не v, что, как подсказывает моя интуиция, так же хорошо. Это просто не соответствует нотации статьи. А теперь еще и понял, что мой вопрос тоже не в лад с ним.
Ну, я думал, это то, что вы хотите, но вы можете изменить это, просто удалив torch.sqrt()
в холоде
Этот вопрос относится не только к pytorch, но и к другой более общей области, такой как машинное обучение, я предлагаю вам добавить к нему больше тегов.