Как выполнять вычисления с помощью скользящего окна, сохраняя при этом эффективность использования памяти?

Я работаю с очень большими (несколько ГБ) двумерными квадратными массивами NumPy . Учитывая входной массив a, для каждого элемента я хотел бы найти направление его крупнейшего соседнего соседа. Я использую предоставленный раздвижной вид окна, чтобы избежать создания ненужных копий:

# a is an L x L array of type np.float32
swv = sliding_window_view(a, (3, 3)) # (L-2) x (L-2) x 3 x 3
directions = swv.reshape(L-2, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)

Однако вызов reshape здесь создает копию (L-2) x (L-2) x 9 вместо представления, что потребляет нежелательно большой кусок памяти. Есть ли способ выполнить эту операцию векторизованно, но с меньшим объемом памяти?

Обновлено: Многие ответы ориентированы на NumPy, который использует процессор (поскольку я изначально спрашивал об этом, чтобы упростить проблему). Будет ли оптимальная стратегия использования CuPy другой, то есть NumPy для графического процессора? Насколько я знаю, это делает использование Numba гораздо менее простым.

Можете ли вы сделать срез 1::2 для (3,3)? Изменение формы раздвижного окна плохое, нарезка прямо сохраняет вид. Агрегация, такая как max, avg и т. д., также подходит.

hpaulj 24.08.2024 04:35

Возможно, лучший вариант — выбросить sliding_window_view и реализовать его с помощью numba.

ken 24.08.2024 06:32

@ken Это было бы более эффективно из-за лучшего использования кэша и регистров ЦП. Однако код будет намного больше из-за связанной проверки (4 условия) в сочетании с argmax (требующим индекса значения, а не только максимального) из 5 элементов. Ради производительности мы можем написать цикл для центральных значений, избегая бесполезных условий, но это сделает код еще больше и сложнее, хотя и более эффективным. Преимущество реализации Numba также заключается в том, что затраты памяти незначительны (т. е. O(1), что является оптимальным).

Jérôme Richard 24.08.2024 12:40

@JérômeRichard Я не уверен, что слежу... Поскольку sliding_window_view не обрабатывает края, я не думаю, что проверка привязки необходима. Итак, argmax — единственная сложная часть. Однако, ИМХО, логическая сложность будет проще, если учесть, что текущая реализация разрезает скользящее окно только для того, чтобы получить четырех соседей.

ken 24.08.2024 13:26

Мне также интересно, можно ли использовать для этого трафарет (я не проверял). В этом случае argmax будет буквально единственной частью.

ken 24.08.2024 13:27

@ken Ах да, действительно, нет необходимости в связанных проверках, поскольку выходные данные имеют размер L-2. Я пропустил это. Тем не менее, есть еще 3 условия для 4 соседей с некоторой скучной обработкой индекса, направление, максимизирующее значение, должно быть сохранено в каждом условии, а также максимальное значение, найденное на данный момент. Таким образом, около 10 строк в 2 вложенных циклах. Лично я также предпочитаю циклы, хотя здесь это не так уж и хорошо. Это действительно своего рода необычный трафарет (настолько необычный, что я не ожидаю, что модули Python предоставят способ легко его выразить).

Jérôme Richard 24.08.2024 15:40
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
6
6
501
3
Перейти к ответу Данный вопрос помечен как решенный

Ответы 3

вызов reshape здесь создает копию (L-2) x (L-2) x 9 вместо представления

Это связано с тем, что две последние оси целевого массива не могут быть изменены вместе здесь, в Numpy. Действительно, это означало бы, что шаг последнего измерения будет варьироваться между элементами, что не поддерживается (и, конечно, никогда не будет, потому что это сделает многие операции очень медленными и намного более сложными). Шаги могут быть постоянными только для данной оси. Когда представление не может быть создано, Numpy выполняет дорогостоящее копирование.

Есть ли способ выполнить эту операцию векторизованно, но с меньшим объемом памяти?

Ключевым моментом в таком случае является выполнение операции по частям. Если входные данные большие, это может быть быстрее, чем вычисление окончательного массива за одну уникальную операцию из-за ошибок кэша ЦП и страниц. Действительно, временный массив можно повторно использовать в памяти. При этом куски не должны быть слишком маленькими, иначе накладные расходы на вызов функции Numpy и цикла CPython станут дорогими.

