Как оптимизировать функцию агрегации с условиями?

У меня есть функция агрегирования, которая суммирует группы данных, затем создает флаг на основе набора условий и присваивает его группе. Проблема в том, что необходимо объединить большое количество групп, и каждая группа очень мала.

Это означает, что время, необходимое для выполнения агрегирования, даже для набора данных небольшого размера, непомерно велико, и это необходимо для работы с наборами данных с миллионами строк.

Ниже я создал воспроизводимый пример, который иллюстрирует проблему и имеет стиль логики, аналогичный моей функции агрегации (у меня больше условий, но они похожи по своей природе):

#install.packages("palmerpenguins")
library(data.table)
library(palmerpenguins)

# Create data
GROUPS <- 1:100
penguin_list <- lapply(GROUPS, \(x) data.table(group = x, penguins))
penguin_table <- rbindlist(penguin_list)

# Aggregation function
aggregatePenguinMass <- function(mass, sex, ratio = 2/3){
    data <- data.table(mass, sex)
    
    total <- sum(data[,mass], na.rm = TRUE)
    
    n_sex <- data[,.N, by = sex]
    n_male <- n_sex[sex == "male", N]
    n_female <- n_sex[sex == "female", N]
    
    if (n_female >= ratio * (n_male + n_female)){
        return(data.table(total = total,
                          flag = "F"))
    } else {
        return(data.table(total = total,
                          flag = "M"))
    }
}

# Perform aggregation and time
system.time(
penguin_table[, aggregatePenguinMass(body_mass_g, sex), by = .(group, species, year)]
)
#   user  system elapsed 
#   2.66    0.47    7.30 

Как я могу изменить эту функцию или способ выполнения агрегации, чтобы сделать ее на порядок быстрее?


Тесты

Unit: milliseconds
         expr       min        lq       mean    median        uq       max neval
       base() 1793.4883 1803.1108 2491.32386 1821.8588 1928.4540 8195.9058    10
 RBarradas1() 1554.5428 1577.5579 1614.60087 1583.7645 1637.0681 1764.4564    10
 RBarradas2()  119.5481  127.6831  186.11583  131.7930  141.9157  657.2762    10
    NGraham()   23.2437   23.8428   25.50738   24.5770   27.4004   30.6569    10
       Miff()  427.2106  440.3240  462.32146  455.1878  462.6939  577.3803    10

Оптимизация заключается в том, чтобы sum(mass, na.rm = TRUE) вместо создания таблицы data.table затем подмножество и суммирование извлеченного вектора.

Rui Barradas 11.07.2024 11:56

Не могли бы вы предоставить более крупные и реалистичные данные?

s_baldur 11.07.2024 18:52
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
2
2
114
5
Перейти к ответу Данный вопрос помечен как решенный

Ответы 5

Ответ принят как подходящий

Я думаю, вы хотите избежать создания такого количества таблиц данных, сколько у вас есть групп, и работать только с одной большой таблицей данных. Я могу воспроизвести вашу агрегацию с помощью dtyplr (поскольку я не очень разбираюсь в синтаксисе data.table). это похоже

library(dtplyr)

agg_dtyplr <- function(dt,mass, ratio = 2 / 3) {
  
  lazy_dt(dt) |> group_by(group, species, year)   |>
    summarise(
       total = sum({{mass}}, na.rm = TRUE),
    
       n_female = sum(1 * (sex == "female"),na.rm=TRUE),
       n = n())|> 
      mutate(
       flag = if_else(n_female >= ratio * n,
         "F", "M"
       )
     ) |> 
    select(group,species,year,total,flag) |> 
   as.data.table()
}
  agg_dtyplr(penguin_table,body_mass_g)

Вам не нужен 1* в 1 * (sex == "female"). Внутренне условие становится нулями или единицами.

Rui Barradas 11.07.2024 13:27

да, я написал более многословно, чем нужно; основная идея — использовать одну таблицу data.table;

Nir Graham 11.07.2024 16:12

Вот более быстрая функция.
Он не создает, а затем подмножество data. Он работает непосредственно с входными векторами.

library(data.table)
library(palmerpenguins)

# Create data
GROUPS <- 1:100
penguin_list <- lapply(GROUPS, \(x) data.table(group = x, penguins))
penguin_table <- rbindlist(penguin_list)

# Aggregation function (OP)
aggregatePenguinMass <- function(mass, sex, ratio = 2/3){
  data <- data.table(mass, sex)
  
  total <- sum(data[,mass], na.rm = TRUE)
  
  n_sex <- data[,.N, by = sex]
  n_male <- n_sex[sex == "male", N]
  n_female <- n_sex[sex == "female", N]
  
  if (n_female >= ratio * (n_male + n_female)){
    return(data.table(total = total,
                      flag = "F"))
  } else {
    return(data.table(total = total,
                      flag = "M"))
  }
}

# Aggregation function 2
aggregatePenguinMass2 <- function(mass, sex, ratio = 2/3){
  i <- !is.na(mass) & !is.na(sex)  
  total <- sum(mass[!is.na(mass)])
  
  f <- sum(sex[i] == "female")  
  if (f >= ratio * total){
    data.table(total = total, flag = "F")
  } else {
    data.table(total = total, flag = "M")
  }
}

# Perform aggregation and time
system.time(
  res <- penguin_table[, aggregatePenguinMass(body_mass_g, sex), by = .(group, species, year)]
)
#>    user  system elapsed 
#>    0.41    0.08    2.40

system.time(
  res2 <- penguin_table[, aggregatePenguinMass2(body_mass_g, sex), by = .(group, species, year)]
)
#>    user  system elapsed 
#>    0.04    0.00    0.11

