Как я могу получить индексы подмассива двоичного массива, используя numpy?

У меня есть массив, который выглядит так

r = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1, 1])

и я хочу получить результат

[(0, 0), (3, 5), (7, 9)]

прямо сейчас я могу сделать это с помощью следующей функции

def get_indicies(array):
    indicies = []
    xstart = None
    for x, col in enumerate(array):
        if col == 0 and xstart is not None:
            indicies.append((xstart, x - 1))
            xstart = None
        elif col == 1 and xstart is None:
            xstart = x

    if xstart is not None:
        indicies.append((xstart, x))

    return indicies

Однако для массивов с 2 миллионами элементов этот метод медленный (~ 8 секунд). Есть ли способ использовать «встроенные модули» numpy (.argwhere, .split и т. д.), чтобы сделать это быстрее? Эта тема — самое близкое, что я нашел, однако я не могу найти правильную комбинацию для решения моей проблемы.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
3
0
90
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

Решение, которое я придумал, состоит в том, чтобы отдельно найти индексы, в которых впервые встречается 1, и индексы, в которых встречается последнее появление 1.

def get_indices2(arr):
    value_diff = np.diff(arr, prepend=0, append=0)
    start_idx = np.nonzero(value_diff == 1)[0]
    end_idx = np.nonzero(value_diff == -1)[0] - 1  # -1 to include endpoint
    idx = np.stack((start_idx, end_idx), axis=-1)
    return idx

Обратите внимание, что результатом является не список кортежей, а двумерный массив, подобный показанному ниже.

array([[0, 0],
       [3, 5],
       [7, 9]])

Вот эталон:

import timeit

import numpy as np


def get_indices(array):
    indices = []
    xstart = None
    for x, col in enumerate(array):
        if col == 0 and xstart is not None:
            indices.append((xstart, x - 1))
            xstart = None
        elif col == 1 and xstart is None:
            xstart = x

    if xstart is not None:
        indices.append((xstart, x))

    return indices


def get_indices2(arr):
    value_diff = np.diff(arr, prepend=0, append=0)
    start_idx = np.nonzero(value_diff == 1)[0]
    end_idx = np.nonzero(value_diff == -1)[0] - 1  # -1 to include endpoint
    idx = np.stack((start_idx, end_idx), axis=-1)
    return idx


def benchmark():
    rng = np.random.default_rng(0)
    arr = rng.integers(0, 1, endpoint=True, size=20_000_000)
    expected = np.asarray(get_indices(arr))

    for f in [get_indices, get_indices2]:
        t = np.asarray(f(arr))
        assert expected.shape == t.shape and np.array_equal(expected, t), f.__name__
        elapsed = min(timeit.repeat(lambda: f(arr), repeat=10, number=1))
        print(f"{f.__name__:25}: {elapsed}")


benchmark()

Результат:

get_indices              : 4.708652864210308
get_indices2             : 0.21052680909633636

Меня беспокоит то, что на моем компьютере вашей функции требуется менее 5 секунд для обработки 20 миллионов элементов, в то время как вы упоминаете, что обработка 2 миллионов элементов занимает 8 секунд. Так что возможно я что-то упускаю.

Обновлять

Мэтт в своем ответе предложил элегантное решение, используя изменение формы. Однако, если производительность важна, я бы посоветовал сначала оптимизировать часть np.diff.

def custom_int8_diff(arr):
    out = np.empty(len(arr) + 1, dtype=np.int8)
    out[0] = arr[0]
    out[-1] = -arr[-1]
    np.subtract(arr[1:], arr[:-1], out=out[1:-1])
    return out


def get_indices2_custom_diff(arr):
    mask = custom_int8_diff(arr)  # Use custom diff. Others unchanged.
    start_idx = np.nonzero(mask == 1)[0]
    end_idx = np.nonzero(mask == -1)[0] - 1
    return np.stack((start_idx, end_idx), axis=-1)

Для решения Мэтта по изменению формы мы можем использовать logical_xor, что еще быстрее.

def custom_bool_diff(arr):
    out = np.empty(len(arr) + 1, dtype=np.bool_)
    out[0] = arr[0]
    out[-1] = arr[-1]
    np.logical_xor(arr[1:], arr[:-1], out=out[1:-1])
    return out


def get_indices3_custom_diff(arr):
    value_diff = custom_bool_diff(arr)  # Use custom diff. Others unchanged.
    idx = np.nonzero(value_diff)[0]
    idx[1::2] -= 1
    return idx.reshape(-1, 2)

Тест (2 миллиона элементов):

get_indices              : 0.463582425378263
get_indices2             : 0.01675519533455372
get_indices3             : 0.01814895775169134
get_indices2_custom_diff : 0.010258681140840054
get_indices3_custom_diff : 0.006368924863636494

Тест (20 миллионов элементов):

get_indices              : 4.708652864210308
get_indices2             : 0.21052680909633636
get_indices3             : 0.19463363010436296
get_indices2_custom_diff : 0.14093663357198238
get_indices3_custom_diff : 0.08207075204700232

Как насчет:

import numpy as np
r = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1, 1])
s = np.diff(r, prepend=0, append=0)
t = np.where(s)[0]
t[1::2] -= 1
# don't do `tolist` if you don't need to, though
t.reshape(-1, 2).tolist() 
# [[0, 0], [3, 5], [7, 9]]

Обновление: я вижу, что примерно в то же время есть независимая публикация того же основного решения. Я подозреваю, что операции на месте + reshape здесь немного предпочтительнее разделения, создания копии с арифметикой и укладки (которое создает еще одну копию). tolist занимает здесь почти все время, поэтому просто остановитесь после reshape, если вам действительно не нужен список.

Другие вопросы по теме