How to use breakDown package for models created with xgboost

Przemyslaw Biecek


This example demonstrates how to use the breakDown package for models created with the xgboost package.

if (requireNamespace("xgboost", quietly = TRUE)) {
  model_martix_train <- model.matrix(left ~ . - 1, HR_data)
  data_train <- xgb.DMatrix(model_martix_train, label = as.numeric(HR_data$left))
  param <- list(objective = "reg:linear")
  HR_xgb_model <- xgb.train(param, data_train, nrounds = 50)
#> [11:41:44] WARNING: amalgamation/../src/objective/ reg:linear is now deprecated in favor of reg:squarederror.
#> ##### xgb.Booster
#> raw: 189.4 Kb 
#> call:
#>   xgb.train(params = param, data = data_train, nrounds = 50)
#> params (as set within xgb.train):
#>   objective = "reg:linear", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.print.evaluation(period = print_every_n)
#> # of features: 19 
#> niter: 50
#> nfeatures : 19

Now we are ready to call the broken() function.

if (requireNamespace("xgboost", quietly = TRUE)) {
  nobs <- model_martix_train[1L, , drop = FALSE]
  explain_2 <- broken(HR_xgb_model, new_observation = nobs, 
                      data = model_martix_train)
#>                              contribution
#> (Intercept)                         0.238
#> + time_spend_company = 3           -0.058
#> + number_project = 2               -0.008
#> + average_montly_hours = 157       -0.023
#> + satisfaction_level = 0.38         0.212
#> + last_evaluation = 0.53            0.620
#> + salarylow = 1                     0.009
#> + Work_accident = 0                 0.002
#> + salesRandD = 0                    0.001
#> + saleshr = 0                       0.001
#> + salesIT = 0                       0.000
#> + salarymedium = 0                  0.000
#> + salesaccounting = 0               0.000
#> + salesmarketing = 0                0.000
#> + salessales = 1                    0.000
#> + salessupport = 0                  0.000
#> + salesmanagement = 0               0.000
#> + salesproduct_mng = 0              0.000
#> + promotion_last_5years = 0         0.000
#> + salestechnical = 0               -0.001
#> final_prognosis                     0.992
#> baseline:  0

And plot it.

if (requireNamespace("xgboost", quietly = TRUE)) {
  plot(explain_2) + ggtitle("breakDown plot for xgboost model")