Обучение экстрактора пользовательских функций в стабильных базовых условиях3. Начиная с предварительно обученных весов?

Я использую следующий экстрактор пользовательских функций для своей модели StableBaselines3:

import torch.nn as nn
from stable_baselines3 import PPO

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim=2):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, embedding_dim),
            nn.ReLU()
        )
        self.regressor = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.regressor(x)
        return x
    
model = Encoder(input_dim, embedding_dim, hidden_dim)
model.load_state_dict(torch.load('trained_model.pth'))

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim):
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
        self.model = model  # Use the pre-trained model as the feature extractor

        self._features_dim = features_dim

    def forward(self, observations):
        features = self.model(observations)
        return features

policy_kwargs = {
        "features_extractor_class": CustomFeatureExtractor,
        "features_extractor_kwargs": {"features_dim": 64}
    }

 model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)

На данный момент модель хорошо обучена, без проблем и с хорошими результатами. Теперь я хочу не замораживать веса и попытаться обучить Feature Extractor, начиная с начального предварительно обученного веса. Как я могу сделать это с помощью такого специального экстрактора функций, определенного как класс внутри другого класса? Мой экстрактор функций отличается от описанного в документации , поэтому я не уверен, что он будет обучен. Или он начнет обучение, если я разморозлю слои?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
116
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

ОБНОВЛЕННЫЙ ответ

Поскольку ваш CustomFE уже импортирует кодировщик заморозки (с requires_grad = False), у вас возникает такая ситуация, когда все веса CustomFE заморожены. Таким образом, по умолчанию CustomFE не поддается обучению. Вам нужно будет разморозить его вручную:


model = PPO("MlpPolicy", env='FrozenLake8x8', policy_kwargs=policy_kwargs)

# get model feature extractor
feature_extr: CustomFeatureExtractor = model.policy.features_extractor

# convert all parameters to trainable
for name, param in feature_extr.named_parameters():
    param.requires_grad = True

# check parameters before training
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
    print(name, param.mean())


# train the model
model.learn(total_timesteps = 5)


# check parameters after training (if mean changed parameters are training)
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
    print(name, param.mean())

Спасибо за ответ, Джонни. Да, я это понимаю. Однако мой вопрос больше касался Stable Baselines3. Запускаем ли мы model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs) и обновляем Feature Encoder или нет? Или это только политика обучения и сети ценностей. Кроме того, мой класс Feature Extractor состоит из 2 классов, а не одного, как в документации: Stable-baselines3.readthedocs.io/en/v1.0/guide/…. Это проблема?

Sayyor Y 16.07.2024 15:29

Простой вызов `model = PPO(...)` не запускает процедуру обучения. А модель PPO наследует слои FE с их атрибутами. Позже во время обучения будет обучено все, что имеет градиенты. Я предполагаю, что вы заморозили слои FE перед началом процедуры обучения и, возможно, раньше, прежде чем объединить Encoder с PPO. Однако это неясно только из вашего фрагмента, поскольку ваш фрагмент не позволяет воспроизвести используемую вами процедуру обучения. Количество nn.modules в FE не имеет значения, поскольку модель PPO заботится только о выходных размерах (hidden_dims).

Johnny Cheesecutter 16.07.2024 15:40

Понятно, спасибо за продолжение! Меня беспокоило то, что CustomFeatureExtractor вызывает мою предварительно обученную сеть следующим образом: features = self.model(observations) поэтому я подумал, что она не может быть обучена, поскольку ее можно рассматривать как функцию. Но это не проблема?

Sayyor Y 16.07.2024 15:51

Вызывая self.model(observations), вы рассчитываете градиенты и, таким образом, позже обновляете веса. Можно сделать только вывод, если фрагмент находится в контексте with torch.no_grad():, но я сомневаюсь, что это ваш случай. В любом случае, если вы предоставите фрагмент кода для обучения, я тоже смогу это проверить.

Johnny Cheesecutter 16.07.2024 22:14

Понял, теперь стало гораздо яснее. Определив модель, как указано выше, для обучения я просто бегу model.learn().

Sayyor Y 17.07.2024 14:40

@SayyorY Я обновил свой ответ. На самом деле вы были правы, вам нужно разморозить параметры FE, иначе их невозможно будет обучить. Проверьте новый ответ.

Johnny Cheesecutter 17.07.2024 15:47

Спасибо за ваши усилия! @Джонни Чизкаттер

Sayyor Y 17.07.2024 16:12

Другие вопросы по теме