Если L довольно большой, вы можете просто перебирать построчно. В противном случае вам придется вычислять куски строк. Вот пример:

swv = np.lib.stride_tricks.sliding_window_view(a, (3, 3)) # (L-2) x (L-2) x 3 x 3

out = np.empty((L-2, L-2), dtype=np.uint8)
k = 8
for i in range(0, L-2, k):
    out[i:i+k] = swv[i:i+k,:,:,:].reshape(-1, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)

Контрольный показатель

Вот результаты производительности a = np.random.rand(L, L).astype(np.float32) с L = 1000 на моей машине с i5-9600KF и Numpy 1.24.3:

Initial implementation:
    - time: 51 ms
    - memory overhead: O(L**2)

Proposed implementation:
    - time: 43 ms
    - memory overhead: O(L * k)

Вычисление здесь происходит быстрее для всех k<30. При k>=30 массивы слишком велики, чтобы вычисления были эффективными, и они занимают примерно то же время, что и ваши вычисления (на самом деле, предлагаемая реализация даже в этом случае немного быстрее). Мы также можем заключить, что циклы CPython не являются медленными, пока фрагменты достаточно велики, чтобы накладные расходы были небольшими по сравнению со временем вычислений. Вычисления также занимают меньше памяти. Единственным недостатком является то, что код больше. Бесплатного обеда не существует.


Примечания и более быстрые реализации

Обратите внимание, что разумным значением k может быть max(int(512*1024 / (3*3*a.itemsize*(L-2)) + 0.5), 1). По этой формуле вычисления, если это возможно, должны занимать не более нескольких МБ ОЗУ. Если это невозможно, потому что k=1, тогда необходимо взять C*a.itemsize*(L-2)*3*3/1024**2 MiB, где C — небольшая константа (обычно 2).

Вот расширенный тест с другими конкурентными реализациями:

nocomment's first implementation ("mine"):
    - time: 19 ms
    - memory overhead: O(L**2)

nocomment's second implementation ("mine4"):
    - time: 13 ms
    - memory overhead: O(L**2)

Native scalar code:
    - time: 1.5 ms
    - memory overhead: O(1)

ken's best implementation ("neighbor_argmax"):
    - time: 0.38 ms
    - memory overhead: O(1)

Optimized native SIMD code:
    - time: 0.25 ms
    - memory overhead: O(1)

Первая реализация nocomment быстрее, но требует значительно больше памяти (хотя и меньше, чем исходный код). Действительно, необходимо одновременно выделить как минимум 3 временных логических массива. Размер каждого логического массива составляет (L-2)**2 байт. Это означает, что необходимо выделить как минимум 3 * (L-2)**2 байт. Это значительно больше, чем C * k * L (где C — константа, которая должна быть от 30 до 50), пока k остается небольшим, а L относительно большим. Вторая реализация на моей машине работает быстрее, но для нее также потребуется больше памяти.

Реализация ken (лучшая последовательная) великолепна, поскольку Numba генерирует ассемблерный код, используя инструкции SIMD, и его использование памяти очень незначительно (как и собственные коды). Он не так хорош, как оптимизированный нативный код, но довольно близок к этому. Я думаю, что основным недостатком является значительное время компиляции (800 мс платится только один раз при самом первом вызове).

Кроме того, можно отметить, что Numpy намного медленнее того, что можно реализовать в собственном коде (например, в C/C++). Разрыв становится еще больше, когда его собственный код оптимизирован для использования модулей SIMD, доступных на всех основных процессорах. Собственный код требует <1 КБ дополнительной памяти. У кода Numpy нет шансов приблизиться к такой производительности. Оптимизированный SIMD-совместимый собственный код примерно в 34 раза быстрее, чем предлагаемая реализация Numpy, и в 200 раз быстрее, чем исходный код, а также использует еще меньше памяти!

В целом, мы видим, что существует огромный разрыв между собственными/обработанными кодами и кодами, использующими только Numpy, как с точки зрения использования памяти, так и с точки зрения скорости.

Спасибо за то, что научили меня разбиению на фрагменты, а также за описание всех других предлагаемых решений в комментариях! Я пошел дальше и отметил ответ Кена как принятый, поскольку он самый быстрый, но хотел отметить, что ваше профилирование всего было очень ценным.

