Расширенная индексация Numpy в многомерных массивах; не просто брать простые строки или столбцы

У меня есть трехмерный массив numpy. Каков самый быстрый способ получить 3D-массив, содержащий самый большой элемент каждой конечной оси массива, без написания цикла. (Позже я буду использовать CuPy с тем же синтаксисом, и циклы уберут параллелизм графического процессора и скорость что здесь является самым важным фактором.)

Получить индексы для самых крупных элементов легко:

>>> arr = np.array(
        [[[ 6, -2, -6, -5],
        [ 1, 12,  3,  9],
        [21,  7,  9,  8]],

       [[15, 12, 20, 12],
        [17, 15, 17, 23],
        [22, 18, 27, 32]]])
>>> indexes = arr.argmax(axis=2, keepdims=True)
>>> indexes
array([[[0],
        [1],
        [0]],

       [[2],
        [3],
        [3]]])

но как использовать эти индексы для получения выбранных значений из arr? Все способы, которые я пробовал, либо приводят к ошибкам (например, arr[indexes]), либо к неверным результатам. В этом примере я хотел бы получить

array([[[6],
        [12],
        [21]],

       [[20],
        [23],
        [32]]])
argmaxдокументы советуют вам посмотреть take_along_axisnumpy.org/doc/stable/reference/generated/…
hpaulj 09.08.2024 06:04

Спасибо! Я предполагал, что будет простой ответ, но не смог найти его среди множества функций numpy.

Mikael 09.08.2024 09:54

Действительно ли вам нужен промежуточный этап получения индексов самых больших элементов? Если вас интересуют только самые крупные элементы, вы можете использовать np.amax(), как предложено в ответе «неплательщика налогов» ниже. Также вы можете отдать должное тем, кто вам помог, (1) проголосовав за ответы, которые вы считаете полезными, и (2) приняв ответ, который решает вашу проблему.

simon 09.08.2024 10:53

Я обязательно проголосую за тех, кто помог, когда я впервые наберу достаточно очков для голосования (= 15 очков репутации), сейчас у меня 11 очков.... В моем конкретном случае мне нужна информация об индексах, но в этом нет необходимости. np.amax() действительно был бы лучше. Я собирался прокомментировать это после ответа «неплательщика налогов», но потом прочитал руководство, в котором якобы «Комментарии используются для запроса разъяснений или для указания проблем в сообщении». Это мой второй пост, я перестраховался и еще не оставлял подобных комментариев.

Mikael 09.08.2024 11:49

@Mikael Спасибо за разъяснения. Извините, я не помню, чтобы было минимальное количество баллов для голосования. В любом случае: добро пожаловать в Stack Overflow!

simon 09.08.2024 12:07
Структурированный массив Numpy
Структурированный массив Numpy
Однако в реальных проектах я чаще всего имею дело со списками, состоящими из нескольких типов данных. Как мы можем использовать массивы numpy, чтобы...
2
5
50
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Думаю, для этого можно использовать np.amax

import numpy as np

arr = np.array(
        [[[ 6, -2, -6, -5],
        [ 1, 12,  3,  9],
        [21,  7,  9,  8]],

       [[15, 12, 20, 12],
        [17, 15, 17, 23],
        [22, 18, 27, 32]]])

max_arr = np.amax(arr, axis=2, keepdims=True)
print(max_arr)

Выход

[[[ 6]
  [12]
  [21]]

 [[20]
  [23]
  [32]]]

Спасибо! Если бы информация о том, какие индексы были самыми большими, не требовалась, этот ответ был бы идеальным.

Mikael 09.08.2024 11:54
Ответ принят как подходящий

Как указал hpaulj, вы можете использовать take_along_axis

>>>arr = np.array(
        [[[ 6, -2, -6, -5],
        [ 1, 12,  3,  9],
        [21,  7,  9,  8]],

       [[15, 12, 20, 12],
        [17, 15, 17, 23],
        [22, 18, 27, 32]]])
>>>indexes = arr.argmax(axis=2, keepdims=True)
>>>np.take_along_axis(arr , indexes , axis=2)

array([[[6],
        [12],
        [21]],

       [[20],
        [23],
        [32]]])

Другие вопросы по теме