Using causalOT

Introduction

causalOT was developed to reproduce the methods in Optimal transport methods for causal inference. The functions in the package are built to construct weights to make distributions more same and estimate causal effects. We recommend using the Causal Optimal Transport methods since they are semi- to non-parametric. This document will describe some simple usages of the functions in the package and should be enough to get users started.

Estimating weights

The weights can be estimated by using the calc_weight function in the package. We select optimal hyperparameters through our bootstrap-based algorithm and target the average treatment effect.

library(causalOT)
set.seed(1111)

hainmueller <- Hainmueller$new(n = 128)
hainmueller$gen_data()

weights <- calc_weight(data = hainmueller, method = "Wasserstein",
                       add.divergence = TRUE, grid.search = TRUE,
                       estimand = "ATE",
                       verbose = FALSE)

These weights will balance distributions, making estimates of treatment effects the same. We can then estimate effects with

tau_hat <- estimate_effect(data = hainmueller, weights = weights,
                           hajek = TRUE, doubly.robust = FALSE,
                           estimand = "ATE")

This creates an object of class causalEffect which can be fed into the native R function confint to calculate asymptotic confidence intervals.

ci_tau <- confint(object = tau_hat, level = 0.95, 
                  method = "asymptotic",
                  model = "lm",
                  formula = list(control = "y ~.", treated = "y ~ ."))

This then gives the following estimate and C.I.

print(tau_hat$estimate)
#> [1] 0.2574432
print(ci_tau$CI)
#> [1] -0.08814654  0.60303297

Diagnostics

Diagnostics are also an important part of deciding whether the weights perform well. There are several areas that we will explore:

  1. Effective sample size
  2. Mean balance
  3. Distributional balance

1. Effective sample size

Typically, estimated samples sizes with weights are calculated as \(\sum_i 1/w_i^2\) and gives us a measure of how much information is in the sample. The lower the effective sample size (ESS), the higher the variance, and the lower the sample size, the more weight a few individuals have. Of course, we can calculate this in causalOT!

ESS(weights)
#>  Control  Treated 
#> 36.44976 31.14867

Of course, this measure has problems because it can fail to diagnose problems with variable weights. In response, Vehtari et al. use Pareto smoothed importance sampling. We offer some shell code to adapt the class causalWeights to the loo package:

raw_psis <- PSIS(weights)

This will also return the Pareto smoothed weights and log weights.

If we want to easily examine the PSIS diagnostics, we can pull those out too

PSIS_diag(raw_psis)
#> $w0
#> $w0$pareto_k
#> [1] 0.2173782
#> 
#> $w0$n_eff
#> [1] 35.64768
#> 
#> 
#> $w1
#> $w1$pareto_k
#> [1] 0.3378445
#> 
#> $w1$n_eff
#> [1] 30.33031

We can see all of the \(k\) values are below the recommended 0.5, indicating finite variance and that the central limit theorem holds. Note the estimated sample sizes are a bit lower than the ESS method above.

2. Mean balance

Many authors consider the standardized absolute mean balance as a marker for important balance: see Stuart (2010). That is \[ \frac{|\overline{X}_c - \overline{X}_t| }{\sigma_{\text{pool}}},\] where \(\overline{X}_c\) is the mean in the controls, \(\overline{X}_t\) is the mean in the treated, and \(\sigma_{\text{pool}}\) is the pooled standard deviation. We offer such checks in causalOT as well.

First, we consider pre-weighting mean balance

mean_bal(data = hainmueller)
#>        X1        X2        X3        X4        X5        X6 
#> 1.3156637 1.1497004 1.0696805 0.6061098 0.1089106 0.1555189

and after weighting mean balance

mean_bal(data = hainmueller, weights = weights)
#>          X1          X2          X3          X4          X5 
#> 0.001290803 0.006736102 0.005986706 0.005668089 0.006262544 
#>          X6 
#> 0.002943358

Pretty good!

However, mean balance doesn’t ensure distributional balance.

3. Distributional balance

Ultimately, distributional balance is what we care about in causal inference. Fortunately, we can also measure that too. If we have python installed, we can use the GPU enabled GeomLoss package. Otherwise, the approxOT package can provide similar calculations. We consider the 2-Sinkhorn divergence of Genevay et al. since it metrizes the convergence in distribution.

Before weighting, distributional balance looks poor:

# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
              a = rep(1/64, 64), b = rep(1/128,128),
              power = 2, blur = 1e3, debias = TRUE)$loss,
     w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
              a = rep(1/64, 64), b = rep(1/128,128),
              power = 2, blur = 1e3, debias = TRUE)$loss
)
#> $w0
#> [1] 1.549706
#> 
#> $w1
#> [1] 1.549707

But after weighting, it looks much better!

# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
              a = weights$w0, b = rep(1/128,128),
              power = 2, blur = 1e3, debias = TRUE)$loss,
     w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
              a = weights$w1, b = rep(1/128,128),
              power = 2, blur = 1e3, debias = TRUE)$loss
)
#> $w0
#> [1] 6.205006e-06
#> 
#> $w1
#> [1] 0.0002435214

After Causal Optimal Transport, the distributions are much more similar.

Other methods

The calc weight function can also handle other methods. We have implemented methods for logistic or probit regression, the covariate balancing propensity score (CBPS), stable balancing weights (SBW), and the synthetic control method (SCM).

calc_weight(data = hainmueller, method = "Logistic",
                       estimand = "ATE")

calc_weight(data = hainmueller, method = "Probit",
                       estimand = "ATE")

calc_weight(data = hainmueller, method = "CBPS",
                       estimand = "ATE",
                       verbose = FALSE)

calc_weight(data = hainmueller, method = "SBW",
                        grid.search = TRUE,
                       estimand = "ATE", solver = "osqp",
                       verbose = FALSE, formula = "~.")

calc_weight(data = hainmueller, method = "SCM", penalty = "none",
                       estimand = "ATE", solver = "osqp",
                       verbose = FALSE)