DanDan面 27.08.2024 21:14

Несколько решений, которые занимают гораздо меньше памяти, чем исходное, и работают быстрее. Память можно еще больше уменьшить, разбивая на части, как это сделал Жером. Умеренное остроумие L = 1000:

Memory:
59.76 bytes/element  original
 3.06 bytes/element  mine
 9.00 bytes/element  mine4
 8.96 bytes/element  mine5

Speed:
  9.8 ± 0.3 ms  mine
 11.0 ± 0.5 ms  mine5
 22.2 ± 1.7 ms  mine4
 99.1 ± 4.9 ms  original

Python: 3.12.2 (main, Jun 12 2024, 09:13:57) [GCC 14.1.1 20240522]
NumPy:  1.26.4

Я получаю значения в четырех направлениях (вверх/влево/вправо/вниз). В mine я сравниваю все шесть пар, сохраняю результаты сравнения в виде 6-битных чисел, а затем смотрю, какое направление означает каждое 6-битное значение. В mine4 и mine5 я отслеживаю максимумы.

import numpy as np


def original(a):
  swv = np.lib.stride_tricks.sliding_window_view(a, (3, 3))
  return swv.reshape(L-2, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)


def mine(a):
  L = a[1:-1, :-2]
  R = a[1:-1, 2:]
  U = a[:-2, 1:-1]
  D = a[2:, 1:-1]
  cmp = (U < L).view(np.uint8)
  for (x, y) in (U, R), (U, D), (L, R), (L, D), (R, D):
    cmp = (cmp << 1) | (x < y).view(np.uint8)
  table = np.array([
      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3,
      0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 3,
      1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0,
      1, 0, 0, 0, 2, 0, 0, 0, 1, 1, 0, 3, 2, 0, 2, 3
  ], dtype=np.uint8)
  return table[cmp]

# How I got my magic table
# table = np.zeros(64)
# table[mine(a)] = original(a)
# print(repr(table.astype(np.uint8)))


def mine4(a):
  U = a[:-2, 1:-1]
  L = a[1:-1, :-2]
  R = a[1:-1, 2:]
  D = a[2:, 1:-1]

  max = np.maximum(U, L)
  dir = (L > U).view(np.uint8)

  dir[R > max] = 2
  max = np.maximum(max, R)

  dir[D > max] = 3

  return dir


def mine5(a):
  U = a[:-2, 1:-1]
  L = a[1:-1, :-2]
  R = a[1:-1, 2:]
  D = a[2:, 1:-1]

  return np.where(
    np.maximum(R, D) > np.maximum(U, L),
    (D > R).view(np.uint8) + 2,
    (L > U).view(np.uint8)
  )


funcs = [original, mine, mine4, mine5]

from timeit import timeit
from statistics import mean, stdev
import sys
import random
import tracemalloc as tm

L = 1000
a = np.random.random((L, L)).astype(np.float32)

# Correctness
print('Correctness:')
expect = original(a)
for f in funcs:
  print((f(a) == expect).all(), f.__name__)

# Memory
print('\nMemory:')
for f in funcs * 2:
  tm.start()
  f(a)
  print(f'{tm.get_traced_memory()[1] / L**2 :5.2f} bytes/element ', f.__name__)
  tm.stop()

# Speed
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:5]]
    return f'{mean(ts):5.1f} ± {stdev(ts):3.1f} ms '
for _ in range(25):
    random.shuffle(funcs)
    for f in funcs:
        t = timeit(lambda: f(a), number=1) / 1
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)

print('\nPython:', sys.version)
print('NumPy: ', np.__version__)

Попробуйте это онлайн!

Хорошая попытка! Однако он должен занимать больше памяти, чем мой код, поскольку этот код работает со всем массивом, хотя логические массивы занимают всего 1 байт/элемент. Подробный анализ смотрите в моем обновленном ответе ;) . Более того, код довольно сложен для понимания, ИМХО. Примечание. Я добавил производительность собственного кода в тесты, чтобы увидеть, насколько медленными являются коды Numpy.

Jérôme Richard 24.08.2024 16:32

