Какова альтернатива pytorch module.register_parameter(name, param) в tensorflow?

Я пытаюсь преобразовать код pytorch в tensorflow. В коде pytorch они добавляют дополнительный параметр к каждому модулю модели, используя module.register_parameter(name, param). Как я могу скрыть эту часть кода в тензорном потоке?

Пример кода ниже:

for module_name, module in self.model.named_modules():
    module.register_parameter(name, new_parameter)
     
Анализ настроения постов в Twitter с помощью Python, Tweepy и Flair
Анализ настроения постов в Twitter с помощью Python, Tweepy и Flair
Анализ настроения текстовых сообщений может быть настолько сложным или простым, насколько вы его сделаете. Как и в любом ML-проекте, вы можете выбрать...
7 лайфхаков для начинающих Python-программистов
7 лайфхаков для начинающих Python-программистов
В этой статье мы расскажем о хитростях и советах по Python, которые должны быть известны разработчику Python.
Установка Apache Cassandra на Mac OS
Установка Apache Cassandra на Mac OS
Это краткое руководство по установке Apache Cassandra.
Сертификатная программа "Кванты Python": Бэктестер ансамблевых методов на основе ООП
Сертификатная программа "Кванты Python": Бэктестер ансамблевых методов на основе ООП
В одном из недавних постов я рассказал о том, как я использую навыки количественных исследований, которые я совершенствую в рамках программы TPQ...
Создание персонального файлового хранилища
Создание персонального файлового хранилища
Вы когда-нибудь хотели поделиться с кем-то файлом, но он содержал конфиденциальную информацию? Многие думают, что электронная почта безопасна, но это...
Создание приборной панели для анализа данных на GCP - часть I
Создание приборной панели для анализа данных на GCP - часть I
Недавно я столкнулся с интересной бизнес-задачей - визуализацией сбоев в цепочке поставок лекарств, которую могут просматривать врачи и...
1
0
25
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

tf.Variable — это эквивалент nn.Parameter в PyTorch. tf.Variable в основном используется для хранения параметров модели, так как их значения постоянно обновляются во время обучения.

Чтобы использовать тензор в качестве нового параметра модели, вам нужно преобразовать его в tf.Variable. Вы можете проверить здесь, как создавать переменные из тензоров.

Если вы хотите добавить параметр модели в TensorFlow внутри самой модели, вы можете просто создать переменную внутри класса модели, и она будет автоматически зарегистрирована TensorFlow как параметр модели.

Если вы хотите добавить tf.Variableвнешне к модели в качестве параметра модели, вы можете вручную добавить его к атрибуту обучаемые_весаtf.keras.layers.Layer с помощью расширение, например, это -

model.layers[-1].trainable_weights.extend([new_parameter])

Спасибо за Ваш ответ. Я обнаружил, что могу использовать layer.add_weight(), чтобы добавить новую переменную в слой. Это эквивалентно module.register_parameter() на pytorch?

Al Shahreyaj 17.05.2022 12:42

Другие вопросы по теме