Speed comparison

Neuroblastoma data

Consider the neuroblastoma data. There are 3418 labeled examples. If we consider subsets, how long does it take to compute the AUM and its directional derivatives?

data(neuroblastomaProcessed, package="penaltyLearning")
library(data.table)
nb.err <- data.table(neuroblastomaProcessed$errors)
nb.err[, example := paste0(profile.id, ".", chromosome)]
nb.X <- neuroblastomaProcessed$feature.mat
max.log <- if(interactive())3.5 else 3
(N.pred.vec <- as.integer(10^seq(1, max.log, by=0.5)))
#> [1]   10   31  100  316 1000
timing.dt.list <- list()
for(N.pred in N.pred.vec){
  N.pred.names <- rownames(nb.X)[1:N.pred]
  N.diffs.dt <- aum::aum_diffs_penalty(nb.err, N.pred.names)
  pred.dt <- data.table(example=N.pred.names, pred.log.lambda=0)
  timing.df <- microbenchmark::microbenchmark(penaltyLearning={
    roc.list <- penaltyLearning::ROChange(nb.err, pred.dt, "example")
  }, aum={
    aum.list <- aum::aum(N.diffs.dt, pred.dt$pred.log.lambda)
  }, times=10)
  timing.dt.list[[paste(N.pred)]] <- with(timing.df, data.table(
    package=expr, N.pred, seconds=time/1e9))
}
(timing.dt <- do.call(rbind, timing.dt.list))
#>              package N.pred     seconds
#>               <fctr>  <int>       <num>
#>   1:             aum     10 0.000097301
#>   2: penaltyLearning     10 0.509706500
#>   3:             aum     10 0.000072701
#>   4:             aum     10 0.000030701
#>   5: penaltyLearning     10 0.498848601
#>   6: penaltyLearning     10 0.605855601
#>   7: penaltyLearning     10 0.525148501
#>   8: penaltyLearning     10 1.007690101
#>   9: penaltyLearning     10 1.902207401
#>  10:             aum     10 0.000083001
#>  11:             aum     10 0.000042702
#>  12: penaltyLearning     10 3.013572501
#>  13:             aum     10 0.000061801
#>  14: penaltyLearning     10 1.327742202
#>  15: penaltyLearning     10 0.670776101
#>  16:             aum     10 0.000077101
#>  17:             aum     10 0.000023202
#>  18:             aum     10 0.000015601
#>  19: penaltyLearning     10 0.532483001
#>  20:             aum     10 0.000067102
#>  21: penaltyLearning     31 0.715787702
#>  22: penaltyLearning     31 0.779073901
#>  23: penaltyLearning     31 1.532825702
#>  24:             aum     31 0.000066401
#>  25:             aum     31 0.000023201
#>  26:             aum     31 0.000018201
#>  27:             aum     31 0.000016700
#>  28:             aum     31 0.000015201
#>  29: penaltyLearning     31 1.923791901
#>  30: penaltyLearning     31 0.484168101
#>  31:             aum     31 0.000066101
#>  32:             aum     31 0.000024601
#>  33: penaltyLearning     31 0.431930701
#>  34: penaltyLearning     31 0.541138100
#>  35:             aum     31 0.000076000
#>  36: penaltyLearning     31 0.549841300
#>  37: penaltyLearning     31 0.751044400
#>  38:             aum     31 0.000075600
#>  39: penaltyLearning     31 0.467962400
#>  40:             aum     31 0.000074801
#>  41:             aum    100 0.000101400
#>  42: penaltyLearning    100 0.739468902
#>  43:             aum    100 0.000085401
#>  44: penaltyLearning    100 0.742329401
#>  45: penaltyLearning    100 0.792718902
#>  46: penaltyLearning    100 0.765943401
#>  47:             aum    100 0.000081800
#>  48: penaltyLearning    100 0.671504001
#>  49:             aum    100 0.000080301
#>  50: penaltyLearning    100 1.365807801
#>  51: penaltyLearning    100 1.154336400
#>  52:             aum    100 0.000086300
#>  53:             aum    100 0.000045001
#>  54:             aum    100 0.000031501
#>  55: penaltyLearning    100 0.851427900
#>  56:             aum    100 0.000093801
#>  57: penaltyLearning    100 0.831670201
#>  58:             aum    100 0.000085201
#>  59: penaltyLearning    100 0.736335900
#>  60:             aum    100 0.000080301
#>  61: penaltyLearning    316 1.153899301
#>  62:             aum    316 0.000109401
#>  63: penaltyLearning    316 1.089968301
#>  64: penaltyLearning    316 1.275377001
#>  65:             aum    316 0.000113800
#>  66: penaltyLearning    316 1.313045701
#>  67: penaltyLearning    316 0.882887001
#>  68:             aum    316 0.000100200
#>  69:             aum    316 0.000048601
#>  70: penaltyLearning    316 1.169340700
#>  71:             aum    316 0.000105602
#>  72: penaltyLearning    316 0.699940901
#>  73:             aum    316 0.000113301
#>  74:             aum    316 0.000052801
#>  75:             aum    316 0.000044600
#>  76:             aum    316 0.000041801
#>  77:             aum    316 0.000044300
#>  78: penaltyLearning    316 1.028359901
#>  79: penaltyLearning    316 1.045927102
#>  80: penaltyLearning    316 0.888910101
#>  81: penaltyLearning   1000 2.285941701
#>  82: penaltyLearning   1000 2.011495601
#>  83:             aum   1000 0.000178101
#>  84:             aum   1000 0.000120001
#>  85:             aum   1000 0.000112201
#>  86:             aum   1000 0.000106101
#>  87: penaltyLearning   1000 1.959409600
#>  88:             aum   1000 0.000153801
#>  89:             aum   1000 0.000107101
#>  90:             aum   1000 0.000085100
#>  91: penaltyLearning   1000 2.345941000
#>  92: penaltyLearning   1000 2.092365800
#>  93:             aum   1000 0.000179900
#>  94:             aum   1000 0.000112001
#>  95: penaltyLearning   1000 2.053903501
#>  96: penaltyLearning   1000 2.491802701
#>  97: penaltyLearning   1000 2.234580400
#>  98: penaltyLearning   1000 2.153815601
#>  99:             aum   1000 0.000191501
#> 100: penaltyLearning   1000 2.016691202
#>              package N.pred     seconds

