Decision Trees using tidymodels

Catalina CaΓ±izares, Ph.D. and Raymond Balise Ph.D.

Decision Trees using tidymodelsΒ© 2024 by Catalina Canizares and Raymond Balise is licensed under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International

This material is freely available under the Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.

For more information on this license, please visit: Creative Commons License

Agenda 🌲

🌲 Decision Trees

🌲 Basic concepts (Root, Feature, Leaf)

🌲 View and understand the splits

🌲 Entropy and Information Gain

🌲 Early stopping and Pruning

🌲 A tree in tidymodels

Trees

🌲 Decision Trees are widely used algorithms for supervised machine learning.

🌲 They provide interpretable models for making predictions in both regression and classification tasks.

How it works

🌲 Consists of a series of sequential decisions on some data set’s features.

How it works

How it works

How it works

How it works

How it works

What the Splits Look Like

What the Splits Look Like

What the Splits Look Like

What the Splits Look Like

What the Splits Look Like

What the Splits Look Like

What the Splits Look Like

How does the algorithm determine where to partition the data?

🌲 Instead of minimizing the Sum of Squared Errors, you can minimize entropy….

🌲 Entropy = measures the amount of information of some variable or event.

🌲 We’ll make use of it to identify regions consisting of

  • A large number of similar (pure) or

  • Dissimilar (impure) elements.

Information Gain - The logic to train

🌲 Measures the quality of a split

🌲 The core algorithm to calculate information gain is called ID3.

🌲 It is calculated for a split by subtracting the weighted entropies of each branch from the original entropy. 🌲When training a Decision Tree using these metrics, the best split is chosen by maximizing Information Gain.

🌲 Select the split that yields the largest reduction in entropy, or, the largest increase in information.

Information Gain

Click here to see the animation

Information Gain

If you are intersted in the math: A Simple Explanation of Information Gain and Entropy

Classification trees

🌲 One of the questions that arises in a decision tree algorithm is: what is the optimal size of the final tree

🌲 A tree that is too large risks over-fitting the training data and poorly generalizing to new samples.

🌲 A small tree might not capture important structural information about the sample space.

🌲 However, it is hard to tell when a tree algorithm should stop!

Early Stopping

tree_depth

🌲 Cap the maximum tree depth.

🌲 A method to stop the tree early.

🌲 Used to prevent overfitting.

tree_depth

tree_depth

min_n

🌲 An integer for the minimum number of data points in a node that are required for the node to be split further.

🌲 Set minimum n to split at any node.

🌲 Another early stopping method.

🌲 Used to prevent overfitting.

🌲 min_n = 1 would lead to the most overfit tree.

Pruning

cost_complexity - tree pruning

🌲 Adds a cost or penalty to error rates of more complex trees

🌲 Used to prevent overfitting.

🌲 Closer to zero ➑️ larger trees.

🌲 Higher penalty ➑️ smaller trees.

cost_complexity

\[ R_\alpha(T) = R(T) + \alpha|\widetilde{T}| \] 🌲 \(R(T)\) misclassification rate

🌲 For any subtree \(T<T_{max}\) we will define its complexity as \(|\widetilde{T}|\)

🌲 \(|\widetilde{T}|\) = the number of terminal or leaf nodes in T.

🌲 \(\alpha ≀0\) be a real number called the complexity parameter.

🌲 If \(\alpha\) = 0 then the biggest tree will be chosen because the complexity penalty term is essentially dropped.

🌲 As \(\alpha\) approaches infinity, the tree of size 1, will be selected.

cost_complexity

cost_complexity

Recap

Classification Tree with tidymodels

Task

Predict whether an adolescent has consumed alcohol or not based on a set of various risk behaviors.

Data Cleaning

data("riskyBehaviors")

riskyBehaviors_analysis <- 
  riskyBehaviors |> 
  mutate(UsedAlcohol = case_when(
    AgeFirstAlcohol == 1 ~ 0, 
    AgeFirstAlcohol %in% c(2, 3, 5, 6, 4, 7) ~ 1, 
    TRUE ~ NA
    )) |> 
  mutate(UsedAlcohol = factor(UsedAlcohol)) |> 
  drop_na(UsedAlcohol) |> 
  select(- c(AgeFirstAlcohol, DaysAlcohol, BingeDrinking, LargestNumberOfDrinks, SourceAlcohol, SourceAlcohol))

Splitting the data

set.seed(2023)

alcohol_split <- initial_split(riskyBehaviors_analysis, 
                               strata = UsedAlcohol)

alcohol_train <- training(alcohol_split)
alcohol_test <- testing(alcohol_split)

alcohol_split
<Training/Testing/Total>
<9889/3297/13186>

Lets Check Our Work

