Что говорит название. Я ищу быстрый и питонический подход для извлечения строк массива конечных точек A, который содержит элементы другого массива v
Простой пример того, чего я хочу достичь, выглядит следующим образом:
Вход:
A = [[ 4 9]
[15 19]
[20 28]
[31 37]
[43 43]]
v = [ 0 1 2 3 11 12 13 14 26 29 30 31 43]
Поскольку A является массивом конечных точек, это означает, что в каждой строке первый элемент и второй элемент представляют начало и конец интервала. Поскольку только интервалы [20 28], [31 37] и [43 43] содержат элементы в v (в этом случае 26,31 and 43 содержатся в интервалах, созданных массивом конечных точек A), желаемый результат:
[[20 28]
[31 37]
[43 43]]
Ниже приведен код для создания фактических входных массивов:
import numpy as np
np.random.seed(0)
size = 32000
base_arr = np.arange(size)*10
t1 = np.random.randint(0,6, size)+base_arr
t2 = np.random.randint(5,10, size)+base_arr
A = np.vstack((t1,t2)).T
v = np.sort(np.random.randint(0,10,3*size)+np.repeat(base_arr,3))
заранее спасибо
Обновлено: добавлено больше деталей в объяснение






Сравните в третьем измерении
import numpy as np
a = np.array([[ 4, 9],
[15, 19],
[20, 28],
[31, 37],
[43, 43]])
v = np.array([ 0, 1, 2, 3, 11, 12, 13, 14, 26, 29, 30, 31, 43])
between = np.logical_and(v >= a[:,0,None], v <= a[:,1,None])
print(a[between.any(-1)])
>>>
[[20 28]
[31 37]
[43 43]]
>>>
К сожалению, мой взрывается на больших массивах с ошибкой памяти
На самом деле я заметил это и на своем компьютере. Интересно, есть ли способы сделать это в вашем подходе, не создавая дополнительного измерения. Если это возможно, то функция должна потреблять гораздо меньше памяти.
Чтобы это работало, необходимо дополнительное измерение, оно допускает Вещание. Searchsorted - это путь. Я застрял и ввел себя в заблуждение, когда прочитал термин векторизованный.
Подход №1
Мы можем использовать np.searchsorted, чтобы получить левый и правый позиционные индексы для начального и конечного элементов в каждой строке по значениям v и искать несоответствующие, которые будут указывать на то, что конкретная строка имеет по крайней мере один элемент в пределах этих границ. Следовательно, мы могли бы просто сделать -
A[np.searchsorted(v,A[:,0],'left')!=np.searchsorted(v,A[:,1],'right')]
Подход №2
Другой способ — использовать левые индексы для индексации v, а затем посмотреть, меньше ли они, чем правые конечные точки. Следовательно, было бы -
idx = np.searchsorted(v,A[:,0],'left')
out = A[(idx<len(v)) & (v[idx.clip(max=len(v)-1)]<=A[:,1])]
Обратите внимание, что это предполагает сортировку v и ввод в виде массивов. Если v еще не отсортировано, нам нужно отсортировать его, а затем передать.
Тайминги для большего набора данных в моем конце -
In [65]: %timeit A[np.searchsorted(v,A[:,0],'left')!=np.searchsorted(v,A[:,1],'right')]
2 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [66]: %%timeit
...: idx = np.searchsorted(v,A[:,0],'left')
...: out = A[(idx<len(v)) & (v[idx.clip(max=len(v)-1)]<=A[:,1])]
1.32 ms ± 7.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Спасибо за ваш ответ. На сегодняшний день самый быстрый подход среди уже хороших подходов
@Divakar Я просто перебрал ответы range(400), я обнаружил, что for i in [5,29,130,139,146,150,156,178,295,313,315,337,342]: np.random.seed(i) результаты подхода 1 и подхода 2 разные.
@Divakar, на полигоне с 400 итерациями, да, могут быть дубликаты. Но в тех местах, где я действительно хочу применить свою функцию, не должно быть дубликатов. А если есть, то я воспользуюсь v = np.unique(v) для их удаления.
@mathguy Да, мне нужно было проверить одно условие. Отредактированное приложение № 2.
@Divakar Только что проверил это с разными видами v, пока все хорошо. Очень ценю ваше продолжение
Я не считаю это полностью Pythonic, но это как минимум O (n).
def find_bounding_intervals(A, v):
rows = []
i = 0
for row in A:
while all(v[i] < row):
i += 1
if row[0] <= v[i] <= row[1]:
rows.append(row)
return np.array(rows)
A = np.array([[ 4, 9],
[15, 19],
[20, 28],
[31, 37],
[43, 43]])
v = np.array([ 0, 1, 2, 3, 11, 12, 13, 14, 26, 29, 30, 31, 43])
print(find_bounding_intervals(A, v))
Мой недорогой ноутбук выдает решение за ~ 0,28 с для гораздо больших данных, определенных в вашем вопросе.
Спасибо за ваш ответ. Не знал, что встроенная функция all() существует до сих пор.
Только что использовал ваш код, чтобы поэкспериментировать, сможет ли numba-njited версия вашей реализации превзойти векторизованную версию. Я обнаружил, что с точки зрения производительности версия numba является самой быстрой. Я опубликую это ниже.
from numba import njit
@njit
def find_bounding_intervals(A, v):
rows_L = []
rows_R = []
i = 0
for row in range(A.shape[0]):
while v[i] < A[row,0] and v[i] < A[row,1]:
i += 1
if A[row,0] <= v[i] <= A[row,1]:
rows_L.append(A[row,0])
rows_R.append(A[row,1])
return np.array([rows_L, rows_R]).T
Хотя эта реализация технически не является векторизованной функцией, она действительно является самой быстрой почти для всех размеров n.
Я должен прояснить, что алгоритм исходит от @brentertainer
Связано: Найдите, эффективно ли содержит отсортированный массив чисел с плавающей запятой числа в определенном диапазоне.