Как отсортировать трехмерный тензор по координатам в последнем измерении (pytorch)

У меня есть тензор формы [bn, k, 2]. Последнее измерение — это координаты, и я хочу, чтобы каждая партия сортировалась независимо в зависимости от координаты y ([:, :, 0]). Мой подход выглядит примерно так:

import torch
a = torch.randn(2, 5, 2)
indices = a[:, :, 0].sort()[1]
a_sorted = a[:, indices]

print(a)
print(a_sorted)

Пока все хорошо, но теперь я сортирую обе партии по обоим спискам индексов, поэтому всего получаю 4 партии:

a
tensor([[[ 0.5160,  0.3257],
         [-1.2410, -0.8361],
         [ 1.3826, -1.1308],
         [ 0.0338,  0.1665],
         [-0.9375, -0.3081]],

        [[ 0.4140, -1.0962],
         [ 0.9847, -0.7231],
         [-0.0110,  0.6437],
         [-0.4914,  0.2473],
         [-0.0938, -0.0722]]])

a_sorted
tensor([[[[-1.2410, -0.8361],
          [-0.9375, -0.3081],
          [ 0.0338,  0.1665],
          [ 0.5160,  0.3257],
          [ 1.3826, -1.1308]],

         [[ 0.0338,  0.1665],
          [-0.9375, -0.3081],
          [ 1.3826, -1.1308],
          [ 0.5160,  0.3257],
          [-1.2410, -0.8361]]],


        [[[ 0.9847, -0.7231],
          [-0.0938, -0.0722],
          [-0.4914,  0.2473],
          [ 0.4140, -1.0962],
          [-0.0110,  0.6437]],

         [[-0.4914,  0.2473],
          [-0.0938, -0.0722],
          [-0.0110,  0.6437],
          [ 0.4140, -1.0962],
          [ 0.9847, -0.7231]]]])

Как видите, я хочу вернуть только 1-ю и 4-ю партии. Как мне это сделать?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
691
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Что вы хотите: объединение a[0, indices[0]] и a[1, indices[1]].

Что вы закодировали: конкатенация a[0, indices] и a[1, indices].

Проблема, с которой вы столкнулись, заключается в том, что индексы, возвращаемые sort, имеют форму первого измерения, но значения являются только индексами второго измерения. Когда вы собираетесь использовать их, вы хотите сопоставить indices[0] с a[0], но pytorch не делает этого неявно (потому что необычное индексирование очень мощное и нуждается в этом синтаксисе для его мощности). Итак, все, что вам нужно сделать, это дать параллельный список индексов для первого измерения.

т. е. вы хотите использовать что-то вроде: a[[[0], [1]], indices].

Чтобы немного обобщить это, вы можете использовать что-то вроде:

n = a.shape[0]
first_indices = torch.arange(n)[:, None]
a[first_indices, indices]

Это немного сложно, поэтому вот пример:

>>> a = torch.randn(2,4,2)
>>> a
tensor([[[-0.2050, -0.1651],
         [ 0.5688,  1.0082],
         [-1.5964, -0.9236],
         [ 0.3093, -0.2445]],

        [[ 1.0586,  1.0048],
         [ 0.0893,  2.4522],
         [ 2.1433, -1.2428],
         [ 0.1591,  2.4945]]])
>>> indices = a[:, :, 0].sort()[1]
>>> indices
tensor([[2, 0, 3, 1],
        [1, 3, 0, 2]])
>>> a[:, indices]
tensor([[[[-1.5964, -0.9236],
          [-0.2050, -0.1651],
          [ 0.3093, -0.2445],
          [ 0.5688,  1.0082]],

         [[ 0.5688,  1.0082],
          [ 0.3093, -0.2445],
          [-0.2050, -0.1651],
          [-1.5964, -0.9236]]],


        [[[ 2.1433, -1.2428],
          [ 1.0586,  1.0048],
          [ 0.1591,  2.4945],
          [ 0.0893,  2.4522]],

         [[ 0.0893,  2.4522],
          [ 0.1591,  2.4945],
          [ 1.0586,  1.0048],
          [ 2.1433, -1.2428]]]])
>>> a[0, indices]
tensor([[[-1.5964, -0.9236],
         [-0.2050, -0.1651],
         [ 0.3093, -0.2445],
         [ 0.5688,  1.0082]],

        [[ 0.5688,  1.0082],
         [ 0.3093, -0.2445],
         [-0.2050, -0.1651],
         [-1.5964, -0.9236]]])
>>> a[1, indices]
tensor([[[ 2.1433, -1.2428],
         [ 1.0586,  1.0048],
         [ 0.1591,  2.4945],
         [ 0.0893,  2.4522]],

        [[ 0.0893,  2.4522],
         [ 0.1591,  2.4945],
         [ 1.0586,  1.0048],
         [ 2.1433, -1.2428]]])
>>> a[0, indices[0]]
tensor([[-1.5964, -0.9236],
        [-0.2050, -0.1651],
        [ 0.3093, -0.2445],
        [ 0.5688,  1.0082]])
>>> a[1, indices[1]]
tensor([[ 0.0893,  2.4522],
        [ 0.1591,  2.4945],
        [ 1.0586,  1.0048],
        [ 2.1433, -1.2428]])
>>> a[[[0], [1]], indices]
tensor([[[-1.5964, -0.9236],
         [-0.2050, -0.1651],
         [ 0.3093, -0.2445],
         [ 0.5688,  1.0082]],

        [[ 0.0893,  2.4522],
         [ 0.1591,  2.4945],
         [ 1.0586,  1.0048],
         [ 2.1433, -1.2428]]])

Потрясающий!! Большое спасибо, я многому научился из вашего поста. Я уже предполагал, что что-то подобное должно быть решением, но я не знал, как это реализовать. Хорошего дня!

spadel 21.12.2020 12:16

Другие вопросы по теме