Мне нужно вычислить функцию ошибок V
V = Σi Σj X[i] X[j] σ[i][j]
Где σ[i][j]
— заданная матрица, и мне нужно относительно быстрое решение. Для этого я хочу создать еще одну матрицу Y, где
Y[i][j] = X[i]*X[j]
Итак, я могу просто суммировать Y * σ
. Есть ли хороший способ реализовать это с помощью функций numpy?
До сих пор я пытался использовать meshgrid(X, X), а затем применять np.prod к каждой строке, однако это дало ожидаемый результат и потребовало бы цикла for в Python.
Обновлено: Минимальный воспроизводимый пример:
cov = np.array(((0.1, 0.05), (0.05, 0.25)))
x = np.array((0.6, 0.4)
desired = x[0]*x[0]*cov[0][0] + x[0]*x[1]*cov[0][1] + x[1]*x[0]*cov[1][0] + x[1]*x[1] * cov[1][1]
np.outer
и np.einsum
— это то, что нужно, но обратите внимание, что создание такой (временной?) матрицы в памяти неэффективно, прежде всего, если она большая (потому что она привязана к памяти, а DRAM в настоящее время довольно медленны.
Используйте x @ s @ x
np.dot поддерживает умножение матриц.
Просто сделай
np.dot(x,sigma).dot(x)
ЕСЛИ вы видите его первое уравнение, ОП хочет получить скалярную ошибку. Внешний продукт он делает только как промежуточный расчет. Обновлено, чтобы отразить, что в этом случае x и y - это то же самое, что вы правильно указали.
вы также можете добавить x @ sigma @ x
.
@Onyambu np.dot(a, b)
эквивалентно a @ b
, когда можно применить матричное умножение. Оба будут вызывать процедуры BLAS, которые должны занимать почти все время при достаточно большом вводе. Единственное отличие заключается в том, что накладные расходы Numpy видны только на небольших входах.
np.dot(x,sigma).dot(x)
быстрее, чем np.einsum("i,j,ij->", x, x, s)
решение для задач любого размера.einsum(..., optimize=True)
имеет высокую стоимость запуска, но вступает в тройную ничью первым после N = 1000 или около того.(x @ s) @ x
и einsum()
с заранее рассчитанным путем сокращения, но они были очень похожи на другие результаты.Код решения Numba
@nb.njit(fastmath=True)
def numba_multiply(x, s):
V = 0
N = len(x)
assert s.ndim == 2 and s.shape[0] == N and s.shape[1] == N
# Note: this iteration order is fastest if s is a C ordered
# array. If it is Fortran ordered, it is better to swap these
# two for loops.
# See https://numpy.org/doc/1.20/reference/internals.html#multidimensional-array-indexing-order-issues
for i in range(N):
for j in range(N):
V += x[i] * x[j] * s[i, j]
return V
import numpy as np
import numba as nb
import perfplot
@nb.njit(fastmath=True)
def numba_multiply(x, s):
V = 0
N = len(x)
assert s.ndim == 2 and s.shape[0] == N and s.shape[1] == N
# Note: this iteration order is fastest if s is a C ordered
# array. If it is Fortran ordered, it is better to swap these
# two for loops.
# See https://numpy.org/doc/1.20/reference/internals.html#multidimensional-array-indexing-order-issues
for i in range(N):
for j in range(N):
V += x[i] * x[j] * s[i, j]
return V
# precompile Numba function
x_tmp = np.random.normal(size=10)
s_tmp = np.random.normal(size=(10, 10))
numba_multiply(x_tmp, s_tmp)
# precompute path
einsum_path, _ = np.einsum_path("i,j,ij", x_tmp, x_tmp, s_tmp)
import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))
perfplot.show(
setup=lambda n: (np.random.rand(n), np.random.rand(n, n)),
kernels=[
# Credit: Nick ODell
lambda x, s: numba_multiply(x, s),
# Credit: Tarifazo
lambda x, s: np.dot(x, s).dot(x),
# Credit: simon
lambda x, s: np.einsum("i,j,ij", x, x, s),
# Credit: Jérôme Richard
lambda x, s: np.einsum("i,j,ij", x, x, s, optimize=True),
# Credit: simon
lambda x, s: (np.outer(x, x) * s).sum(),
],
labels=[
"numba",
"dot product",
"einsum",
"einsum optimize",
"outer product",
],
n_range=[2**k for k in np.linspace(0, 14, 20)],
xlabel = "len(x)",
equality_check=np.allclose,
)
Обратите внимание, что код Numba с fastmath
менее безопасен, чем другие методы, поскольку fastmath
предполагается, например, что нет NaN
или Inf
или субнормальных чисел (в то время как другое решение не должно).
зачем вам сравнивать x @ s @ s
вместо x @ s @ x
?
Не могли бы вы привести небольшой явный минимально воспроизводимый пример с соответствующим ожидаемым результатом?