alcohol_train |> 
  tabyl(UsedAlcohol)  |> 
  adorn_pct_formatting(0) |> 
  adorn_totals()
 UsedAlcohol    n percent
           0 4354     44%
           1 5535     56%
       Total 9889       -
alcohol_test |>  
  tabyl(UsedAlcohol)  |> 
  adorn_pct_formatting(0) |> 
  adorn_totals()
 UsedAlcohol    n percent
           0 1452     44%
           1 1845     56%
       Total 3297       -

Creating the Resampling Object

set.seed(2023)

cv_alcohol <- rsample::vfold_cv(alcohol_train, 
                                strata = UsedAlcohol)
cv_alcohol
#  10-fold cross-validation using stratification 
# A tibble: 10 Γ— 2
   splits             id    
   <list>             <chr> 
 1 <split [8899/990]> Fold01
 2 <split [8899/990]> Fold02
 3 <split [8899/990]> Fold03
 4 <split [8899/990]> Fold04
 5 <split [8900/989]> Fold05
 6 <split [8901/988]> Fold06
 7 <split [8901/988]> Fold07
 8 <split [8901/988]> Fold08
 9 <split [8901/988]> Fold09
10 <split [8901/988]> Fold10

The Recipe

alcohol_recipe <- 
  recipe(formula = UsedAlcohol ~ ., data = alcohol_train) |>
  step_impute_mode(all_nominal_predictors()) |>
  step_impute_mean(all_numeric_predictors())

The Specification

cart_spec <- 
  decision_tree(
   cost_complexity = tune(),
   tree_depth = tune(),
   min_n = tune()) |>  
  set_engine("rpart") |> 
  set_mode("classification")

cart_spec 
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = tune()
  tree_depth = tune()
  min_n = tune()

Computational engine: rpart 

The Workflow

cart_workflow <- 
  workflow() |> 
  add_recipe(alcohol_recipe) |> 
  add_model(cart_spec)

cart_workflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

β€’ step_impute_mode()
β€’ step_impute_mean()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = tune()
  tree_depth = tune()
  min_n = tune()

Computational engine: rpart 

Tuning for the tree - The Grid

tree_grid <- 
  grid_regular(cost_complexity(),
               tree_depth(c(2, 5)),
               min_n(), 
               levels = 4)
tree_grid
# A tibble: 64 Γ— 3
   cost_complexity tree_depth min_n
             <dbl>      <int> <int>
 1    0.0000000001          2     2
 2    0.0000001             2     2
 3    0.0001                2     2
 4    0.1                   2     2
 5    0.0000000001          3     2
 6    0.0000001             3     2
 7    0.0001                3     2
 8    0.1                   3     2
 9    0.0000000001          4     2
10    0.0000001             4     2
# β„Ή 54 more rows

Tuning for the tree

doParallel::registerDoParallel()  

cart_tune <- 
  cart_workflow %>% 
  tune_grid(resamples = cv_alcohol,
            grid = tree_grid, 
            metrics = metric_set(roc_auc),
            control = control_grid(save_pred = TRUE)
  )

doParallel::stopImplicitCluster()  

Choosing the best CP

show_best(cart_tune, metric = "roc_auc")
# A tibble: 5 Γ— 9
  cost_complexity tree_depth min_n .metric .estimator  mean     n std_err
            <dbl>      <int> <int> <chr>   <chr>      <dbl> <int>   <dbl>
1    0.0000000001          5     2 roc_auc binary     0.834    10 0.00520
2    0.0000001             5     2 roc_auc binary     0.834    10 0.00520
3    0.0000000001          5    40 roc_auc binary     0.834    10 0.00484
4    0.0000001             5    40 roc_auc binary     0.834    10 0.00484
5    0.0001                5    40 roc_auc binary     0.834    10 0.00484
# β„Ή 1 more variable: .config <chr>

Choosing the best hyperparameters

bestPlot_cart <- 
  autoplot(cart_tune)

bestPlot_cart

Choosing the best CP

best_cart <- select_best(
  cart_tune, 
  metric = "roc_auc")

best_cart
# A tibble: 1 Γ— 4
  cost_complexity tree_depth min_n .config              
            <dbl>      <int> <int> <chr>                
1    0.0000000001          5     2 Preprocessor1_Model13

Finalizing Workflow

cart_final_wf <- finalize_workflow(cart_workflow, best_cart)
cart_final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

β€’ step_impute_mode()
β€’ step_impute_mean()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = 1e-10
  tree_depth = 5
  min_n = 2

Computational engine: rpart 

Fit the tree

cart_fit <- fit(
  cart_final_wf, 
  alcohol_train)

cart_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

β€’ step_impute_mode()
β€’ step_impute_mean()

