Не могу загрузить сохраненную политику (TF-агенты)

Я сохранил обученную политику с помощью заставки политики следующим образом:

  tf_policy_saver = policy_saver.PolicySaver(agent.policy)
  tf_policy_saver.save(policy_dir)

Я хочу продолжить обучение с сохраненной политикой. Поэтому я попытался инициализировать обучение с сохраненной политикой, что вызвало некоторую ошибку.

agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)

agent.initialize()

agent.policy=tf.compat.v2.saved_model.load(policy_dir)

ОШИБКА:

  File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')


File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
    super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute

Я просто хочу сэкономить время на переобучение с первого раза каждый раз. Как я могу загрузить сохраненную политику и продолжить обучение??

заранее спасибо

Udacity Nanodegree Capstone Project: Классификатор пород собак
Udacity Nanodegree Capstone Project: Классификатор пород собак
Вы можете ознакомиться со скриптами проекта и данными на github .
3
0
983
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Вы должны проверить Checkpointer для этой цели.

Ответ принят как подходящий

Да, как указывалось ранее, для этого следует использовать Checkpointer. Взгляните на приведенный ниже пример кода.

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)

... # Train the agent

# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())

Когда вы позже захотите перезагрузить политику, вы просто запускаете тот же код инициализации.

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
#               then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)
# Policy --> X

После создания policy_checkpointer автоматически определит, существуют ли какие-либо ранее существовавшие контрольные точки. Если они есть, он автоматически обновит значение отслеживаемых переменных при создании.

Пара заметок, чтобы сделать:

  1. Вы можете сэкономить с помощью контрольного указателя гораздо больше, чем просто политику, и я действительно рекомендую это делать. Объект Checkpointer TF-Agent чрезвычайно гибок, например:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
                                         agent=tf_agent,               # tf_agent.TFAgent
                                         train_step=train_step,        # tf.Variable
                                         epoch_counter=epoch_counter,  # tf.Variable
                                         metrics=metric_utils.MetricsGroup(
                                                 train_metrics, 'train_metrics'))

policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
                                          policy=agent.policy)

rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
                                      max_to_keep=1,
                                      replay_buffer=replay_buffer  # TFUniformReplayBuffer
                                      )
  1. Обратите внимание, что в случае DqnAgentagent.policy и agent.collect_policy по сути являются обертками вокруг QNetwork. Последствия этого показаны в приведенном ниже коде (посмотрите на комментарии к состоянию переменной политики).
agent = DqnAgent(...)
policy = agent.policy      # Random initial policy ---> X

dataset = replay_buffer.as_dataset(...)
for data in dataset:
   experience, _ = data
   loss_agent_info = agent.train(experience=experience)

# policy variable stores a trained Policy object ---> Y

Это происходит потому, что тензоры в TF используются во время выполнения. Поэтому, когда вы обновляете QNetwork веса вашего агента с помощью agent.train, те же самые веса будут неявно обновляться и в вашей policy переменной QNetwork. На самом деле дело не в том, что тензор policy обновляется, а в том, что они просто такие же, как тензоры в вашем agent.

Checkpointer сохранит состояние обучения, состояние политики и состояние replay_buffer. Я не видел, как сохранение буфера играет роль в сохранении модели. Но если моя цель состоит в том, чтобы в основном сохранить веса и восстановить веса, когда это необходимо, подойдет ли контрольный указатель?

user3656142 01.04.2021 01:03

Буфер воспроизведения не обязательно играет роль в сохранении модели, но если по какой-то причине вы хотите прервать обучение и продолжить его в другое время, вам следует сохранить буфер воспроизведения, так как он является частью процесса обучения. Если ваша единственная цель — обучить агента, а затем сохранить оптимальную политику, вы можете вместо этого использовать PolicySaver (tensorflow.org/agents/api_docs/python/tf_agents/policies/…) и сохранить жадную политику агента.

Federico Malerba 02.04.2021 09:43

Другие вопросы по теме