Меня интересует определение длины последовательности единиц вдоль одной оси в многомерном массиве.
Для одномерного массива я нашел решение, используя ответы на этот старый вопрос. Например. [0,1,0,0,1,1,1,0,1,1] --> [нан,1,нан,нан,3,нан,нан,нан,2,нан]
Для 3D-массива я, конечно, могу создать цикл, но не хочу. (Базовая климатология: цикл по всем ячейкам сетки широты и долготы сделает это очень медленно.)
Я пытаюсь найти решение, соответствующее решению 1D. Помощь будет очень признательна, в соответствии с кодом 1D, но, конечно, приветствуются и полные другие решения.
Для справки, это мое рабочее 1D-решение:
import xarray as xr
import numpy as np
def run_lengths(da):
n = len(da)
y = da.values[1:] != da.values[:-1]
i = np.append(np.where(y), n - 1)
z = np.diff(np.append(-1,i))
p = np.cumsum(np.append(0,z))[:-1]
runs = np.where(da[i]==1)[0]
runs_len = z[runs] # length of sequence
time_val = da.time[p[runs]] # date of first day in sequence
da_runs = xr.DataArray(runs_len,coords = {'time':time_val})
_,da_runs = xr.align(da,da_runs,join='outer') # make sure we have full time axis
return da_runs
da = xr.DataArray(np.array([[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]],[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]]]),coords = {'lat':[0,1],'lon':[0,1,2],'time':[0,1,2,3,4,5]})
da_runs = run_lengths(da[0,1])
print(da_runs)
<xarray.DataArray (time: 6)>
array([ 1., nan, 2., nan, nan, 1.])
Coordinates:
* time (time) int64 0 1 2 3 4 5
А это попытка в 3D. Я застрял в том, как переместить действительные записи в i вперед/удалить NaN из i. (А может быть, и помимо этого?)
def run_lengths_3D(da):
n = len(da.time)
y = da.values[:,:,1:] != da.values[:,:,:-1]
y = xr.DataArray(y,coords = {'lat':da.lat,'lon':da.lon,'time':da.time[0:-1]})
i = y.where(y)*xr.DataArray(np.arange(0,len(da.time[0:-1])),coords = {'time':y.time}) -1
@olivier-gauthé Спасибо, я добавил пример кода. Обратите внимание: здесь я создал искусственный массив данных о 3D-климате da, на самом деле он имеет размер [256,512,3650].
Требуется скорость, используйте numba. Все еще нужно выжать немного больше скорости, используя оптимизированный C. - в конечном итоге будет немного лучше.






Для этой задачи вы можете попробовать использовать цифру, например:
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype = "uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
]
)
print(calculate(arr))
Распечатки:
[[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]
[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]]
Тестирование с использованием timeit + параллельной версии:
from timeit import timeit
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype = "uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
@numba.njit(parallel=True)
def calculate_parallel(arr):
out = np.empty_like(arr, dtype = "uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in numba.prange(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
],
dtype = "uint8",
)
# compile calculate()/calculate_parallel()
assert np.allclose(calculate(arr), calculate_parallel(arr))
np.random.seed(42)
arr = np.random.randint(low=0, high=2, size=(256, 512, 3650), dtype = "uint8")
t_serial = timeit("calculate(arr)", number=1, globals=globals())
t_parallel = timeit("calculate_parallel(arr)", number=1, globals=globals())
print(f"{t_serial * 1_000_000:.2f} usec")
print(f"{t_parallel * 1_000_000:.2f} usec")
Печатает на моей машине (AMD 5700x):
1575227.47 usec
320453.57 usec
пожалуйста, добавьте минимальный воспроизводимый пример. Что такое
da?