Healthy Foods

Decision Tree Training, Pruning and Hyperparameters Tuning.

Article Outline

  • What is a decision tree?
  • Why use them?
  • Data Background
  • Descriptive Statistics
  • Decision Tree Training and Evaluation
  • Decision Tree Pruning
  • Hyperparameters Tuning

What is a decision tree?

A decision tree is a representation of a flowchart. The classification and regression tree (a.k.a decision tree) algorithm was developed by Breiman et al. 1984 (usually reported) but that certainly was not the earliest. Wei-Yin Loh of the University of Wisconsin has written about the history of decision trees. You can read it here “Fifty Years of Classification and Regression Trees”. 

In a decision tree, the top node is called the “root node” and the bottom node “terminal node”. The other nodes are called “internal nodes” which includes a binary split condition, while each leaf node contains associated class labels.

A classification tree uses a split condition to predict a class label based on the provided input variables. The splitting process starts from the top node (root node), and at each node, it checks whether supplied input values recursively continue to the left or right according to a supplied splitting condition (Gini or Information gain). This process terminates when a leaf or terminal node is reached.

Why use them?

A single decision tree-based model is easy to build, plot and interpret which makes this algorithm so popular. You can use this algorithm for performing classification as well as a regression task. 

Data Background

In this example, we are going to use the Pima Indian Diabetes 2 data set obtained from the UCI Repository of machine learning databases (Newman et al. 1998).

This data set is originally from the National Institute of Diabetes and Digestive and Kidney Diseases. The objective of the data set is to diagnostically predict whether or not a patient has diabetes, based on certain diagnostic measurements included in the data set. Several constraints were placed on the selection of these instances from a larger database. In particular, all patients here are females at least 21 years old of Pima Indian heritage.

The Pima Indian Diabetes 2 data set is the refined version (all missing values were assigned as NA) of the Pima Indian diabetes data. The data set contains the following independent and dependent variables.

Independent variables (symbol: I)

  • I1: pregnant: Number of times pregnant
  • I2: glucose: Plasma glucose concentration (glucose tolerance test)
  • I3: pressure: Diastolic blood pressure (mm Hg)
  • I4: triceps: Triceps skinfold thickness (mm)
  • I5: insulin: 2-Hour serum insulin (mu U/ml)
  • I6: mass: Body mass index (weight in kg/(height in m)\²)
  • I7: pedigree: Diabetes pedigree function
  • I8: age: Age (years)

Dependent Variable (symbol: D)

  • D1: diabetes: diabetes case (pos/neg)

Aim of the Modelling

  • fitting a decision tree classification machine learning model that accurately predicts whether or not the patients in the data set have diabetes
  • Decision tree pruning for reducing overfitting
  • Decision tree hyperparameters tuning

Loading relevant libraries

The first step of data analysis starts with loading relevant libraries.

library(mlbench) # Diabetes dataset
library(rpart) # Decision tree
library(rpart.plot) # Plotting decision tree
library(caret) # Accuracy estimation
library(Metrics) # For diferent model evaluation metrics

Loading dataset

The very next step is to load the data into the R environment. As this comes with mlbench package one can load the data calling data( ).

# load the diabetes dataset
data(PimaIndiansDiabetes2)

Data Preprocessing

The next step would be to perform exploratory analysis. First, we need to remove the missing values using the na.omit( ) function. Print the data types using glimpse( ) method from dplyr library. You can see that all the variables except the dependent variable (diabetes: categorical/factor) are double type.

Diabetes <- na.omit(PimaIndiansDiabetes2) # Data for modeling
dplyr::glimpse(Diabetes)
Data Types

Train and Test Split

The next step is to split the dataset into 80% train and 20% test. Here, we are using the sample( ) method to randomly pick the observation index for train and test split with replacement. Next, based on indexing we split out the train and test data.

set.seed(123)
index <- sample(2, nrow(Diabetes), prob = c(0.8, 0.2), replace = TRUE)
Diabetes_train <- Diabetes[index==1, ] # Train data
Diabetes_test <- Diabetes[index == 2, ] # Test data

The train data includes 318 observations and test data included 74 observations. Both contain 9 variables.

print(dim(Diabetes_train))
print(dim(Diabetes_test))
Train and Test Dimension

Model Training

The next step is the model training and evaluation of model performance

Training a Decision Tree

For decision tree training, we will use the rpart( ) function from the rpart library. The arguments include; formula for the model, data and method. 

formula = diabetes ~. i.e., diabetes is predicted by all independent variables (excluding diabetes)

Here, the method should be specified as the class for the classification task.

# Train a decision tree model
Diabetes_model <- rpart(formula = diabetes ~., 
                        data = Diabetes_train, 
                        method = "class")

Model Plotting

The main advantage of the tree-based model is that you can plot the tree structure and able to figure out the decision mechanism.

