У меня есть следующий код:
zero_indices = np.where(A==0)
nonzero_indices = np.where(A!=0)
Но я хочу найти нулевые и ненулевые индексы с помощью всего лишь одного вызова np.where.
Один из способов, который я нашел, это:
zero_indices = np.where(array == 0)
non_zero_indices = np.setdiff1d(np.arange(array.size), zero_indices)
Однако при тестировании на больших массивах я обнаружил, что на самом деле это в 10 раз медленнее, чем простой вызов np.where дважды.
Есть ли более эффективный способ сделать это?
.where
уже очень эффективно. Вы можете рассмотреть возможность небольшой оптимизации, определив mask = A == 0
и используя mask
и ~mask
в вызовах. И вместо np.where(mask)
рассмотрите возможность использования np. asarray(mask).nonzero()
, как указано в документации numpy.org/doc/stable/reference/generated/numpy.where.htmlДвойной подход где кажется самым быстрым. Как отмечено в комментарии, вы можете немного улучшить его, сохранив маску.
Я также тестировал другие подходы: numba и argsort + разделение, numba работало гораздо медленнее, а argsort
+diff
было не лучше.
Предполагая этот тестовый массив:
np.random.seed(0)
A = np.random.randint(0, 2, size=10_000_000)
Ниже приведены сроки для различных подходов:
zero_indices = np.where(A==0)
nonzero_indices = np.where(A!=0)
# 137 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
mask = A == 0
zero_indices = np.where(mask)
nonzero_indices = np.where(~mask)
# 122 ms ± 959 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
argsort
+split
mask = A==0
tmp = np.argsort(mask, kind='stable')
nonzero_indices, nonzero_indices = np.split(tmp, [len(A)-mask.sum()])
# 125 ms ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
from numba import jit
@jit(nopython=True)
def zero_nonzero(A):
zero_indices = []
nonzero_indices = []
for i, item in enumerate(A):
if item == 0:
zero_indices.append(i)
else:
nonzero_indices.append(i)
return zero_indices, nonzero_indices
zero_indices, nonzero_indices = zero_nonzero(A)
# 437 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
zero_indices = np.where(A == 0)
non_zero_indices = np.setdiff1d(np.arange(A.size), zero_indices)
# 519 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
A = np.zeros(10_000_000)
:# double where + reusing the mask
22 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# argsort + split
34.5 ms ± 2.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Спасибо за подробные тесты! Мне интересно, почему np.where не предназначен для возврата как нулевых, так и ненулевых_indices? В принципе, не может ли np.where(A) создать оба этих вывода за один проход через массив A, не жертвуя при этом скоростью?
Это правда, но для этого потребуется немного больше вычислений и, возможно, гораздо больше памяти. Я думаю, это может быть необязательный флаг.
входы и ожидаемые результаты?