kernelshap

Introduction

SHAP values (Lundberg and Lee, 2017) decompose model predictions into additive contributions of the features in a fair way. A model agnostic approach is called Kernel SHAP, introduced in Lundberg and Lee (2017), and investigated in detail in Covert and Lee (2021).

The “kernelshap” package implements the Kernel SHAP Algorithm 1 described in the supplement of Covert and Lee (2021). An advantage of their algorithm is that SHAP values are supplemented by standard errors. Furthermore, convergence can be monitored and controlled.

The main function kernelshap() has three key arguments:

Remarks

Installation

# install.packages("devtools")
devtools::install_github("mayer79/kernelshap")

Example: linear regression

library(kernelshap)
library(shapviz)

fit <- lm(Sepal.Length ~ ., data = iris)
pred_fun <- function(X) predict(fit, X)

# Crunch SHAP values (9 seconds)
s <- kernelshap(iris[-1], pred_fun = pred_fun, bg_X = iris[-1])
s

# Output (partly)
# SHAP values of first 2 observations:
#      Sepal.Width Petal.Length Petal.Width   Species
# [1,]  0.21951350    -1.955357   0.3149451 0.5823533
# [2,] -0.02843097    -1.955357   0.3149451 0.5823533
# 
#  Corresponding standard errors:
#       Sepal.Width Petal.Length  Petal.Width      Species
# [1,] 1.526557e-15 1.570092e-16 1.110223e-16 1.554312e-15
# [2,] 2.463307e-16 5.661049e-16 1.110223e-15 1.755417e-16

# Plot with shapviz
shp <- shapviz(s)  # until shapviz 0.2.0: shapviz(s$S, s$X, s$baseline)
sv_waterfall(shp, 1)
sv_importance(shp)
sv_dependence(shp, "Petal.Length")

Example: logistic regression on probability scale

library(kernelshap)
library(shapviz)

fit <- glm(I(Species == "virginica") ~ Sepal.Length + Sepal.Width, data = iris, family = binomial)
pred_fun <- function(X) predict(fit, X, type = "response")

# Crunch SHAP values (4 seconds)
s <- kernelshap(iris[1:2], pred_fun = pred_fun, bg_X = iris[1:2])

# Plot with shapviz
shp <- shapviz(s)  # until shapviz 0.2.0: shapviz(s$S, s$X, s$baseline)
sv_waterfall(shp, 51)
sv_dependence(shp, "Sepal.Length")

Example: Keras neural net

library(kernelshap)
library(keras)
library(shapviz)

model <- keras_model_sequential()
model %>% 
  layer_dense(units = 6, activation = "tanh", input_shape = 3) %>% 
  layer_dense(units = 1)

model %>% 
  compile(loss = "mse", optimizer = optimizer_nadam(0.005))

model %>% 
  fit(
    x = data.matrix(iris[2:4]), 
    y = iris[, 1],
    epochs = 50,
    batch_size = 30
  )

X <- data.matrix(iris[2:4])
pred_fun <- function(X) as.numeric(predict(model, X, batch_size = nrow(X)))

# Crunch SHAP values

# Takes about 40 seconds
system.time(
  s <- kernelshap(X, pred_fun = pred_fun, bg_X = X)
)

# Plot with shapviz
shp <- shapviz(s)  # until shapviz 0.2.0: shapviz(s$S, s$X, s$baseline)
sv_waterfall(shp, 1)
sv_importance(shp)
sv_dependence(shp, "Petal.Length")

Example: mlr3

library(mlr3)
library(mlr3learners)
library(kernelshap)
library(shapviz)

mlr_tasks$get("iris")
tsk("iris")
task_iris <- TaskRegr$new(id = "iris", backend = iris, target = "Sepal.Length")
fit_lm <- lrn("regr.lm")
fit_lm$train(task_iris)
s <- kernelshap(iris, function(X) fit_lm$predict_newdata(X)$response, bg_X = iris)
sv <- shapviz(s)  # until shapviz 0.2.0: shapviz(s$S, s$X, s$baseline)
sv_waterfall(sv, 1)
sv_dependence(sv, "Species")

Example: caret

library(caret)
library(kernelshap)
library(shapviz)

fit <- train(
  Sepal.Length ~ ., 
  data = iris, 
  method = "lm", 
  tuneGrid = data.frame(intercept = TRUE),
  trControl = trainControl(method = "none")
)

s <- kernelshap(iris[1, -1], function(X) predict(fit, X), bg_X = iris[-1])
sv <- shapviz(s)  # until shapviz 0.2.0: shapviz(s$S, s$X, s$baseline)
sv_waterfall(sv, 1)

References

[1] Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017.

[2] Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.