# type: 0; Draw a split label at each split and a node label at each leaf.
# yesno = 2; provides spli yes or no
# Extra = 0; no extra information

rpart.plot(x = Diabetes_model, yesno = 2, type = 0, extra = 0)
Diabetes_model Tree Structure

Model Performance Evaluation

Next, step is to see how our trained model performs on the test/unseen dataset. For predicting the test data class we need to supply the model object, test dataset and the type = “class” inside the predict( ) function.

# class prediction
class_predicted <- predict(object = Diabetes_model,  
                            newdata = Diabetes_test,   
                            type = "class")

(a) Confusion matrix

To evaluate the test performance we are going to use the confusionMatrix( ) from caret library. We can observe that out of 74 observations it wrongly predicts 17 observations. The model has achieved about 77.03% accuracy using a single decision tree.

# Generate a confusion matrix for the test data
confusionMatrix(data = class_predicted,       
                reference = Diabetes_test$diabetes)
Diabetes_model Test Evaluation Statistics

(b) Test accuracy

We can also supply the predicted class labels and original test dataset labels to the accuracy( ) function for estimating the model accuracy.

accuracy(actual = class_predicted,       
         predicted = Diabetes_test$diabetes)
Diabetes_model Test Accuracy

Splitting Criteria Based Model Comparision

While building the model the decision tree algorithm uses splitting criteria. There are two popular splitting criteria used in decision trees; one is called “gini” and others called “information gain”. Here, we try to compare the model performance on the test set after training with different split criteria. The splitting criteria are supplied using parms argument as a list.

# Model training based on gini-based splitting criteria
Diabetes_model1 <- rpart(formula = diabetes ~ ., 
                         data = Diabetes_train, 
                         method = "class",
                         parms = list(split = "gini"))
# Model training based on information gain-based splitting criteria
Diabetes_model2 <- rpart(formula = diabetes ~ ., 
                         data = Diabetes_train, 
                         method = "class",
                         parms = list(split = "information"))

Model Evaluation on Test Data

After model training, the next step is to predict the class labels of the test dataset.

# Generate class predictions on the test data using gini-based splitting criteria
pred1 <- predict(object = Diabetes_model1, 
                 newdata = Diabetes_test,
                 type = "class")
# Generate class predictions on test data using information gain based splitting criteria
pred2 <- predict(object = Diabetes_model2, 
                 newdata = Diabetes_test,
                 type = "class")

Prediction Accuracy Comparision

Next, we compare the accuracy of the models. Here, we can observe that “gini” based splitting criteria is providing a more accurate model than “information” based splitting.

# Compare classification accuracy on test data

accuracy(actual = Diabetes_test$diabetes, 
   predicted = pred1)

accuracy(actual = Diabetes_test$diabetes, 
   predicted = pred2)
Diabetes_model1 Test Accuracy
Diabetes_model2 Test Accuracy

The initial model (Diabetes_model) and the “gini” based model (Diabetes_model1) providing the same accuracy, as rpart model uses “gini” as its default splitting criteria.

Decision Tree Pruning

The initial model (Diabetes_model) plot shows that the tree structure is deep and fragile which might reduce the easy interpretation in the decision-making process. Thus here we would try to explore other ways to make the tree more interpretable without losing performance. One way of doing this is by pruning the fragile part of the tree (part contributes to model overfitting).

(a) Plotting the error vs Complexity Parameter (CP)

The decision tree has one parameter called complexity parameter (cp) which controls the size of the decision tree. If the cost of adding another variable to the decision tree from the current node is above the value of cp, then tree building does not continue. We can generate the cp vs error plot using the plotcp( ) library.

# Plotting Cost Parameter (CP) Table
plotcp(Diabetes_model1)
Error vs CP Plot

(b) Generating complexity parameter table

We can also generate the cp table by calling model$cptable. Here, you can observe that xerror is minimum with CP value of 0.025.

# Plotting the Cost Parameter (CP) Table
print(Diabetes_model1$cptable)

(c) Obtaining an optimal pruned model

We can filter out the optimal CP value by identifying the index of minimum xerror and by supplying it to the CP table.

# Retrieve of optimal cp value based on cross-validated error
index <- which.min(Diabetes_model1$cptable[, "xerror"])
cp_optimal <- Diabetes_model1$cptable[index, "CP"]

The next step is to prune the tree using prune( ) function by supplying optimal CP value. If we plot the optimal pruned tree we can now observe that the tree is very simple and easy to interpret. 

If a person has a glucose level above 128 and age greater than 25 will be designated as diabetes positive else negative.

# Pruning tree based on optimal CP value
Diabetes_model1_opt <- prune(tree = Diabetes_model1, cp = cp_optimal)

# Plotting pruned tree
rpart.plot(x = Diabetes_model1_opt, yesno = 2, type = 0, extra = 0)

(d) Pruned tree performance