all.equal(res, res2)
#> [1] TRUE

Created on 2024-07-11 with reprex v2.1.0

В этом случае с dplyr работать быстрее:

Оригинальная версия

system.time(
  a<- penguin_table[, aggregatePenguinMass(body_mass_g, sex), by = .(group, species, year)]
)

#user  system elapsed 
#1.45    0.47    1.61 

Использование dplyr Включая пару оптимизаций в суммирование

library(dplyr)

aggregatePenguinMass2 <- function(mass, sex, ratio = 2/3){
  total <- sum(mass, na.rm = TRUE)
  n_sex <- table(sex)

  n_male <- n_sex["male"]
  n_female <- n_sex["female"]
  
  if (n_female >= ratio * sum(n_sex)){
    return(data.table(total = total,
                      flag = "F"))
  } else {
    return(data.table(total = total,
                      flag = "M"))
  }
}

system.time(
 b <- penguin_table %>% summarise(aggregatePenguinMass2(body_mass_g, sex), .by = c(group, species, year))
)

#user  system elapsed 
#0.19    0.00    0.19 

all.equal(as.data.frame(a), b)
#[1] TRUE

Преобразование в data.table в конце мало влияет на время, если важно иметь вывод в этом формате.

Пара более быстрых вариантов. Опция 1:

ratio <- 2/3
r <- 2 - ratio

jblood1 <- function() {
  penguin_table[, s := match(sex, c("female", "male"))][
    ,.(total = sum(body_mass_g, na.rm = TRUE),
       flag = if (mean(s, na.rm = TRUE) > r) "M" else "F"),
    .(group, species, year)
  ]
}

microbenchmark::microbenchmark(
  base(),
  jblood1(),
  check = "equal",
  times = 10
)
#> Unit: milliseconds
#>       expr      min        lq       mean     median        uq       max neval
#>     base() 1546.044 1554.1209 1582.18637 1569.23260 1606.2057 1661.9998    10
#>  jblood1()    7.428    7.5738   10.16692    8.15465   11.2894   20.9406    10

Вариант 2 работает еще быстрее, поскольку полностью исключает операции группировки с помощью сортировки за счет дополнительной сложности:

jblood2 <- function() {
  setorder(penguin_table, group, species, year)[
    , `:=`(csSex = cumsum(c(1, 1i, 0)[match(sex, c("female", "male", NA))]),
           csMass = fcumsum(body_mass_g))
  ][
    (grp <- rleid(group, species, year)) != shift(grp, -1, 0L),
    .(group, species, year, total = diff(c(0, csMass)),
      flag = fifelse(Im(d <- diff(c(0, csSex)))/Re(d) > 1/ratio - 1, "M", "F"))
  ]
}

all.equal(setorder(jblood1(), group, species, year), jblood2())
#> [1] TRUE

microbenchmark::microbenchmark(
  jblood1(),
  jblood2()
)
#> Unit: milliseconds
#>       expr    min      lq     mean  median      uq     max neval
#>  jblood1() 7.2552 7.45605 8.469040 7.75345 8.21155 12.9634   100
#>  jblood2() 2.6411 2.82530 3.122956 3.06550 3.37760  4.5615   100

Проводил сравнительный анализ с большими данными (34 миллиона строк) и переполнениями fcumsum(...), но работал с as.integer64(body_mass_g).

s_baldur 11.07.2024 19:02

Или используйте as.numeric (и вычтите среднее значение для большей точности при работе с очень большими наборами данных). Или обратите внимание, что body_mass_g увеличивается с шагом 25, поэтому целое число разделите на 25L, а затем в конце умножьте на 25L.

jblood94 11.07.2024 19:19

Использование коллапса с Rcpp:

library(collapse)
penguin_table |>
  fgroup_by(group, species, year) |>
  fsummarise(total = fsum(body_mass_g, na.rm = TRUE), flag = create_flag(sex))

Или, альтернативно, с помощью data.table (медленнее по данным, но может лучше масштабироваться (?))

penguin_table[, .(sum(body_mass_g, na.rm = TRUE), flag = create_flag(sex)),
              .(group, species, year)]

Где

Rcpp::cppFunction("char create_flag(const SEXP sex_vector, const double ratio = 2.0 / 3.0) {
  int m = 0;
  int f = 0;
  
  const R_xlen_t n = XLENGTH(sex_vector);
  const int* sex_ptr = INTEGER(sex_vector);
  
  for (R_xlen_t i = 0; i < n; i++) {
    int val = sex_ptr[i];
    if (val == NA_INTEGER) {
      continue;
    } else if (val == 1) {
      f += 1;
    } else {
      m += 1;
    }
  }
  
  return (f >= ratio * (f + m)) ? 'F' : 'M';
}")

Тесты

microbenchmark::microbenchmark(
  op = penguin_table[, aggregatePenguinMass(body_mass_g, sex), by = .(group, species, year)],
  sb = penguin_table |>
    fgroup_by(group, species, year) |>
    fsummarise(total = fsum(body_mass_g, na.rm = TRUE), flag = create_flag(sex)),
  sbdt = penguin_table[, .(sum(body_mass_g, na.rm = TRUE), flag = create_flag(sex)),
                       .(group, species, year)],
  times = 10
)

# Unit: milliseconds
#  expr       min        lq       mean     median        uq       max neval
#    op 1421.3494 1475.1812 1495.12772 1510.03970 1528.2198 1537.4873    10
#    sb    1.7398    1.7904    1.91444    1.84415    1.9116    2.5670    10
#  sbdt    3.3524    3.3710    3.78065    3.62835    4.1727    4.6842    10

Другие вопросы по теме