Быстрый способ удалить несколько строк по индексам из 2D-массива Pytorch или Numpy

У меня есть массив numpy (и, что эквивалентно, тензор Pytorch) формы Nx3. У меня также есть список индексов, соответствующих строкам, которые я хочу удалить из этого тензора. Этот список индексов называется remove_ixs. N очень большой, около 5 миллионов строк, а длина remove_ixs составляет 50 тысяч. Сейчас я делаю это следующим образом:

mask = [i not in remove_ixs for i in range(my_array.shape[0])]
new_array = my_array[mask,:]

Но первая строка просто не завершается, занимает вечность. Вышеупомянутое находится в numpy-коде. Эквивалентный код Pytorch также подойдет мне.

Есть ли более быстрый способ сделать это с помощью numpy или pytorch?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
51
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Вы можете использовать np.delete():

import numpy as np

A = np.random.rand(5000000, 3)
remove_ixs = np.random.choice(5000000, 50000, replace=False)
B = np.delete(A, remove_ixs, axis=0)

print(len(B))

Принты

4950000
Ответ принят как подходящий

Вы можете создать исходный mask (логический) массив, True для элементов, которые вы хотите удалить, а затем инвертировать его, чтобы получить mask элементов, которые вы хотите сохранить.

remove_mask = np.zeros(my_array.shape[0], dtype=bool)
remove_mask[remove_ixs] = True
mask = ~remove_mask
    
new_array = my_array[mask, :]

Или запустите все True и сделайте наоборот:

mask = np.ones(my_array.shape[0], dtype=bool)
mask[remove_ixs] = False
    
new_array = my_array[mask, :]

По какой-то причине первая версия работает быстрее для массивов меньшего размера.

Другие вопросы по теме