Я сохранил обученную политику с помощью заставки политики следующим образом:
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
tf_policy_saver.save(policy_dir)
Я хочу продолжить обучение с сохраненной политикой. Поэтому я попытался инициализировать обучение с сохраненной политикой, что вызвало некоторую ошибку.
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
agent.policy=tf.compat.v2.saved_model.load(policy_dir)
ОШИБКА:
File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')
File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute
Я просто хочу сэкономить время на переобучение с первого раза каждый раз. Как я могу загрузить сохраненную политику и продолжить обучение??
заранее спасибо
Вы должны проверить Checkpointer для этой цели.
Да, как указывалось ранее, для этого следует использовать Checkpointer. Взгляните на приведенный ниже пример кода.
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
... # Train the agent
# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())
Когда вы позже захотите перезагрузить политику, вы просто запускаете тот же код инициализации.
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
# then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
# Policy --> X
После создания policy_checkpointer
автоматически определит, существуют ли какие-либо ранее существовавшие контрольные точки. Если они есть, он автоматически обновит значение отслеживаемых переменных при создании.
Пара заметок, чтобы сделать:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
agent=tf_agent, # tf_agent.TFAgent
train_step=train_step, # tf.Variable
epoch_counter=epoch_counter, # tf.Variable
metrics=metric_utils.MetricsGroup(
train_metrics, 'train_metrics'))
policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
policy=agent.policy)
rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
max_to_keep=1,
replay_buffer=replay_buffer # TFUniformReplayBuffer
)
DqnAgent
agent.policy
и agent.collect_policy
по сути являются обертками вокруг QNetwork. Последствия этого показаны в приведенном ниже коде (посмотрите на комментарии к состоянию переменной политики).agent = DqnAgent(...)
policy = agent.policy # Random initial policy ---> X
dataset = replay_buffer.as_dataset(...)
for data in dataset:
experience, _ = data
loss_agent_info = agent.train(experience=experience)
# policy variable stores a trained Policy object ---> Y
Это происходит потому, что тензоры в TF используются во время выполнения. Поэтому, когда вы обновляете QNetwork
веса вашего агента с помощью agent.train
, те же самые веса будут неявно обновляться и в вашей policy
переменной QNetwork
. На самом деле дело не в том, что тензор policy
обновляется, а в том, что они просто такие же, как тензоры в вашем agent
.
Буфер воспроизведения не обязательно играет роль в сохранении модели, но если по какой-то причине вы хотите прервать обучение и продолжить его в другое время, вам следует сохранить буфер воспроизведения, так как он является частью процесса обучения. Если ваша единственная цель — обучить агента, а затем сохранить оптимальную политику, вы можете вместо этого использовать PolicySaver
(tensorflow.org/agents/api_docs/python/tf_agents/policies/…) и сохранить жадную политику агента.
Checkpointer сохранит состояние обучения, состояние политики и состояние replay_buffer. Я не видел, как сохранение буфера играет роль в сохранении модели. Но если моя цель состоит в том, чтобы в основном сохранить веса и восстановить веса, когда это необходимо, подойдет ли контрольный указатель?