Decision Trees using tidymodels
Catalina CaΓ±izares, Ph.D. and Raymond Balise Ph.D.
Decision Trees using tidymodels
Β© 2024 by Catalina Canizares and Raymond Balise is licensed under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International
This material is freely available under the Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.
For more information on this license, please visit: Creative Commons License
Agenda π²
π² Decision Trees
π² Basic concepts (Root, Feature, Leaf)
π² View and understand the splits
π² Entropy and Information Gain
π² Early stopping and Pruning
π² A tree in tidymodels
Trees
π² Decision Trees are widely used algorithms for supervised machine learning.
π² They provide interpretable models for making predictions in both regression and classification tasks.
How it works
π² Consists of a series of sequential decisions on some data setβs features.
How it works
How it works
How it works
How it works
How it works
What the Splits Look Like
What the Splits Look Like
What the Splits Look Like
What the Splits Look Like
What the Splits Look Like
What the Splits Look Like
What the Splits Look Like
How does the algorithm determine where to partition the data?
π² Instead of minimizing the Sum of Squared Errors, you can minimize entropyβ¦.
π² Entropy = measures the amount of information of some variable or event.
π² Weβll make use of it to identify regions consisting of
Classification trees
π² One of the questions that arises in a decision tree algorithm is: what is the optimal size of the final tree
π² A tree that is too large risks over-fitting the training data and poorly generalizing to new samples.
π² A small tree might not capture important structural information about the sample space.
π² However, it is hard to tell when a tree algorithm should stop!
tree_depth
π² Cap the maximum tree depth.
π² A method to stop the tree early.
π² Used to prevent overfitting.
tree_depth
tree_depth
min_n
π² An integer for the minimum number of data points in a node that are required for the node to be split further.
π² Set minimum n to split at any node.
π² Another early stopping method.
π² Used to prevent overfitting.
π² min_n = 1 would lead to the most overfit tree.
cost_complexity - tree pruning
π² Adds a cost or penalty to error rates of more complex trees
π² Used to prevent overfitting.
π² Closer to zero β‘οΈ larger trees.
π² Higher penalty β‘οΈ smaller trees.
cost_complexity
\[
R_\alpha(T) = R(T) + \alpha|\widetilde{T}|
\] π² \(R(T)\) misclassification rate
π² For any subtree \(T<T_{max}\) we will define its complexity as \(|\widetilde{T}|\)
π² \(|\widetilde{T}|\) = the number of terminal or leaf nodes in T.
π² \(\alpha β€0\) be a real number called the complexity parameter.
π² If \(\alpha\) = 0 then the biggest tree will be chosen because the complexity penalty term is essentially dropped.
π² As \(\alpha\) approaches infinity, the tree of size 1, will be selected.
cost_complexity
cost_complexity
Recap
Classification Tree with tidymodels
Task
Predict whether an adolescent has consumed alcohol or not based on a set of various risk behaviors.
Data Cleaning
data ("riskyBehaviors" )
riskyBehaviors_analysis <-
riskyBehaviors |>
mutate (UsedAlcohol = case_when (
AgeFirstAlcohol == 1 ~ 0 ,
AgeFirstAlcohol %in% c (2 , 3 , 5 , 6 , 4 , 7 ) ~ 1 ,
TRUE ~ NA
)) |>
mutate (UsedAlcohol = factor (UsedAlcohol)) |>
drop_na (UsedAlcohol) |>
select (- c (AgeFirstAlcohol, DaysAlcohol, BingeDrinking, LargestNumberOfDrinks, SourceAlcohol, SourceAlcohol))
Splitting the data
set.seed (2023 )
alcohol_split <- initial_split (riskyBehaviors_analysis,
strata = UsedAlcohol)
alcohol_train <- training (alcohol_split)
alcohol_test <- testing (alcohol_split)
alcohol_split
<Training/Testing/Total>
<9889/3297/13186>
Lets Check Our Work
alcohol_train |>
tabyl (UsedAlcohol) |>
adorn_pct_formatting (0 ) |>
adorn_totals ()
UsedAlcohol n percent
0 4354 44%
1 5535 56%
Total 9889 -
alcohol_test |>
tabyl (UsedAlcohol) |>
adorn_pct_formatting (0 ) |>
adorn_totals ()
UsedAlcohol n percent
0 1452 44%
1 1845 56%
Total 3297 -
Creating the Resampling Object
set.seed (2023 )
cv_alcohol <- rsample:: vfold_cv (alcohol_train,
strata = UsedAlcohol)
cv_alcohol
# 10-fold cross-validation using stratification
# A tibble: 10 Γ 2
splits id
<list> <chr>
1 <split [8899/990]> Fold01
2 <split [8899/990]> Fold02
3 <split [8899/990]> Fold03
4 <split [8899/990]> Fold04
5 <split [8900/989]> Fold05
6 <split [8901/988]> Fold06
7 <split [8901/988]> Fold07
8 <split [8901/988]> Fold08
9 <split [8901/988]> Fold09
10 <split [8901/988]> Fold10
The Recipe
alcohol_recipe <-
recipe (formula = UsedAlcohol ~ ., data = alcohol_train) |>
step_impute_mode (all_nominal_predictors ()) |>
step_impute_mean (all_numeric_predictors ())
The Specification
cart_spec <-
decision_tree (
cost_complexity = tune (),
tree_depth = tune (),
min_n = tune ()) |>
set_engine ("rpart" ) |>
set_mode ("classification" )
cart_spec
Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = tune()
tree_depth = tune()
min_n = tune()
Computational engine: rpart
The Workflow
cart_workflow <-
workflow () |>
add_recipe (alcohol_recipe) |>
add_model (cart_spec)
cart_workflow
ββ Workflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Preprocessor: Recipe
Model: decision_tree()
ββ Preprocessor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2 Recipe Steps
β’ step_impute_mode()
β’ step_impute_mean()
ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = tune()
tree_depth = tune()
min_n = tune()
Computational engine: rpart
Tuning for the tree - The Grid
tree_grid <-
grid_regular (cost_complexity (),
tree_depth (c (2 , 5 )),
min_n (),
levels = 4 )
tree_grid
# A tibble: 64 Γ 3
cost_complexity tree_depth min_n
<dbl> <int> <int>
1 0.0000000001 2 2
2 0.0000001 2 2
3 0.0001 2 2
4 0.1 2 2
5 0.0000000001 3 2
6 0.0000001 3 2
7 0.0001 3 2
8 0.1 3 2
9 0.0000000001 4 2
10 0.0000001 4 2
# βΉ 54 more rows
Tuning for the tree
doParallel:: registerDoParallel ()
cart_tune <-
cart_workflow %>%
tune_grid (resamples = cv_alcohol,
grid = tree_grid,
metrics = metric_set (roc_auc),
control = control_grid (save_pred = TRUE )
)
doParallel:: stopImplicitCluster ()
Choosing the best CP
show_best (cart_tune, metric = "roc_auc" )
# A tibble: 5 Γ 9
cost_complexity tree_depth min_n .metric .estimator mean n std_err
<dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl>
1 0.0000000001 5 2 roc_auc binary 0.834 10 0.00520
2 0.0000001 5 2 roc_auc binary 0.834 10 0.00520
3 0.0000000001 5 40 roc_auc binary 0.834 10 0.00484
4 0.0000001 5 40 roc_auc binary 0.834 10 0.00484
5 0.0001 5 40 roc_auc binary 0.834 10 0.00484
# βΉ 1 more variable: .config <chr>
Choosing the best hyperparameters
bestPlot_cart <-
autoplot (cart_tune)
bestPlot_cart
Choosing the best CP
best_cart <- select_best (
cart_tune,
metric = "roc_auc" )
best_cart
# A tibble: 1 Γ 4
cost_complexity tree_depth min_n .config
<dbl> <int> <int> <chr>
1 0.0000000001 5 2 Preprocessor1_Model13
Finalizing Workflow
cart_final_wf <- finalize_workflow (cart_workflow, best_cart)
cart_final_wf
ββ Workflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Preprocessor: Recipe
Model: decision_tree()
ββ Preprocessor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2 Recipe Steps
β’ step_impute_mode()
β’ step_impute_mean()
ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Decision Tree Model Specification (classification)
Main Arguments:
cost_complexity = 1e-10
tree_depth = 5
min_n = 2
Computational engine: rpart
Fit the tree
cart_fit <- fit (
cart_final_wf,
alcohol_train)
cart_fit
ββ Workflow [trained] ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Preprocessor: Recipe
Model: decision_tree()
ββ Preprocessor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2 Recipe Steps
β’ step_impute_mode()
β’ step_impute_mean()
ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
n= 9889
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 9889 4354 1 (0.44028719 0.55971281)
2) Vaping=0 5109 1608 0 (0.68526130 0.31473870)
4) AgeFirstMarihuana< 1.5 4397 1086 0 (0.75301342 0.24698658)
8) AgeFirstCig< 2.5 4282 1003 0 (0.76576366 0.23423634)
16) TextingDriving< 2.659047 3911 844 0 (0.78419841 0.21580159) *
17) TextingDriving>=2.659047 371 159 0 (0.57142857 0.42857143)
34) DrivingDrinking< 2.5 363 151 0 (0.58402204 0.41597796) *
35) DrivingDrinking>=2.5 8 0 1 (0.00000000 1.00000000) *
9) AgeFirstCig>=2.5 115 32 1 (0.27826087 0.72173913)
18) Grade=12,9 53 21 1 (0.39622642 0.60377358)
36) Race=Hispanic/Latino,Native Hawaiian/Other PI 5 1 0 (0.80000000 0.20000000) *
37) Race=Asian,Black or African American,Multiple-Hispanic,Multiple-Non-Hispanic,White 48 17 1 (0.35416667 0.64583333) *
19) Grade=10,11 62 11 1 (0.17741935 0.82258065) *
5) AgeFirstMarihuana>=1.5 712 190 1 (0.26685393 0.73314607)
10) AgeFirstCig< 1.423554 448 162 1 (0.36160714 0.63839286)
20) AgeFirstMarihuana< 2.809703 59 24 0 (0.59322034 0.40677966)
40) TextingDriving< 2.659047 52 17 0 (0.67307692 0.32692308) *
41) TextingDriving>=2.659047 7 0 1 (0.00000000 1.00000000) *
21) AgeFirstMarihuana>=2.809703 389 127 1 (0.32647815 0.67352185) *
11) AgeFirstCig>=1.423554 264 28 1 (0.10606061 0.89393939) *
3) Vaping=1 4780 853 1 (0.17845188 0.82154812)
6) AgeFirstMarihuana< 1.5 1658 556 1 (0.33534379 0.66465621)
12) SourceVaping=1 937 396 1 (0.42262540 0.57737460)
24) SexualAbuseByPartner< 1.884927 450 220 0 (0.51111111 0.48888889)
48) AgeFirstCig< 5.5 424 198 0 (0.53301887 0.46698113) *
49) AgeFirstCig>=5.5 26 4 1 (0.15384615 0.84615385) *
25) SexualAbuseByPartner>=1.884927 487 166 1 (0.34086242 0.65913758) *
13) SourceVaping=2,3,4,5,6,7,8 721 160 1 (0.22191401 0.77808599) *
7) AgeFirstMarihuana>=1.5 3122 297 1 (0.09513133 0.90486867) *
Review fit on the training data
tree_pred <-
augment (cart_fit, alcohol_train) |>
select (UsedAlcohol, .pred_class, .pred_1, .pred_0)
tree_pred
# A tibble: 9,889 Γ 4
UsedAlcohol .pred_class .pred_1 .pred_0
<fct> <fct> <dbl> <dbl>
1 0 1 0.905 0.0951
2 0 0 0.216 0.784
3 0 0 0.216 0.784
4 0 0 0.216 0.784
5 0 1 0.905 0.0951
6 0 0 0.467 0.533
7 0 0 0.216 0.784
8 0 0 0.216 0.784
9 0 0 0.416 0.584
10 0 0 0.216 0.784
# βΉ 9,879 more rows
Review fit on the training data
roc_tree <-
tree_pred |>
roc_curve (truth = UsedAlcohol,
.pred_1,
event_level = "second" ) |>
autoplot ()
roc_tree
tree_pred |>
roc_auc (truth = UsedAlcohol,
.pred_1,
event_level = "second" )
# A tibble: 1 Γ 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.838
Review on Resamples
fit_resamples (cart_final_wf, resamples = cv_alcohol) |>
collect_metrics ()
# A tibble: 3 Γ 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.786 10 0.00439 Preprocessor1_Model1
2 brier_class binary 0.156 10 0.00270 Preprocessor1_Model1
3 roc_auc binary 0.834 10 0.00520 Preprocessor1_Model1
The tree
cart_fit |>
extract_fit_engine () |>
rpart.plot:: rpart.plot (roundint= FALSE )
To be continuedβ¦
This model has not been tested yet, as we are planning to conduct an additional analysis. In the next presentation, we will utilize the same training data with the Random Forest algorithm, followed by evaluating its performance using the testing set.