Применить массив перестановок по нескольким осям в numpy

Допустим, у меня есть массив перестановок perm, который может выглядеть так:

perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])

Если я хочу применить его к одной оси, я могу написать что-то вроде:

v = np.arange(9).reshape(3, 3)
print(v[perm])

Выход:

array([[[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]],

       [[3, 4, 5],
        [6, 7, 8],
        [0, 1, 2]],

       [[0, 1, 2],
        [6, 7, 8],
        [3, 4, 5]],

       [[6, 7, 8],
        [3, 4, 5],
        [0, 1, 2]]])

Теперь я хотел бы применить его к двум осям одновременно. Я понял, что могу сделать это через:

np.array([v[tuple(np.meshgrid(p, p, indexing = "ij"))] for p in perm])

Но я считаю это совершенно неэффективным, потому что он должен создавать сетку-сетку, а также требует цикла for. В этом примере я создал небольшой массив, но на самом деле у меня массивы гораздо большего размера с большим количеством перестановок, поэтому мне бы очень хотелось иметь что-то такое же быстрое и простое, как одноосная версия.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
5
0
84
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Вы можете избавиться от meshgrid с помощью

a = np.array([v[p][:,p] for p in perm])
b = np.array([v[tuple(np.meshgrid(p, p, indexing = "ij"))] for p in perm])
print(np.all(b == a)) # True

Это в 5 раз быстрее в вашем примере массива:

import timeit
%timeit np.array([v[tuple(np.meshgrid(p, p, indexing = "ij"))] for p in perm]) # 42.7 µs
%timeit np.array([v[p][:,p] for p in perm]) # 8.18 µs

Я бы предположил, что цикл for по большей части не имеет значения. Если вас интересует дальнейшая оптимизация, укажите формы, с которыми вы работаете...

Ответ принят как подходящий

Как насчет:

p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
v[p1, p2]

Нулевая ось p1 и p2 — это всего лишь «пакетное» измерение perm, которое позволяет выполнять множество перестановок за одну операцию.

Другое измерение perm, соответствующее индексам, выровнено по первой оси в p1 и второй в p2. Поскольку оси ортогональны, массивы транслируются, в основном, как массивы, которые вы использовали meshgrid, но они по-прежнему имеют размерность пакета.

Это лучшее, что я могу сделать со своего мобильного телефона :) При необходимости могу попытаться уточнить позже, но основная идея - это вещание.

Сравнение:

import numpy as np
perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
v = np.arange(9).reshape(3, 3)

ref = np.array([v[tuple(np.meshgrid(p, p, indexing = "ij"))] for p in perm])

p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
res = v[p1, p2]

np.testing.assert_equal(res, ref)
# passes

%timeit np.array([v[tuple(np.meshgrid(p, p, indexing = "ij"))] for p in perm])
# 107 µs ± 20.6 µs per loop

%timeit v[perm[:, :, np.newaxis], perm[:, np.newaxis, :]]
# 3.73 µs ± 1.07 µs per loop

Более простой (без пакетного измерения) пример индексов вещания:

import numpy as np
i = np.arange(3)
ref = np.meshgrid(i, i, indexing = "ij")
res = np.broadcast_arrays(i[:, np.newaxis], i[np.newaxis, :])
np.testing.assert_equal(res, ref)
# passes

В коде решения вверху широковещательная рассылка неявна. Нам не нужно вызывать broadcast_arrays, потому что это происходит автоматически во время индексации.

лучше, чем мой ответ по всем фронтам. (может быть, добавить какие-нибудь пояснения?)

Julien 15.07.2024 09:48

Другие вопросы по теме