@JérômeRichard Почему «попробовать»? Разве я не преуспел? Вы сами сказали, что он занимает меньше памяти, чем у них. Да, конечно, нужно больше, чем у вас. Но цель – взять меньше, чем у них. Я думаю, что при нескольких байтах на элемент я беру гораздо меньше, чем их по крайней мере 9*4 байта на элемент (если я правильно понимаю), так что это может быть достаточно мало. А для дальнейшего сокращения я упомянул разделение на фрагменты. Кстати, они сказали, что у них есть float32, поэтому, пожалуйста, измерьте его. Использование float64, кажется, уменьшает мое преимущество в скорости.

no comment 24.08.2024 16:45

«Разве мне это не удалось?» Конечно, вам удалось улучшить решение OP. Я не носитель языка, поэтому не могу использовать наиболее подходящие слова (я рассматриваю ответы как предложение, поэтому использую слово «попробуй»). Я также рассматриваю свой ответ как «попытку» (особенно до тех пор, пока он не принят или нет отзывов от ОП). Я согласен с вами, что этот ответ значительно улучшает использование памяти по сравнению с первоначальной реализацией.

Jérôme Richard 24.08.2024 18:57

Хороший момент для 32-битных чисел с плавающей запятой. Я обновил тест/ответ. Интересно, что на большинство решений это изменение не влияет. Есть только родное решение SIMD. Я думаю, это связано с тем, что все коды являются скалярными, привязанными к вычислениям, и тип не влияет на скорость скалярных операций. Я думаю, что тайминги немного лучше, потому что массив в два раза меньше и поэтому он должен лучше помещаться в кеше и уменьшать нагрузку на ОЗУ, а также ошибки страниц (которые не являются узким местом, но влияют на производительность).

Jérôme Richard 24.08.2024 19:01

Кстати, я думаю, что ваша реализация может быть немного быстрее с .view(np.uint8) вместо .astype(np.uint8). Последний создает новый массив, а не первый. Помимо этого, я не думаю, что существует (намного) лучшая полностью векторизованная реализация, использующая только Numpy.

Jérôme Richard 24.08.2024 19:07

@JérômeRichard Ах, ок. Вместе с вашим «Однако он должен занимать больше памяти, чем мой код» звучало так, как будто я пытался/заявлял, что мой занимает меньше памяти, чем ваш (хотя я тоже не носитель языка). Спасибо за view, кажется, это немного быстрее (хотя сложно сказать, учитывая мои простые измерения на этой машине). Меня интересует разница в скорости между нашими машинами, почему я измерил, что моя скорость более чем в 10 раз быстрее, чем оригинал, в то время как вы измерили ее менее чем в 3 раза быстрее...

no comment 24.08.2024 20:19

Хорошо, теперь я понимаю. Извините за путаницу. Это действительно странно. Это, безусловно, связано с версией Numpy или вашим процессором (например, довольно недавним процессором). Какую версию вы используете и какой именно процессор?

Jérôme Richard 24.08.2024 20:43

@JérômeRichard NumPy 1.26.4 и некоторые модели семейства Skylake Xeon 6 модели 85 (говорит lscpu, когда я помещаю это туда в «Дополнительно»). Я только что добавил еще одно решение, оно менее быстрое для меня, не могли бы вы тоже попробовать?

no comment 24.08.2024 21:05

Я попытался установить Numpy 2.0 и CPython 3.12 и получил примерно такое же поведение (хотя производительность вашей функции была немного лучше: ~ 16 мс для моей и ~ 12 мс для моей4). У меня также есть процессор, похожий на Skylake, так что это не архитектура процессора. Единственная разница, которую я вижу, - это, возможно, больший/другой кеш L3 и другая оперативная память (пропускная способность и задержка). Обратите внимание, что для этого стенда я использовал Windows.

Jérôme Richard 24.08.2024 22:39

@JérômeRichard Спасибо. Еще одна интересная разница в скорости.

no comment 24.08.2024 23:12
Ответ принят как подходящий

Поскольку использование sliding_window_view неэффективно для вашего случая, я предложу альтернативу с использованием Numba.

Во-первых, чтобы упростить реализацию, определите следующую альтернативу argmax.

from numba import njit


