У меня есть тензор a
с плавающими элементами и torch.Size([64,2])
, а также тензор b
с torch.Size([64])
. Записи b
только 0
или 1
.
Я хотел бы получить новый тензор c
с torch.Size([64])
таким, что c[i] == a[i,b[i]]
для каждого индекса i. Как я могу это сделать?
Моя попытка
Я пытался с torch.gather
, но безуспешно. Следующий код дает мне RuntimeError: Index tensor must have the same number of dimensions as input tensor
import torch
a = torch.zeros([64,2])
b = torch.ones(64).long()
torch.gather(input=a, dim=1,index=b)
Любая помощь будет высоко оценена!
Не уверен, что понимаю ваш вопрос, но я думаю, что вы можете перебрать свой тензор
a = torch.zeros([64,2])
b = torch.ones(64).long()
c = torch.empty([64])
for i, _ in enumerate(a):
c[i] = a[i,b[i]]
c
Вы можете выполнить это прямо с индексацией a
в обоих измерениях:
На dimension=0
: «последовательная» индексация с использованием torch.arange
.
На dimension=1
: индексирование с помощью b
.
В целом это дает:
>>> a[torch.arange(len(a)), b]
В качестве альтернативы вы можете использовать torch.gather
, операция, которую вы ищете:
# c[i] == a[i,b[i]]
Предоставленная операция сбора при применении к dim=1
дает что-то вроде:
# c[i,j] == a[i,b[i,j]]
Как видите, нам нужно учитывать разницу в формах между a
и b
. Для этого вы можете распаковать одноэлементное измерение в b
(обозначенное буквой j
выше), например, #b=(64, 1)
, например, с помощью b.unsqueeze(-1)
или b[...,None]
:
>>> a.gather(dim=1, index=b[...,None]).flatten()
Большое спасибо!!! Теперь я понимаю, чего мне не хватало, чтобы использовать сбор.. :)