── Model ───────────────────────────────────────────────────────────────────────
n= 9889 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 9889 4354 1 (0.44028719 0.55971281)  
   2) Vaping=0 5109 1608 0 (0.68526130 0.31473870)  
     4) AgeFirstMarihuana< 1.5 4397 1086 0 (0.75301342 0.24698658)  
       8) AgeFirstCig< 2.5 4282 1003 0 (0.76576366 0.23423634)  
        16) TextingDriving< 2.659047 3911  844 0 (0.78419841 0.21580159) *
        17) TextingDriving>=2.659047 371  159 0 (0.57142857 0.42857143)  
          34) DrivingDrinking< 2.5 363  151 0 (0.58402204 0.41597796) *
          35) DrivingDrinking>=2.5 8    0 1 (0.00000000 1.00000000) *
       9) AgeFirstCig>=2.5 115   32 1 (0.27826087 0.72173913)  
        18) Grade=12,9 53   21 1 (0.39622642 0.60377358)  
          36) Race=Hispanic/Latino,Native Hawaiian/Other PI 5    1 0 (0.80000000 0.20000000) *
          37) Race=Asian,Black or African American,Multiple-Hispanic,Multiple-Non-Hispanic,White 48   17 1 (0.35416667 0.64583333) *
        19) Grade=10,11 62   11 1 (0.17741935 0.82258065) *
     5) AgeFirstMarihuana>=1.5 712  190 1 (0.26685393 0.73314607)  
      10) AgeFirstCig< 1.423554 448  162 1 (0.36160714 0.63839286)  
        20) AgeFirstMarihuana< 2.809703 59   24 0 (0.59322034 0.40677966)  
          40) TextingDriving< 2.659047 52   17 0 (0.67307692 0.32692308) *
          41) TextingDriving>=2.659047 7    0 1 (0.00000000 1.00000000) *
        21) AgeFirstMarihuana>=2.809703 389  127 1 (0.32647815 0.67352185) *
      11) AgeFirstCig>=1.423554 264   28 1 (0.10606061 0.89393939) *
   3) Vaping=1 4780  853 1 (0.17845188 0.82154812)  
     6) AgeFirstMarihuana< 1.5 1658  556 1 (0.33534379 0.66465621)  
      12) SourceVaping=1 937  396 1 (0.42262540 0.57737460)  
        24) SexualAbuseByPartner< 1.884927 450  220 0 (0.51111111 0.48888889)  
          48) AgeFirstCig< 5.5 424  198 0 (0.53301887 0.46698113) *
          49) AgeFirstCig>=5.5 26    4 1 (0.15384615 0.84615385) *
        25) SexualAbuseByPartner>=1.884927 487  166 1 (0.34086242 0.65913758) *
      13) SourceVaping=2,3,4,5,6,7,8 721  160 1 (0.22191401 0.77808599) *
     7) AgeFirstMarihuana>=1.5 3122  297 1 (0.09513133 0.90486867) *

Review fit on the training data

tree_pred <- 
  augment(cart_fit, alcohol_train) |> 
  select(UsedAlcohol, .pred_class, .pred_1, .pred_0)

tree_pred
# A tibble: 9,889 Γ— 4
   UsedAlcohol .pred_class .pred_1 .pred_0
   <fct>       <fct>         <dbl>   <dbl>
 1 0           1             0.905  0.0951
 2 0           0             0.216  0.784 
 3 0           0             0.216  0.784 
 4 0           0             0.216  0.784 
 5 0           1             0.905  0.0951
 6 0           0             0.467  0.533 
 7 0           0             0.216  0.784 
 8 0           0             0.216  0.784 
 9 0           0             0.416  0.584 
10 0           0             0.216  0.784 
# β„Ή 9,879 more rows

Review fit on the training data

roc_tree <- 
  tree_pred |> 
  roc_curve(truth = UsedAlcohol, 
           .pred_1, 
           event_level = "second") |> 
  autoplot()

roc_tree

tree_pred |> 
  roc_auc(truth = UsedAlcohol, 
           .pred_1, 
           event_level = "second")
# A tibble: 1 Γ— 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.838

Review on Resamples

fit_resamples(cart_final_wf, resamples = cv_alcohol) |> 
  collect_metrics()
# A tibble: 3 Γ— 6
  .metric     .estimator  mean     n std_err .config             
  <chr>       <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy    binary     0.786    10 0.00439 Preprocessor1_Model1
2 brier_class binary     0.156    10 0.00270 Preprocessor1_Model1
3 roc_auc     binary     0.834    10 0.00520 Preprocessor1_Model1

The tree

cart_fit |> 
  extract_fit_engine() |> 
  rpart.plot::rpart.plot(roundint=FALSE)

To be continued…

This model has not been tested yet, as we are planning to conduct an additional analysis. In the next presentation, we will utilize the same training data with the Random Forest algorithm, followed by evaluating its performance using the testing set.