library(tabnet)
library(tidymodels)
library(modeldata)
In this vignette we show how to create a TabNet model using the tidymodels interface.
We are going to use the lending_club
dataset available in the modeldata
package.
First let’s split our dataset into training and testing so we can later access performance of our model:
set.seed(123)
data("lending_club", package = "modeldata")
<- initial_split(lending_club, strata = Class)
split <- training(split)
train <- testing(split) test
We now define our pre-processing steps. Note that tabnet handles categorical variables, so we don’t need to do any kind of transformation to them. Normalizing the numeric variables is a good idea though.
<- recipe(Class ~ ., train) %>%
rec step_normalize(all_numeric())
Next, we define our model. We are going to train for 50 epochs with a batch size of 128. There are other hyperparameters but, we are going to use the defaults.
<- tabnet(epochs = 50, batch_size = 128) %>%
mod set_engine("torch", verbose = TRUE) %>%
set_mode("classification")
We also define our workflow
object:
<- workflow() %>%
wf add_model(mod) %>%
add_recipe(rec)
We can now define our cross-validation strategy:
<- vfold_cv(train, v = 5) folds
And finally, fit the model:
<- wf %>%
fit_rs fit_resamples(folds)
After a few minutes we can get the results:
collect_metrics(fit_rs)
# A tibble: 2 x 5
.metric .estimator mean n std_err
<chr> <chr> <dbl> <int> <dbl>
1 accuracy binary 0.946 5 0.000713
2 roc_auc binary 0.732 5 0.00539
And finally, we can verify the results in our test set:
<- wf %>% fit(train)
model %>%
test bind_cols(
predict(model, test, type = "prob")
%>%
) roc_auc(Class, .pred_bad)
# A tibble: 1 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.710