У меня есть два массива x
и y
одинаковой формы (B 1 N
).
x
представляет данные, а y
представляет, какому классу (от 1
до C
) принадлежит каждая точка данных в x
.
Я хочу создать новый тензор z
(с формой B C
), где
x
разделены на каналы в зависимости от их классов в y
N
Я могу добиться этого, если использую горячее кодирование. Однако для больших тензоров (особенно с большим количеством классов) горячее кодирование PyTorch быстро использует всю память графического процессора.
Есть ли более эффективный с точки зрения использования памяти способ осуществить эту трансляцию без явного выделения тензора B C N
?
Вот MWE того, что мне нужно:
import torch
B, C, N = 2, 10, 1000
x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))
one_hot = torch.nn.functional.one_hot(y, C) # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1) # B C N
z = x * one_hot # B C N
z = z.sum(-1) # B C
Если z
— желаемый выходной тензор, то вам придется так или иначе выделить BxCxN
в памяти. Альтернативное решение — расширить значения x
и y
и разбросать в нулевой тензор:
>>> x, y = x.expand(-1,C,-1), y.expand(-1,C,-1)
>>> z = torch.zeros(B,C,N).scatter_(1,y,x).sum(-1)
Вы можете проверить сами, но этот подход, похоже, требует меньше памяти.
Обновлено: если вы хотите потом уменьшить N
, то C
не нужен. Поскольку вы использовали горячее кодирование, стандартной операции разброса без сокращения будет достаточно. Кроме того, дополнительные синглтоны не нужны, поэтому предположим, что x
и y
оба являются BxN
:
>>> z = torch.zeros(B,C).scatter_(1,y,x)
Думаю, reduce = "add"
тоже надо включить, но это невероятно! Я понятия не имел, что scatter
такое бывает. Спасибо!
Или вместо этого используйте .scatter_add_
. Кажется, аргумент reduce
будет признан устаревшим.
Вы использовали горячую кодировку, поэтому в заданной позиции не может быть размещено более одного элемента. Именно поэтому режим понижения не требуется, поэтому scatter
следует делать в одиночку.
Последний тензор, который мне нужен, на самом деле
z.sum(-1)
имеет формуBxC
(я обновлю вопрос). Есть ли более быстрый способ добраться до этого?