The next step is to check whether the prune tree has similar performance or the performance has been compromised. After the performance check, we can see that the pruned tree is as capable as the earlier fragile tree but now it is simple and easy to interpret.

pred3 <- predict(object = Diabetes_model1_opt, 
                 newdata = Diabetes_test,
                 type = "class")
accuracy(actual = Diabetes_test$diabetes, 
         predicted = pred3)

Decision Tree Hyperparameter Tuning

Next, we would try to increase the performance of the decision tree model by tuning its hyperparameters. The rpart( ) offers different hyperparameters but here we will try to tune two important parameters which are minsplit, and maxdepth.

  • minsplit: the minimum number of observations that must exist in the node in order for a split to be attempted.
  • maxdepth: The maximum depth of any node of the final tree.

(a) Generating hyperparameter grid

First, we generate a sequence 1 to 20 for both minsplit and maxdepth. Then we build a parameter combination grid using expand.grid( ) function.

# Setting values for minsplit and maxdepth
## the minimum number of observations that must exist in a node in order for a split to be attempted.
## Set the maximum depth of any node of the final tree
minsplit <- seq(1, 20, 1)
maxdepth <- seq(1, 20, 1)
# Generate a search grid 
hyperparam_grid <- expand.grid(minsplit = minsplit, maxdepth = maxdepth)

(b) Training grid-based models

The next step is to train different models based on each grid hyperparameter combination. This could be done through the following steps:

  • using a for loop to loop through each hyperparameter in the grid and then supplying it to rpart( ) function for model training
  • storing each model into an empty list (diabetes_models)
# Number of potential models in the grid
num_models <- nrow(hyperparam_grid)
# Create an empty list 
diabetes_models <- list()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:num_models) {
  
  minsplit <- hyperparam_grid$minsplit[i]
  maxdepth <- hyperparam_grid$maxdepth[i]
  
  # Train a model and store in the list
  diabetes_models[[i]] <- rpart(formula = diabetes ~ ., 
                             data = Diabetes_train, 
                             method = "class",
                             minsplit = minsplit,
                             maxdepth = maxdepth)
}

(c) Computing test accuracy

The next step is to check the model performance of each model on test data and retrieving the best model. This could be done through the following steps:

  • using a for loop to loop through each model in the list, and then predicting the test data and computing accuracy 
  • storing each model accuracy into an empty vector (accuracy_values)
# Number of models inside the grid
num_models <- length(diabetes_models)
# Create an empty vector to store accuracy values
accuracy_values <- c()
# Use for loop for models accuracy estimation
for (i in 1:num_models) {
  
  # Retrieve the model i from the list
  model <- diabetes_models[[i]]
  
  # Generate predictions on test data 
  pred <- predict(object = model,
                  newdata = Diabetes_test,
                  type = "class")
  
  # Compute test accuracy and add to the empty vector accuracy_values 
  accuracy_values[i] <- accuracy(actual = Diabetes_test$diabetes, 
                         predicted = pred)
}

(d) Identifying the best model

The next step is to retrieve the best performing model (maximum accuracy) and printing its hyperparameters using model$control. We can observe that with a minimum split of 17 and a maximum depth of 6 the model provides most accurate results when evaluated on unseen/test dataset.

# Identify the model with maximum accuracy
best_model <- diabetes_models[[which.max(accuracy_values)]]

# Print the model hyper-parameters of the best model
best_model$control

(e) Best model evaluation on test data

After identifying the best performing model, the next step is to see how accurate the model is. Now, with the best hyperparameters, the model achieved an accuracy of 81.08% which is really great.

# Best_model accuracy on test data
pred <- predict(object = best_model,
                newdata = Diabetes_test,
                type = "class")
accuracy(actual = Diabetes_test$diabetes, 
     predicted = pred)

(f) Best model plot

Now it is time to plot the best model.

rpart.plot(x = best_model, yesno = 2, type = 0, extra = 0)
Best Model’s Layout

Even the above plot is for best performing model, still, it looks a little bit fragile. So your next task would be to prune it and see if you get a better interpretable decision tree or not. 

I hope you learned something new. See you next time!

Note

This article was first published in onezero.blog, a data science, machine learning and research related blogging platform maintained by myself.

If you learned something new and liked this article, say 👋 / follow me on onezero.blog (my personal blogging website), Twitter, LinkedIn, YouTubeandGithub.

Featured Image Credit: Photo by Sara Dubler on Unsplash

References

[1] Breiman, L., Friedman, J., Stone, C.J. and Olshen, R.A., 1984. Classification and regression trees. CRC press.

[2] Loh, W. (2014). Fifty Years of Classification and Regression Trees 1.

[3] Newman, C. B. D. & Merz, C. (1998). UCI Repository of machine learning databases, Technical report, University of California, Irvine, Dept. of Information and Computer Sciences.