Below we summarize and plot these timings.

stats.dt <- timing.dt[, .(
  q25=quantile(seconds, 0.25),
  median=median(seconds),
  q75=quantile(seconds, 0.75)
), by=.(package, N.pred)]
library(ggplot2)
gg <- ggplot()+
  geom_line(aes(
    N.pred, median, color=package),
    data=stats.dt)+
  geom_ribbon(aes(
    N.pred, ymin=q25, ymax=q75, fill=package),
    data=stats.dt,
    alpha=0.5)+
  scale_x_log10(limits=stats.dt[, c(min(N.pred), max(N.pred)*5)])+
  scale_y_log10()
directlabels::direct.label(gg, "right.polygons")

plot of chunk unnamed-chunk-2

From the plot above we can see that both packages have similar asymptotic time complexity. However aum is faster by orders of magnitude (speedups shown below).

stats.wide <- data.table::dcast(
  stats.dt, N.pred ~ package, value.var = "median")
stats.wide[, speedup := penaltyLearning/aum][]
#> Key: <N.pred>
#>    N.pred penaltyLearning         aum   speedup
#>     <int>           <num>       <num>     <num>
#> 1:     10       0.6383159 6.44515e-05  9903.817
#> 2:     31       0.6328145 4.53510e-05 13953.706
#> 3:    100       0.7793312 8.35005e-05  9333.251
#> 4:    316       1.0679477 7.65005e-05 13960.009
#> 5:   1000       2.1230907 1.16101e-04 18286.584

R implementation

In this section we show a base R implementation of aum.

diffs.df <- data.frame(
  example=c(0,1,1,2,3),
  pred=c(0,0,1,0,0),
  fp_diff=c(1,1,1,0,0),
  fn_diff=c(0,0,0,-1,-1))
