у меня есть тензор
import torch
a = torch.randn(1, 3, requires_grad=True)
print('a: ', a)
>>> a: tensor([[0.0200, 1.00200, -4.2000]], requires_grad=True)
И маска
mask = torch.zeros_like(a)
mask[0][0] = 1
Я хочу замаскировать свой тензор a
, не распространяя градиенты на тензор маски (в моем реальном случае у него есть градиент). Я попробовал следующее
with torch.no_grad():
b = a * mask
print('b: ', b)
>>> b: tensor([[0.0200, 0.0000, -0.0000]])
Но он полностью удаляет градиент из моего тензора. Как правильно это сделать?
Вы можете вызвать detach
на тензоре маски, чтобы удалить его из цепочки градиентов.
a = torch.randn(1, 3, requires_grad=True)
mask = torch.tensor([[1., 0., 0.]], requires_grad=True)
mask_no_grad = mask.detach()
b = a * mask_no_grad
print(b)
> tensor([[0.3871, 0.0000, -0.0000]], grad_fn=<MulBackward0>)
Что вы подразумеваете под «распространяется на маскирующий тензор»? Сам тензор маски не имеет градиента.
Я выше написал "в моем реальном случае градиент есть"
Что ж, вы сделали это очень запутанным, явно создав a
с градиентом и mask
без него в своем примере. Я обновил ответ, чтобы показать использование detach
, чтобы удалить маску из цепочки градиентов.
Но затем он распространяется на мой тензор маскировки, чего я не хочу.