Допустим, у меня есть две матрицы A и B, заданные формулой
set.seed(123)
m1 = matrix(runif (10*5), nrow = 10, ncol = 5)
m2 = matrix(runif (10*5), nrow = 10, ncol = 5)
Я хочу найти для каждой строки в матрице A строку в матрице B, которая ближе всего к строке в матрице A. Я знаю, что могу сделать это, перебирая каждую строку в A и сравнивая ее с каждой строкой в B следующим образом:
for(i in 1:nrow(m1)){
dist = 9999
index = -1
for(j in 1:nrow(m2)){
test = sqrt(sum(abs(m1[i,]-m2[j,])))
if (test < dist) {
dist = test
index = j
}
}
print(index)
}
Однако у меня миллион строк, и это занимает вечность. Я изо всех сил пытаюсь найти эффективный способ. Есть идеи?
Я обдумывал норму L1 или L2 и сделал что-то среднее...
Вот одно из базовых решений R с использованием apply
:
apply(m1, 1, \(x) which.min(sqrt(colSums(abs(x - t(m2))))))
#[1] 8 3 2 3 3 1 2 3 6 10
Сравнивая его с вашим текущим решением, он работает хорошо:
set.seed(123)
m1 = matrix(runif (10 * 5), nrow = 10, ncol = 5)
m2 = matrix(runif (10 * 5), nrow = 10, ncol = 5)
baseR_sol <- function(m1, m2) {
apply(m1, 1, \(x) which.min(sqrt(colSums(abs(x - t(m2))))))
}
for_loop_sol <- function(m1, m2) {
for(i in 1:nrow(m1)){
dist = 9999
index = -1
for(j in 1:nrow(m2)){
test = sqrt(sum(abs(m1[i,]-m2[j,])))
if (test < dist) {
dist = test
index = j
}
}
print(index)
}
}
microbenchmark::microbenchmark(
baseR_sol = baseR_sol(m1, m2),
for_loop_sol = for_loop_sol(m1, m2), times = 10L
)
# expr min lq mean median uq max neval
# baseR_sol 158.0 185.2 865.81 195.35 224.8 6902.8 10
# for_loop_sol 764.6 830.2 1051.29 973.45 1312.0 1348.9 10
Можно попробовать collapse::dapply(x, f, MARGIN = 1)
быстрее apply
Последующий тест, основанный на решении @Ronak Shah
# Ronak Shah's original solution
f1 <- \() {
apply(m1, 1, \(x) which.min(sqrt(colSums(abs(x - t(m2))))))
}
# a variant to Ronak Shah's solution, by removing `sqrt` and moving `t(m2)` out of `apply`
f2 <- \() {
tm2 <- t(m2)
apply(m1, 1, \(x) which.min(colSums(abs(x - tm2))))
}
# Friede's solution
f3 <- \() {
tm2 <- t(m2)
collapse::dapply(X = m1, FUN = \(x) which.min(Rfast::colsums((abs(x - tm2)))), MARGIN = 1L)
}
# G. Grothendieck's solution
f4 <- \(){
apply(Rfast::dista(m1, m2, type = "manhattan"), 1, which.min)
}
# a variant to G. Grothendieck's solution, by using `dapply`
f5 <- \(){
collapse::dapply(X = Rfast::dista(m1, m2, type = "manhattan"), which.min, MARGIN = 1)
}
# Onyambu's solution
Rcpp::cppFunction("IntegerVector closestRcpp(NumericMatrix A, NumericMatrix B) {
int n = A.nrow();
IntegerVector closest(n);
for (int i = 0; i < n; ++i) {
double prev_dist = 1e308, dist;
for (int j = 0; j < B.nrow(); ++j){
double dist = sum(abs(A(i, _) - B(j, _)));
if (dist < prev_dist) {
prev_dist = dist;
closest[i] = j+1;
}
}
}
return closest;
}")
f6 <- \() {
closestRcpp(m1, m2)
}
microbenchmark(
f1(),
f2(),
f3(),
f4(),
f5(),
f6(),
unit = "relative",
check = "equivalent"
)
шоу
Unit: relative
expr min lq mean median uq max neval
f1() 28.222593 28.05963 7.522702 28.221577 29.144915 2.4924392 100
f2() 18.704444 19.37280 5.715675 20.383970 21.840241 2.1410204 100
f3() 11.630370 11.87556 4.233875 12.320916 12.728101 2.2702689 100
f4() 12.555926 13.06498 3.480127 13.401387 14.256614 1.0021835 100
f5() 9.074074 9.30748 2.714870 9.563296 9.612287 0.9920494 100
f6() 1.000000 1.00000 1.000000 1.000000 1.000000 1.0000000 100
Может быть f3 = \() { tm2 = t(m2); collapse::dapply(X=m1, FUN=\(x) which.min(Rfast::colsums((abs(x - tm2)))), MARGIN=1L) }
@Friede, спасибо за твой вклад! добавлено в тест
Огромное спасибо за интересный тест!
Не могли бы вы также добавить в выходные данные Rcpp, написанный ниже?
@Onyambu смотри обновление
Мы можем использовать dista
из Rfast
library(Rfast)
apply(dista(m1, m2, type = "manhattan"), 1, which.min)
## [1] 8 3 2 3 3 1 2 3 6 10
Быстрее, если вместо этого использовать collapse::dapply
.
отличное решение! добавлено в бенчмарк
Рассмотрите возможность использования Rcpp/C:
Rcpp::cppFunction("IntegerVector closestRcpp(NumericMatrix A, NumericMatrix B) {
int n = A.nrow();
IntegerVector closest(n);
for (int i = 0; i < n; ++i) {
double prev_dist = 1e308, dist;
for (int j = 0; j < B.nrow(); ++j){
double dist = sum(abs(A(i, _) - B(j, _)));
if (dist < prev_dist) {
prev_dist = dist;
closest[i] = j+1;
}
}
}
return closest;
}")
closestRcpp(m1, m2)
[1] 8 3 2 3 3 1 2 3 6 10
Зачем извлекать квадратный корень?