Как восстановить контрольную точку orbax с помощью jax/льна?

Я сохранил контрольную точку orbax с помощью кода ниже:

check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

Когда я пытаюсь возобновить работу с сохраненных контрольных точек, я использовал приведенный ниже код для восстановления переменной state:

state, lr_schedule = init_train_state(model, params['params'], learning_rate, weight_decay, beta1, beta2, decay_lr, warmup_iters, 
                     lr_decay_iters, min_lr)  # Here state is the initialied state variable with type Train_state.
state = checkpoint_manager.restore(checkpoint_manager.latest_step(), items = {'state': state})

Но когда я пытаюсь использовать восстановленное состояние в цикле обучения, я получаю следующую ошибку:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:584, in shaped_abstractify(x)
    583 try:
--> 584   return _shaped_abstractify_handlers[type(x)](x)
    585 except KeyError:

KeyError: <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[40], line 37
     34 if iter_num == 0 and eval_only:
     35     break
---> 37 state, loss = train_step(state, get_batch('train'))
     39 # timing and logging
     40 t1 = time.time()

    [... skipping hidden 6 frame]

File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:575, in _shaped_abstractify_slow(x)
    573   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
    574 else:
--> 575   raise TypeError(
    576       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    577       "does not have a dtype attribute")
    578 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    579                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'> as an abstract array; it does not have a dtype attribute

Итак, как мне правильно восстановить контрольную точку state и использовать ее в цикле обучения?

Спасибо!

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
292
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Вы смешиваете старый и новый API таким образом, что это запрещено. Приносим извинения, что ошибка по этому поводу не возникает, я могу разобраться в этом.

Ваша экономия верна, но я бы рекомендовал, чтобы она выглядела примерно так:

with ocp.CheckpointManager(path, options=options, item_names=('state', 'metadata')) as mngr:
  mngr.save(
      step, 
      args=ocp.args.Composite(
          state=ocp.args.StandardSave(state),
          metadata=ocp.args.JsonSave(...),
      )
  )

При восстановлении вы в настоящее время используете items, который является частью старого API, и его использование не соответствует определению CheckpointManager, которое сделано на основе нового API.

item_names и args являются отличительными чертами нового API.

Ты должен сделать:

with ocp.CheckpointManager(...) as mngr:
  mngr.restore(
      mngr.latest_step(), 
      args=ocp.args.Composite(
          state=ocp.args.StandardRestore(abstract_state),
      )
  )

Дайте мне знать, если с этим возникнут какие-либо непредвиденные проблемы.

Спасибо, что изучили это! Что такое абстрактное_состояние здесь? Это просто экземпляр state, который сообщает функции восстановления, какой формат мы хотим иметь? Кроме того, при восстановлении должно быть StandardRestore, а не StandardSave? Спасибо!

Dmitry J 26.04.2024 21:42
abstract_state просто относится к некоторой древовидной структуре с той же структурой, что и ваша контрольная точка, которая имеет значения типа jax.ShapeDtypeStruct. Это позволяет вам передавать такие свойства, как форма, dtype и сегментирование массивов, когда вы хотите, чтобы они были восстановлены. Если вы хотите восстановить какой-либо пользовательский PyTree (например, flax.struct.dataclass, а не простой вложенный словарь), убедитесь, что abstract_state принадлежит к тому же классу. И да, должно было быть StandardRestore, обновил ответ.
Colin Gaffney 30.04.2024 00:00

Другие вопросы по теме