Как векторизовать это для циклов в python?

Код ниже

import numpy as np


data = np.random.randint(0, 10, 12).reshape(3, 4)
print(data)

h, w = data.shape[:2]
dataMask = np.zeros((h, w, 10), np.int)

r = 2

for i in range(h):
    for j in range(w):
        for ir in range(i - r, i + r):
            for jr in range(j - r, j + r):
                if ir >= 0 and ir < h and jr >= 0 and jr < w:
                    dataMask[i, j, data[ir, jr]] += 1

print(dataMask)

У меня есть "данные" массива numpy с формой (h, w). Его элементами является целое число ∈ [0, 10).
Я создаю массив dataMask с формой (h, w, 10). dataMask[i, j, k] указывает количество точек, значение которых равно k в области данных. Эта область данных имеет центр (i,j) и r = 2 и представляет собой квадрат.

Как векторизовать циклы for в коде? Спасибо!

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
132
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

Вот один из способов использования cumsum:

import numpy as np


data = np.random.randint(0, 10, 1200).reshape(30, 40)
print(data)

h, w = data.shape[:2]
dataMask = np.zeros((h, w, 10), np.int)

r = 20

from time import time
T = []

T.append(time())

for i in range(h):
    for j in range(w):
        for ir in range(i - r, i + r):
            for jr in range(j - r, j + r):
                if ir >= 0 and ir < h and jr >= 0 and jr < w:
                    dataMask[i, j, data[ir, jr]] += 1

T.append(time())

m1 = np.zeros((h, w, 10), np.int)
np.put_along_axis(m1, data[...,None], 1, 2)
m2 = np.empty_like(m1)
m1 = m1.cumsum(1)
m2[: ,:-r+1] = m1[:, r-1:]
m2[:, -r+1:] = m1[:, -1, None]
m2[:, r+1:] -= m1[:, :-r-1]
m2 = m2.cumsum(0)
m1[:-r+1] = m2[r-1:]
m1[-r+1:] = m2[-1, None]
m1[r+1:] -= m2[:-r-1]

T.append(time())


assert (dataMask==m1).all()

print(np.diff(T))

Пример запуска с h,w,r = 30,40,20

# time [seconds] used by
# OP            cumsum
[9.23162699e-01 3.41892242e-04]

Это «частично векторизованное» решение, которое просто перебирает размер окна.

import numpy as np
from itertools import product

# Input data
np.random.seed(0)
data = np.random.randint(0, 10, 12).reshape(3, 4)
h, w = data.shape[:2]
dataMask = np.zeros((h, w, 10), np.int)
r = 2

# Original solution
for i in range(h):
    for j in range(w):
        for ir in range(i - r, i + r):
            for jr in range(j - r, j + r):
                if ir >= 0 and ir < h and jr >= 0 and jr < w:
                    dataMask[i, j, data[ir, jr]] += 1

# Partially vectorized solution
idx_i, idx_j = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
idx_i = idx_i.ravel()
idx_j = idx_j.ravel()
idx_k = data.ravel()
dataMask2 = np.zeros((h, w, 10), np.int)
for i, j in product(range(-r + 1, r + 1), repeat=2):
    ii = idx_i + i
    jj = idx_j + j
    m = (ii >= 0) & (ii < h) & (jj >= 0) & (jj < w)
    ii = ii[m]
    jj = jj[m]
    kk = idx_k[m]
    np.add.at(dataMask2, (ii, jj, kk), 1)

print(np.all(dataMask == dataMask2))
# True

На самом деле вы можете сделать это полностью векторизованным, просто разбивая данные (что использует больше памяти):

import numpy as np

# Fully vectorized
idx_i, idx_j = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
w_i, w_j = np.meshgrid(np.arange(-r + 1, r + 1), np.arange(-r + 1, r + 1), indexing='ij')
ii = (idx_i[:, :, np.newaxis, np.newaxis] + w_i).ravel()
jj = (idx_j[:, :, np.newaxis, np.newaxis] + w_j).ravel()
kk = np.tile(data[:, :, np.newaxis, np.newaxis], (1, 1, 2 * r, 2 * r)).ravel()
m = (ii >= 0) & (ii < h) & (jj >= 0) & (jj < w)
ii = ii[m]
jj = jj[m]
kk = kk[m]
dataMask3 = np.zeros((h, w, 10), np.int)
np.add.at(dataMask3, (ii, jj, kk), 1)
print(np.all(dataMask == dataMask3))
# True

Другие вопросы по теме