Я использую следующий экстрактор пользовательских функций для своей модели StableBaselines3:
import torch.nn as nn
from stable_baselines3 import PPO
class Encoder(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim=2):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, embedding_dim),
nn.ReLU()
)
self.regressor = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
)
def forward(self, x):
x = self.encoder(x)
x = self.regressor(x)
return x
model = Encoder(input_dim, embedding_dim, hidden_dim)
model.load_state_dict(torch.load('trained_model.pth'))
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
class CustomFeatureExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim):
super(CustomFeatureExtractor, self).__init__(observation_space, features_dim)
self.model = model # Use the pre-trained model as the feature extractor
self._features_dim = features_dim
def forward(self, observations):
features = self.model(observations)
return features
policy_kwargs = {
"features_extractor_class": CustomFeatureExtractor,
"features_extractor_kwargs": {"features_dim": 64}
}
model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)
На данный момент модель хорошо обучена, без проблем и с хорошими результатами. Теперь я хочу не замораживать веса и попытаться обучить Feature Extractor, начиная с начального предварительно обученного веса. Как я могу сделать это с помощью такого специального экстрактора функций, определенного как класс внутри другого класса? Мой экстрактор функций отличается от описанного в документации , поэтому я не уверен, что он будет обучен. Или он начнет обучение, если я разморозлю слои?
ОБНОВЛЕННЫЙ ответ
Поскольку ваш CustomFE уже импортирует кодировщик заморозки (с requires_grad = False
), у вас возникает такая ситуация, когда все веса CustomFE заморожены. Таким образом, по умолчанию CustomFE не поддается обучению. Вам нужно будет разморозить его вручную:
model = PPO("MlpPolicy", env='FrozenLake8x8', policy_kwargs=policy_kwargs)
# get model feature extractor
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
# convert all parameters to trainable
for name, param in feature_extr.named_parameters():
param.requires_grad = True
# check parameters before training
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
print(name, param.mean())
# train the model
model.learn(total_timesteps = 5)
# check parameters after training (if mean changed parameters are training)
feature_extr: CustomFeatureExtractor = model.policy.features_extractor
encoder = feature_extr.model.encoder
for name, param in encoder[0].named_parameters():
print(name, param.mean())
Простой вызов `model = PPO(...)` не запускает процедуру обучения. А модель PPO наследует слои FE с их атрибутами. Позже во время обучения будет обучено все, что имеет градиенты. Я предполагаю, что вы заморозили слои FE перед началом процедуры обучения и, возможно, раньше, прежде чем объединить Encoder с PPO. Однако это неясно только из вашего фрагмента, поскольку ваш фрагмент не позволяет воспроизвести используемую вами процедуру обучения. Количество nn.modules в FE не имеет значения, поскольку модель PPO заботится только о выходных размерах (hidden_dims).
Понятно, спасибо за продолжение! Меня беспокоило то, что CustomFeatureExtractor
вызывает мою предварительно обученную сеть следующим образом: features = self.model(observations)
поэтому я подумал, что она не может быть обучена, поскольку ее можно рассматривать как функцию. Но это не проблема?
Вызывая self.model(observations)
, вы рассчитываете градиенты и, таким образом, позже обновляете веса. Можно сделать только вывод, если фрагмент находится в контексте with torch.no_grad():
, но я сомневаюсь, что это ваш случай. В любом случае, если вы предоставите фрагмент кода для обучения, я тоже смогу это проверить.
Понял, теперь стало гораздо яснее. Определив модель, как указано выше, для обучения я просто бегу model.learn()
.
@SayyorY Я обновил свой ответ. На самом деле вы были правы, вам нужно разморозить параметры FE, иначе их невозможно будет обучить. Проверьте новый ответ.
Спасибо за ваши усилия! @Джонни Чизкаттер
Спасибо за ответ, Джонни. Да, я это понимаю. Однако мой вопрос больше касался Stable Baselines3. Запускаем ли мы
model = PPO("MlpPolicy", env=envs, policy_kwargs=policy_kwargs)
и обновляем Feature Encoder или нет? Или это только политика обучения и сети ценностей. Кроме того, мой класс Feature Extractor состоит из 2 классов, а не одного, как в документации: Stable-baselines3.readthedocs.io/en/v1.0/guide/…. Это проблема?