У меня есть тензор формы [bn, k, 2]
. Последнее измерение — это координаты, и я хочу, чтобы каждая партия сортировалась независимо в зависимости от координаты y ([:, :, 0]
). Мой подход выглядит примерно так:
import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]
print(a)
print(a_sorted)
Пока все хорошо, но теперь я сортирую обе партии по обоим спискам индексов, поэтому всего получаю 4 партии:
a
tensor([[[ 0.5160, 0.3257],
[-1.2410, -0.8361],
[ 1.3826, -1.1308],
[ 0.0338, 0.1665],
[-0.9375, -0.3081]],
[[ 0.4140, -1.0962],
[ 0.9847, -0.7231],
[-0.0110, 0.6437],
[-0.4914, 0.2473],
[-0.0938, -0.0722]]])
a_sorted
tensor([[[[-1.2410, -0.8361],
[-0.9375, -0.3081],
[ 0.0338, 0.1665],
[ 0.5160, 0.3257],
[ 1.3826, -1.1308]],
[[ 0.0338, 0.1665],
[-0.9375, -0.3081],
[ 1.3826, -1.1308],
[ 0.5160, 0.3257],
[-1.2410, -0.8361]]],
[[[ 0.9847, -0.7231],
[-0.0938, -0.0722],
[-0.4914, 0.2473],
[ 0.4140, -1.0962],
[-0.0110, 0.6437]],
[[-0.4914, 0.2473],
[-0.0938, -0.0722],
[-0.0110, 0.6437],
[ 0.4140, -1.0962],
[ 0.9847, -0.7231]]]])
Как видите, я хочу вернуть только 1-ю и 4-ю партии. Как мне это сделать?
Что вы хотите: объединение a[0, indices[0]]
и a[1, indices[1]]
.
Что вы закодировали: конкатенация a[0, indices]
и a[1, indices]
.
Проблема, с которой вы столкнулись, заключается в том, что индексы, возвращаемые sort
, имеют форму первого измерения, но значения являются только индексами второго измерения. Когда вы собираетесь использовать их, вы хотите сопоставить indices[0]
с a[0]
, но pytorch не делает этого неявно (потому что необычное индексирование очень мощное и нуждается в этом синтаксисе для его мощности). Итак, все, что вам нужно сделать, это дать параллельный список индексов для первого измерения.
т. е. вы хотите использовать что-то вроде: a[[[0], [1]], indices]
.
Чтобы немного обобщить это, вы можете использовать что-то вроде:
n = a.shape[0]
first_indices = torch.arange(n)[:, None]
a[first_indices, indices]
Это немного сложно, поэтому вот пример:
>>> a = torch.randn(2,4,2)
>>> a
tensor([[[-0.2050, -0.1651],
[ 0.5688, 1.0082],
[-1.5964, -0.9236],
[ 0.3093, -0.2445]],
[[ 1.0586, 1.0048],
[ 0.0893, 2.4522],
[ 2.1433, -1.2428],
[ 0.1591, 2.4945]]])
>>> indices = a[:, :, 0].sort()[1]
>>> indices
tensor([[2, 0, 3, 1],
[1, 3, 0, 2]])
>>> a[:, indices]
tensor([[[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.5688, 1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]],
[[[ 2.1433, -1.2428],
[ 1.0586, 1.0048],
[ 0.1591, 2.4945],
[ 0.0893, 2.4522]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]]])
>>> a[0, indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.5688, 1.0082],
[ 0.3093, -0.2445],
[-0.2050, -0.1651],
[-1.5964, -0.9236]]])
>>> a[1, indices]
tensor([[[ 2.1433, -1.2428],
[ 1.0586, 1.0048],
[ 0.1591, 2.4945],
[ 0.0893, 2.4522]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]])
>>> a[0, indices[0]]
tensor([[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]])
>>> a[1, indices[1]]
tensor([[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]])
>>> a[[[0], [1]], indices]
tensor([[[-1.5964, -0.9236],
[-0.2050, -0.1651],
[ 0.3093, -0.2445],
[ 0.5688, 1.0082]],
[[ 0.0893, 2.4522],
[ 0.1591, 2.4945],
[ 1.0586, 1.0048],
[ 2.1433, -1.2428]]])
Потрясающий!! Большое спасибо, я многому научился из вашего поста. Я уже предполагал, что что-то подобное должно быть решением, но я не знал, как это реализовать. Хорошего дня!