У меня есть функция, которая может принимать typing.Union
типов, включая тип torch.float
. Но если я использую тип typing.Union
с torch.float
в качестве аргумента, я получаю сообщение об ошибке. Вот пример:
from typing import Union
import torch
def fct(my_float_or_tensor: Union[torch.float, torch.Tensor]):
pass
И я получаю ошибку
TypeError: Union[t0, t1, ...]: each t must be a type. Got torch.float32.
Что я делаю не так?
Интересно, что та же проблема возникает со специальным типом typing.Tuple
, но не в том случае, если я использую torch.float
напрямую при подсказке типа.
Есть разница между "dtypes" и "types". torch.float
есть dtype
. Для подсказки типа используйте torch.FloatTensor
(есть и другие, например, DoubleTensor
, HalfTensor
и т. д.)