@njit
def argmax(*values):
    """argmax alternative that can take an arbitrary number of arguments.

    Usage: argmax(0, 1, 3, 2)  # 2
    """
    max_arg = 0
    max_value = values[0]
    for i in range(1, len(values)):
        value = values[i]
        if value > max_value:
            max_value = value
            max_arg = i
    return max_arg

Это стандартная функция argmax, за исключением того, что она принимает несколько скалярных аргументов вместо одного массива numpy.

Используя эту альтернативу argmax, вашу операцию можно легко реализовать повторно.

@njit(cache=True)
def neighbor_argmax(a):
    height, width = a.shape[0] - 2, a.shape[1] - 2
    out = np.empty((height, width), dtype=np.uint8)
    for y in range(height):
        for x in range(width):
            # window: a[y:y + 3, x:x + 3]
            # center: a[y + 1, x + 1]
            out[y, x] = argmax(
                a[y, x + 1],  # up
                a[y + 1, x],  # left
                a[y + 1, x + 2],  # right
                a[y + 2, x + 1],  # down
            )
    return out

Для работы этой функции требуется всего несколько переменных, исключая входной и выходной буферы. Поэтому нам не нужно беспокоиться об объеме памяти.

Альтернативно вы можете использовать трафарет, утилиту для раздвижных окон для Numba. С помощью stencil вам нужно только определить ядро. Нумба позаботится обо всем остальном.

from numba import njit, stencil

@stencil
def kernel(window):
    # window: window[-1:2, -1:2]
    # center: window[0, 0]
    return np.uint8(  # Don't forget to cast to np.uint8.
        argmax(
            window[-1, 0],  # up
            window[0, -1],  # left
            window[0, 1],  # right
            window[1, 0],  # down
        )
    )


@njit(cache=True)
def neighbor_argmax_stencil(a):
    return kernel(a)[1:-1, 1:-1]  # Slicing is not mandatory.

Если хотите, его также можно встроить.

@njit(cache=True)
def neighbor_argmax_stencil_inlined(a):
    f = stencil(lambda w: np.uint8(argmax(w[-1, 0], w[0, -1], w[0, 1], w[1, 0])))
    return f(a)[1:-1, 1:-1]  # Slicing is not mandatory.

Однако stencil очень ограничен по функциональности и не может полностью заменить sliding_window_view. Единственное отличие состоит в том, что нет возможности пропускать края. Он всегда дополняется постоянным значением (по умолчанию 0). То есть, если вы поставите матрицу (L, L), вы получите результат (L, L), а не (L-2, L-2).

Вот почему я вырезаю выходные данные приведенного выше кода, чтобы они соответствовали вашей реализации. Однако это может быть нежелательным поведением, поскольку оно нарушает непрерывность памяти. Вы можете копировать после нарезки, но имейте в виду, что это увеличит пиковое использование памяти.

Кроме того, следует отметить, что эти функции также можно легко адаптировать для многопоточности. Подробную информацию см. в эталонном коде ниже.

Вот эталон.

import math
import timeit

import numpy as np
from numba import njit, prange, stencil
from numpy.lib.stride_tricks import sliding_window_view


def baseline(a):
    L = a.shape[0]
    swv = sliding_window_view(a, (3, 3))  # (L-2) x (L-2) x 3 x 3
    directions = swv.reshape(L - 2, L - 2, 9)[:, :, 1::2].argmax(axis=2).astype(np.uint8)
    return directions


@njit
def argmax(*values):
    """argmax alternative that can accept an arbitrary number of arguments.

    Usage: argmax(0, 1, 3, 2)  # 2
    """
    max_arg = 0
    max_value = values[0]
    for i in range(1, len(values)):
        value = values[i]
        if value > max_value:
            max_value = value
            max_arg = i
    return max_arg


@njit(cache=True)
def neighbor_argmax(a):
    height, width = a.shape[0] - 2, a.shape[1] - 2
    out = np.empty((height, width), dtype=np.uint8)
    for y in range(height):
        for x in range(width):
            # window: a[y:y + 3, x:x + 3]
            # center: a[y + 1, x + 1]
            out[y, x] = argmax(
                a[y, x + 1],  # up
                a[y + 1, x],  # left
                a[y + 1, x + 2],  # right
                a[y + 2, x + 1],  # down
            )
    return out


