У меня возникли проблемы с запуском PyTorch 2.2 с TPU в Google Colab. Я получаю сообщение об ошибке JAX, но меня это смущает, потому что я ничего не делаю с JAX.
Мой процесс настройки очень прост:
!pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
А потом
import torch
import torch_xla.core.xla_model as xm
что дает ошибку
/usr/local/lib/python3.10/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: KeyError('')
This a JAX bug; please report an issue at https://github.com/google/jax/issues
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
Затем пытаюсь
t1 = torch.tensor(100, device=xm.xla_device())
t2 = torch.tensor(200, device=xm.xla_device())
print(t1 + t2)
выдает ошибку
2 frames
/usr/local/lib/python3.10/dist-packages/torch_xla/runtime.py in xla_device(n, devkind)
121
122 if n is None:
--> 123 return torch.device(torch_xla._XLAC._xla_get_default_device())
124
125 devices = xm.get_xla_supported_devices(devkind=devkind)
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: No ba16c7433 device found.






В настоящее время Colab предоставляет только TPU старого поколения, которое несовместимо с последними выпусками JAX или PyTorch. Вполне возможно, что в будущем это может измениться, но я не знаю официальных сроков, когда это может произойти. А пока вы можете получить доступ к TPU последнего поколения через Kaggle или Google Cloud.