Попытка разбить массив numpy на основе условия. Фильтр должен взять столбец split_column и его значение split_value и разделить массив на две части, одна из которых содержит подмассив всех строк <= значение split_value в заданном столбце split_column.
т. е. данный,
a = np.array([[5, 'hi', 23],
[4, 'we', 15],
[3, 'me', 10],
[2, 'be', 67],
[1, 'it', 100]])
split_column = 0
split_value = 3
Ожидаемый результат
[[3, 'me', 10],
[2, 'be', 67],
[1, 'it', 100]]
Я попробовал это решение a[a[:, split_column] <= split_value]
, но оно работает, только если все элементы числовые.
Для смешанных типов в массиве numpy (как показано выше) я получаю
TypeError: «<=» не поддерживается между экземплярами «numpy.ndarray» и «int»
Использование str(), как в a[a[:, split_column] <= str(split_value)]
, не является решением, потому что 10 <= 3 становится истинным, что неверно. Для столбца (1) мне нужно сравнение str, но для других столбцов это должно быть числовое сравнение.
Как мы могли бы сделать это в numpy или нам нужно перебирать все элементы, проверяя типы перед сравнением?
Преобразуйте столбцы в желаемое type
, используя numpy.array.astype
:
a[a[:,0].astype(int) <= 3]
array([['3', 'me', '10'],
['2', 'be', '67'],
['1', 'it', '100']], dtype='<U11')
a.dtype
это строка, верно? Список был смешанным, но не массив. Распечататьa