pred.log.lambda <- c(0,1,-1,0)
microbenchmark::microbenchmark("C++"={
  aum::aum(diffs.df, pred.log.lambda)
}, R={
  thresh.vec <- with(diffs.df, pred-pred.log.lambda[example+1])
  s.vec <- order(thresh.vec)
  sort.diffs <- data.frame(diffs.df, thresh.vec)[s.vec,]
  for(fp.or.fn in c("fp","fn")){
    ord.fun <- if(fp.or.fn=="fp")identity else rev
    fwd.or.rev <- sort.diffs[ord.fun(1:nrow(sort.diffs)),]
    fp.or.fn.diff <- fwd.or.rev[[paste0(fp.or.fn,"_diff")]]
    last.in.run <- c(diff(fwd.or.rev$thresh.vec) != 0, TRUE)
    after.or.before <-
      ifelse(fp.or.fn=="fp",1,-1)*cumsum(fp.or.fn.diff)[last.in.run]
    distribute <- function(values)with(fwd.or.rev, structure(
      values,
      names=thresh.vec[last.in.run]
    )[paste(thresh.vec)])
    out.df <- data.frame(
      before=distribute(c(0, after.or.before[-length(after.or.before)])),
      after=distribute(after.or.before))
    sort.diffs[
      paste0(fp.or.fn,"_",ord.fun(c("before","after")))
    ] <- as.list(out.df[ord.fun(1:nrow(out.df)),])
  }
  AUM.vec <- with(sort.diffs, diff(thresh.vec)*pmin(fp_before,fn_before)[-1])
  list(
    aum=sum(AUM.vec),
    deriv_mat=sapply(c("after","before"),function(b.or.a){
      s <- if(b.or.a=="before")1 else -1
      f <- function(p.or.n,suffix=b.or.a){
        sort.diffs[[paste0("f",p.or.n,"_",suffix)]]
      }
      fp <- f("p")
      fn <- f("n")
      aggregate(
        s*(pmin(fp+s*f("p","diff"),fn+s*f("n","diff"))-pmin(fp, fn)),
        list(sort.diffs$example),
        sum)$x
    }))
}, times=10)
#> Unit: microseconds
#>  expr     min        lq       mean    median        uq       max neval
#>   C++    10.0    14.201    36.6109    34.801    58.502    65.301    10
#>     R 33817.3 34736.901 41731.8009 38217.101 48950.300 58765.100    10

It is clear that the C++ implementation is several orders of magnitude faster.

Synthetic data

library(data.table)
max.N <- 1e6
(N.pred.vec <- as.integer(10^seq(1, log10(max.N), by=0.5)))
#>  [1]      10      31     100     316    1000    3162   10000   31622  100000
#> [10]  316227 1000000
max.y.vec <- rep(c(0,1), l=max.N)
max.diffs.dt <- aum::aum_diffs_binary(max.y.vec)
set.seed(1)
max.pred.vec <- rnorm(max.N)
timing.dt.list <- list()
for(N.pred in N.pred.vec){
  print(N.pred)
  N.diffs.dt <- max.diffs.dt[1:N.pred]
  N.pred.vec <- max.pred.vec[1:N.pred]
  timing.df <- microbenchmark::microbenchmark(dt_sort={
    N.diffs.dt[order(N.pred.vec)]
  }, R_sort_radix={
    sort(N.pred.vec, method="radix")
  }, R_sort_quick={
    sort(N.pred.vec, method="quick")
  }, aum_sort={
    aum.list <- aum:::aum_sort_interface(N.diffs.dt, N.pred.vec)
  }, times=10)
  timing.dt.list[[paste(N.pred)]] <- with(timing.df, data.table(
    package=expr, N.pred, seconds=time/1e9))
}
#> [1] 10
#> [1] 31
#> [1] 100
#> [1] 316
#> [1] 1000
#> [1] 3162
#> [1] 10000
#> [1] 31622
#> [1] 100000
#> [1] 316227
#> [1] 1000000
(timing.dt <- do.call(rbind, timing.dt.list))
#>           package  N.pred     seconds
#>            <fctr>   <int>       <num>
#>   1:     aum_sort      10 0.000056900
#>   2: R_sort_radix      10 0.000169800
#>   3:     aum_sort      10 0.000026800
#>   4: R_sort_quick      10 0.000061301
#>   5:     aum_sort      10 0.000013301
#>  ---                                 
#> 436:     aum_sort 1000000 0.529329601
#> 437:      dt_sort 1000000 0.159101501
#> 438:     aum_sort 1000000 0.531869401
#> 439: R_sort_radix 1000000 0.132736401
#> 440: R_sort_quick 1000000 0.099426701

Below we summarize and plot these timings.

stats.dt <- timing.dt[, .(
  q25=quantile(seconds, 0.25),
  median=median(seconds),
  q75=quantile(seconds, 0.75)
), by=.(package, N.pred)]
library(ggplot2)
gg <- ggplot()+
  geom_line(aes(
    N.pred, median, color=package),
    data=stats.dt)+
  geom_ribbon(aes(
    N.pred, ymin=q25, ymax=q75, fill=package),
    data=stats.dt,
    alpha=0.5)+
  scale_x_log10(limits=stats.dt[, c(min(N.pred), max(N.pred)*5)])+
  scale_y_log10()
directlabels::direct.label(gg, "right.polygons")

plot of chunk unnamed-chunk-6