causalOT
was developed to reproduce the methods in Optimal transport methods for causal inference. The functions in the package are built to construct weights to make distributions more same and estimate causal effects. We recommend using the Causal Optimal Transport methods since they are semi- to non-parametric. This document will describe some simple usages of the functions in the package and should be enough to get users started.
The weights can be estimated by using the calc_weight
function in the package. We select optimal hyperparameters through our bootstrap-based algorithm and target the average treatment effect.
library(causalOT)
set.seed(1111)
<- Hainmueller$new(n = 128)
hainmueller $gen_data()
hainmueller
<- calc_weight(data = hainmueller, method = "Wasserstein",
weights add.divergence = TRUE, grid.search = TRUE,
estimand = "ATE",
verbose = FALSE)
These weights will balance distributions, making estimates of treatment effects the same. We can then estimate effects with
<- estimate_effect(data = hainmueller, weights = weights,
tau_hat hajek = TRUE, doubly.robust = FALSE,
estimand = "ATE")
This creates an object of class causalEffect
which can be fed into the native R
function confint
to calculate asymptotic confidence intervals.
<- confint(object = tau_hat, level = 0.95,
ci_tau method = "asymptotic",
model = "lm",
formula = list(control = "y ~.", treated = "y ~ ."))
This then gives the following estimate and C.I.
print(tau_hat$estimate)
#> [1] 0.2574432
print(ci_tau$CI)
#> [1] -0.08814654 0.60303297
Diagnostics are also an important part of deciding whether the weights perform well. There are several areas that we will explore:
Typically, estimated samples sizes with weights are calculated as \(\sum_i 1/w_i^2\) and gives us a measure of how much information is in the sample. The lower the effective sample size (ESS), the higher the variance, and the lower the sample size, the more weight a few individuals have. Of course, we can calculate this in causalOT
!
ESS(weights)
#> Control Treated
#> 36.44976 31.14867
Of course, this measure has problems because it can fail to diagnose problems with variable weights. In response, Vehtari et al. use Pareto smoothed importance sampling. We offer some shell code to adapt the class causalWeights
to the loo
package:
<- PSIS(weights) raw_psis
This will also return the Pareto smoothed weights and log weights.
If we want to easily examine the PSIS diagnostics, we can pull those out too
PSIS_diag(raw_psis)
#> $w0
#> $w0$pareto_k
#> [1] 0.2173782
#>
#> $w0$n_eff
#> [1] 35.64768
#>
#>
#> $w1
#> $w1$pareto_k
#> [1] 0.3378445
#>
#> $w1$n_eff
#> [1] 30.33031
We can see all of the \(k\) values are below the recommended 0.5, indicating finite variance and that the central limit theorem holds. Note the estimated sample sizes are a bit lower than the ESS
method above.
Many authors consider the standardized absolute mean balance as a marker for important balance: see Stuart (2010). That is \[ \frac{|\overline{X}_c - \overline{X}_t| }{\sigma_{\text{pool}}},\] where \(\overline{X}_c\) is the mean in the controls, \(\overline{X}_t\) is the mean in the treated, and \(\sigma_{\text{pool}}\) is the pooled standard deviation. We offer such checks in causalOT
as well.
First, we consider pre-weighting mean balance
mean_bal(data = hainmueller)
#> X1 X2 X3 X4 X5 X6
#> 1.3156637 1.1497004 1.0696805 0.6061098 0.1089106 0.1555189
and after weighting mean balance
mean_bal(data = hainmueller, weights = weights)
#> X1 X2 X3 X4 X5
#> 0.001290803 0.006736102 0.005986706 0.005668089 0.006262544
#> X6
#> 0.002943358
Pretty good!
However, mean balance doesn’t ensure distributional balance.
Ultimately, distributional balance is what we care about in causal inference. Fortunately, we can also measure that too. If we have python installed, we can use the GPU enabled GeomLoss package. Otherwise, the approxOT package can provide similar calculations. We consider the 2-Sinkhorn divergence of Genevay et al. since it metrizes the convergence in distribution.
Before weighting, distributional balance looks poor:
# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
a = rep(1/64, 64), b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss,
w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
a = rep(1/64, 64), b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss
)#> $w0
#> [1] 1.549706
#>
#> $w1
#> [1] 1.549707
But after weighting, it looks much better!
# geomloss method
list(w0 = sinkhorn(x = hainmueller$get_x0(), y = hainmueller$get_x(),
a = weights$w0, b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss,
w1 = sinkhorn(x = hainmueller$get_x1(), y = hainmueller$get_x(),
a = weights$w1, b = rep(1/128,128),
power = 2, blur = 1e3, debias = TRUE)$loss
)#> $w0
#> [1] 6.205006e-06
#>
#> $w1
#> [1] 0.0002435214
After Causal Optimal Transport, the distributions are much more similar.
The calc weight function can also handle other methods. We have implemented methods for logistic or probit regression, the covariate balancing propensity score (CBPS), stable balancing weights (SBW), and the synthetic control method (SCM).
calc_weight(data = hainmueller, method = "Logistic",
estimand = "ATE")
calc_weight(data = hainmueller, method = "Probit",
estimand = "ATE")
calc_weight(data = hainmueller, method = "CBPS",
estimand = "ATE",
verbose = FALSE)
calc_weight(data = hainmueller, method = "SBW",
grid.search = TRUE,
estimand = "ATE", solver = "osqp",
verbose = FALSE, formula = "~.")
calc_weight(data = hainmueller, method = "SCM", penalty = "none",
estimand = "ATE", solver = "osqp",
verbose = FALSE)