Как получить векторизованный случайный тай-брейк np.argmax?

Я знаю, что могу векторизовать np.argmax, введя двумерный массив и указав ось, например: np.argmax(2Darray,axis=1), чтобы получить максимальный индекс для строки.

Я знаю, что если две записи равны в одном одномерном векторе, и я хочу вернуть максимальный индекс, я могу разбить их через np.random.choice(np.flatnonzero(1Dvector == 1Dvector.max()))

Вопрос в том, как я могу сделать и то, и другое вместе? То есть: как векторизовать np.argmax, в результате чего одинаковые записи случайным образом разбиваются?

Вы проверили это: stackoverflow.com/questions/17568612/…

marco romelli 30.05.2019 09:13

Похоже, это только для 1D-вектора. Я не могу найти аргумент оси для np.argwhere. Также я хотел бы вернуть максимум, а не получить список максимальных индексов, хотя я уверен, что эта часть была бы тривиальной, если бы np.argwhere можно было векторизовать.

user4779 30.05.2019 09:31
Структурированный массив Numpy
Структурированный массив Numpy
Однако в реальных проектах я чаще всего имею дело со списками, состоящими из нескольких типов данных. Как мы можем использовать массивы numpy, чтобы...
2
2
626
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Вот один из способов. Для больших данных можно подумать о замене permutation чем-то более дешевым. Я жестко запрограммировал axis=1, но это не должно скрывать принцип.

def fair_argmax_2D(a):
    y, x = np.where((a.T==a.max(1)).T)
    aux = np.random.permutation(len(y))
    xa = np.empty_like(x)
    xa[aux] = x
    return xa[np.maximum.reduceat(aux, np.where(np.diff(y, prepend=-1))[0])]

a = np.random.randint(0,5,(4,5))
a
# array([[2, 2, 2, 2, 1],
#        [3, 3, 3, 3, 2],
#        [3, 4, 2, 1, 4],
#        [3, 2, 4, 2, 1]])

# draw 10000 times
res = np.array([fair_argmax_2D(a) for _ in range(10000)])

# check
np.array([np.bincount(r, None, 5) for r in res.T])
# array([[ 2447,  2567,  2449,  2537,     0],
#        [ 2511,  2465,  2536,  2488,     0],
#        [    0,  5048,     0,     0,  4952],
#        [    0,     0, 10000,     0,     0]])

Шикарное решение, спасибо!

user4779 30.05.2019 09:51

Другие вопросы по теме