@njit(cache=True, parallel=True)  # Add parallel=True.
def neighbor_argmax_mt(a):
    height, width = a.shape[0] - 2, a.shape[1] - 2
    out = np.empty((height, width), dtype=np.uint8)
    for y in prange(height):  # Change this to prange.
        for x in range(width):
            # window: a[y:y + 3, x:x + 3]
            # center: a[y + 1, x + 1]
            out[y, x] = argmax(
                a[y, x + 1],  # up
                a[y + 1, x],  # left
                a[y + 1, x + 2],  # right
                a[y + 2, x + 1],  # down
            )
    return out


@stencil
def kernel(window):
    # window: window[-1:2, -1:2]
    # center: window[0, 0]
    return np.uint8(  # Don't forget to cast to np.uint8.
        argmax(
            window[-1, 0],  # up
            window[0, -1],  # left
            window[0, 1],  # right
            window[1, 0],  # down
        )
    )


@njit(cache=True)
def neighbor_argmax_stencil(a):
    return kernel(a)[1:-1, 1:-1]  # Slicing is not mandatory.


@njit(cache=True)
def neighbor_argmax_stencil_with_copy(a):
    return kernel(a)[1:-1, 1:-1].copy()  # Slicing is not mandatory.


@njit(cache=True, parallel=True)
def neighbor_argmax_stencil_mt(a):
    return kernel(a)[1:-1, 1:-1]  # Slicing is not mandatory.


@njit(cache=True)
def neighbor_argmax_stencil_inlined(a):
    f = stencil(lambda w: np.uint8(argmax(w[-1, 0], w[0, -1], w[0, 1], w[1, 0])))
    return f(a)[1:-1, 1:-1]  # Slicing is not mandatory.


def benchmark():
    size = 2000  # Total nbytes (in MB) for a.
    n = math.ceil(math.sqrt(size * (10 ** 6) / 4))
    rng = np.random.default_rng(0)
    a = rng.random(size=(n, n), dtype=np.float32)
    print(f"{a.shape=}, {a.nbytes=:,}")

    expected = baseline(a)
    # expected = neighbor_argmax_mt(a)
    assert expected.shape == (n - 2, n - 2) and expected.dtype == np.uint8

    candidates = [
        baseline,
        neighbor_argmax,
        neighbor_argmax_mt,
        neighbor_argmax_stencil,
        neighbor_argmax_stencil_mt,
        neighbor_argmax_stencil_with_copy,
        neighbor_argmax_stencil_inlined,
    ]
    name_len = max(len(f.__name__) for f in candidates)
    for f in candidates:
        assert np.array_equal(expected, f(a)), f.__name__
        t = timeit.repeat(lambda: f(a), repeat=3, number=1)
        print(f"{f.__name__:{name_len}} : {min(t)}")


if __name__ == "__main__":
    benchmark()

Результат:

a.shape=(22361, 22361), a.nbytes=2,000,057,284
baseline                          : 24.971996600041166
neighbor_argmax                   : 0.1917789001017809
neighbor_argmax_mt                : 0.11929619999136776
neighbor_argmax_stencil           : 0.2940085999434814
neighbor_argmax_stencil_mt        : 0.17756330000702292
neighbor_argmax_stencil_with_copy : 0.46573049994185567
neighbor_argmax_stencil_inlined   : 0.29338629997801036

Я думаю, этих результатов достаточно, чтобы вы подумали о том, чтобы попробовать Numba :)

neighbor_argmax отлично работает на моей машине. Он оказывается конкурентоспособным с оптимизированным собственным SIMD-кодом. Хотя я ожидал хорошей производительности, я не ожидал, что Numba сгенерирует SIMD-дружественный ассемблерный код из-за этого условия.
Jérôme Richard 24.08.2024 22:11

Спасибо за это! На самом деле я полностью использовал Numba, прежде чем перейти на CuPy и отказаться от Numba, хотя они якобы совместимы. Проблема в том, что у меня было много проблем с их совместной работой, особенно после того, как параметры target = "cuda" и target_backend = "cuda" были удалены из numba.jit и numba.njit. Знаете ли вы, совместимо ли ваше решение с CUDA?

DanDan面 27.08.2024 21:12

Другие вопросы по теме