Decision Tree Training, Pruning and Hyperparameters Tuning.
Article Outline
- What is a decision tree?
- Why use them?
- Data Background
- Descriptive Statistics
- Decision Tree Training and Evaluation
- Decision Tree Pruning
- Hyperparameters Tuning
What is a decision tree?
A decision tree is a representation of a flowchart. The classification and regression tree (a.k.a decision tree) algorithm was developed by Breiman et al. 1984 (usually reported) but that certainly was not the earliest. Wei-Yin Loh of the University of Wisconsin has written about the history of decision trees. You can read it here “Fifty Years of Classification and Regression Trees”.
In a decision tree, the top node is called the “root node” and the bottom node “terminal node”. The other nodes are called “internal nodes” which includes a binary split condition, while each leaf node contains associated class labels.
A classification tree uses a split condition to predict a class label based on the provided input variables. The splitting process starts from the top node (root node), and at each node, it checks whether supplied input values recursively continue to the left or right according to a supplied splitting condition (Gini or Information gain). This process terminates when a leaf or terminal node is reached.
Why use them?
A single decision tree-based model is easy to build, plot and interpret which makes this algorithm so popular. You can use this algorithm for performing classification as well as a regression task.
Data Background
In this example, we are going to use the Pima Indian Diabetes 2 data set obtained from the UCI Repository of machine learning databases (Newman et al. 1998).
This data set is originally from the National Institute of Diabetes and Digestive and Kidney Diseases. The objective of the data set is to diagnostically predict whether or not a patient has diabetes, based on certain diagnostic measurements included in the data set. Several constraints were placed on the selection of these instances from a larger database. In particular, all patients here are females at least 21 years old of Pima Indian heritage.
The Pima Indian Diabetes 2 data set is the refined version (all missing values were assigned as NA) of the Pima Indian diabetes data. The data set contains the following independent and dependent variables.
Independent variables (symbol: I)
- I1: pregnant: Number of times pregnant
- I2: glucose: Plasma glucose concentration (glucose tolerance test)
- I3: pressure: Diastolic blood pressure (mm Hg)
- I4: triceps: Triceps skinfold thickness (mm)
- I5: insulin: 2-Hour serum insulin (mu U/ml)
- I6: mass: Body mass index (weight in kg/(height in m)\²)
- I7: pedigree: Diabetes pedigree function
- I8: age: Age (years)
Dependent Variable (symbol: D)
- D1: diabetes: diabetes case (pos/neg)
Aim of the Modelling
- fitting a decision tree classification machine learning model that accurately predicts whether or not the patients in the data set have diabetes
- Decision tree pruning for reducing overfitting
- Decision tree hyperparameters tuning
Loading relevant libraries
The first step of data analysis starts with loading relevant libraries.
library(mlbench) # Diabetes dataset
library(rpart) # Decision tree
library(rpart.plot) # Plotting decision tree
library(caret) # Accuracy estimation
library(Metrics) # For diferent model evaluation metrics
Loading dataset
The very next step is to load the data into the R environment. As this comes with mlbench package one can load the data calling data( ).
# load the diabetes dataset
data(PimaIndiansDiabetes2)
Data Preprocessing
The next step would be to perform exploratory analysis. First, we need to remove the missing values using the na.omit( ) function. Print the data types using glimpse( ) method from dplyr library. You can see that all the variables except the dependent variable (diabetes: categorical/factor) are double type.
Diabetes <- na.omit(PimaIndiansDiabetes2) # Data for modeling
dplyr::glimpse(Diabetes)
Train and Test Split
The next step is to split the dataset into 80% train and 20% test. Here, we are using the sample( ) method to randomly pick the observation index for train and test split with replacement. Next, based on indexing we split out the train and test data.
set.seed(123)
index <- sample(2, nrow(Diabetes), prob = c(0.8, 0.2), replace = TRUE)
Diabetes_train <- Diabetes[index==1, ] # Train data
Diabetes_test <- Diabetes[index == 2, ] # Test data
The train data includes 318 observations and test data included 74 observations. Both contain 9 variables.
print(dim(Diabetes_train))
print(dim(Diabetes_test))
Model Training
The next step is the model training and evaluation of model performance
Training a Decision Tree
For decision tree training, we will use the rpart( ) function from the rpart library. The arguments include; formula for the model, data and method.
formula = diabetes ~. i.e., diabetes is predicted by all independent variables (excluding diabetes)
Here, the method should be specified as the class for the classification task.
# Train a decision tree model
Diabetes_model <- rpart(formula = diabetes ~.,
data = Diabetes_train,
method = "class")
Model Plotting
The main advantage of the tree-based model is that you can plot the tree structure and able to figure out the decision mechanism.
# type: 0; Draw a split label at each split and a node label at each leaf.
# yesno = 2; provides spli yes or no
# Extra = 0; no extra information
rpart.plot(x = Diabetes_model, yesno = 2, type = 0, extra = 0)
Model Performance Evaluation
Next, step is to see how our trained model performs on the test/unseen dataset. For predicting the test data class we need to supply the model object, test dataset and the type = “class” inside the predict( ) function.
# class prediction
class_predicted <- predict(object = Diabetes_model,
newdata = Diabetes_test,
type = "class")
(a) Confusion matrix
To evaluate the test performance we are going to use the confusionMatrix( ) from caret library. We can observe that out of 74 observations it wrongly predicts 17 observations. The model has achieved about 77.03% accuracy using a single decision tree.
# Generate a confusion matrix for the test data
confusionMatrix(data = class_predicted,
reference = Diabetes_test$diabetes)
(b) Test accuracy
We can also supply the predicted class labels and original test dataset labels to the accuracy( ) function for estimating the model accuracy.
accuracy(actual = class_predicted,
predicted = Diabetes_test$diabetes)
Splitting Criteria Based Model Comparision
While building the model the decision tree algorithm uses splitting criteria. There are two popular splitting criteria used in decision trees; one is called “gini” and others called “information gain”. Here, we try to compare the model performance on the test set after training with different split criteria. The splitting criteria are supplied using parms argument as a list.
# Model training based on gini-based splitting criteria
Diabetes_model1 <- rpart(formula = diabetes ~ .,
data = Diabetes_train,
method = "class",
parms = list(split = "gini"))
# Model training based on information gain-based splitting criteria
Diabetes_model2 <- rpart(formula = diabetes ~ .,
data = Diabetes_train,
method = "class",
parms = list(split = "information"))
Model Evaluation on Test Data
After model training, the next step is to predict the class labels of the test dataset.
# Generate class predictions on the test data using gini-based splitting criteria
pred1 <- predict(object = Diabetes_model1,
newdata = Diabetes_test,
type = "class")
# Generate class predictions on test data using information gain based splitting criteria
pred2 <- predict(object = Diabetes_model2,
newdata = Diabetes_test,
type = "class")
Prediction Accuracy Comparision
Next, we compare the accuracy of the models. Here, we can observe that “gini” based splitting criteria is providing a more accurate model than “information” based splitting.
# Compare classification accuracy on test data
accuracy(actual = Diabetes_test$diabetes,
predicted = pred1)
accuracy(actual = Diabetes_test$diabetes,
predicted = pred2)
The initial model (Diabetes_model) and the “gini” based model (Diabetes_model1) providing the same accuracy, as rpart model uses “gini” as its default splitting criteria.
Decision Tree Pruning
The initial model (Diabetes_model) plot shows that the tree structure is deep and fragile which might reduce the easy interpretation in the decision-making process. Thus here we would try to explore other ways to make the tree more interpretable without losing performance. One way of doing this is by pruning the fragile part of the tree (part contributes to model overfitting).
(a) Plotting the error vs Complexity Parameter (CP)
The decision tree has one parameter called complexity parameter (cp) which controls the size of the decision tree. If the cost of adding another variable to the decision tree from the current node is above the value of cp, then tree building does not continue. We can generate the cp vs error plot using the plotcp( ) library.
# Plotting Cost Parameter (CP) Table
plotcp(Diabetes_model1)
(b) Generating complexity parameter table
We can also generate the cp table by calling model$cptable. Here, you can observe that xerror is minimum with CP value of 0.025.
# Plotting the Cost Parameter (CP) Table
print(Diabetes_model1$cptable)
(c) Obtaining an optimal pruned model
We can filter out the optimal CP value by identifying the index of minimum xerror and by supplying it to the CP table.
# Retrieve of optimal cp value based on cross-validated error
index <- which.min(Diabetes_model1$cptable[, "xerror"])
cp_optimal <- Diabetes_model1$cptable[index, "CP"]
The next step is to prune the tree using prune( ) function by supplying optimal CP value. If we plot the optimal pruned tree we can now observe that the tree is very simple and easy to interpret.
If a person has a glucose level above 128 and age greater than 25 will be designated as diabetes positive else negative.
# Pruning tree based on optimal CP value
Diabetes_model1_opt <- prune(tree = Diabetes_model1, cp = cp_optimal)
# Plotting pruned tree
rpart.plot(x = Diabetes_model1_opt, yesno = 2, type = 0, extra = 0)
(d) Pruned tree performance
The next step is to check whether the prune tree has similar performance or the performance has been compromised. After the performance check, we can see that the pruned tree is as capable as the earlier fragile tree but now it is simple and easy to interpret.
pred3 <- predict(object = Diabetes_model1_opt,
newdata = Diabetes_test,
type = "class")
accuracy(actual = Diabetes_test$diabetes,
predicted = pred3)
Decision Tree Hyperparameter Tuning
Next, we would try to increase the performance of the decision tree model by tuning its hyperparameters. The rpart( ) offers different hyperparameters but here we will try to tune two important parameters which are minsplit, and maxdepth.
- minsplit: the minimum number of observations that must exist in the node in order for a split to be attempted.
- maxdepth: The maximum depth of any node of the final tree.
(a) Generating hyperparameter grid
First, we generate a sequence 1 to 20 for both minsplit and maxdepth. Then we build a parameter combination grid using expand.grid( ) function.
# Setting values for minsplit and maxdepth
## the minimum number of observations that must exist in a node in order for a split to be attempted.
## Set the maximum depth of any node of the final tree
minsplit <- seq(1, 20, 1)
maxdepth <- seq(1, 20, 1)
# Generate a search grid
hyperparam_grid <- expand.grid(minsplit = minsplit, maxdepth = maxdepth)
(b) Training grid-based models
The next step is to train different models based on each grid hyperparameter combination. This could be done through the following steps:
- using a for loop to loop through each hyperparameter in the grid and then supplying it to rpart( ) function for model training
- storing each model into an empty list (diabetes_models)
# Number of potential models in the grid
num_models <- nrow(hyperparam_grid)
# Create an empty list
diabetes_models <- list()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:num_models) {
minsplit <- hyperparam_grid$minsplit[i]
maxdepth <- hyperparam_grid$maxdepth[i]
# Train a model and store in the list
diabetes_models[[i]] <- rpart(formula = diabetes ~ .,
data = Diabetes_train,
method = "class",
minsplit = minsplit,
maxdepth = maxdepth)
}
(c) Computing test accuracy
The next step is to check the model performance of each model on test data and retrieving the best model. This could be done through the following steps:
- using a for loop to loop through each model in the list, and then predicting the test data and computing accuracy
- storing each model accuracy into an empty vector (accuracy_values)
# Number of models inside the grid
num_models <- length(diabetes_models)
# Create an empty vector to store accuracy values
accuracy_values <- c()
# Use for loop for models accuracy estimation
for (i in 1:num_models) {
# Retrieve the model i from the list
model <- diabetes_models[[i]]
# Generate predictions on test data
pred <- predict(object = model,
newdata = Diabetes_test,
type = "class")
# Compute test accuracy and add to the empty vector accuracy_values
accuracy_values[i] <- accuracy(actual = Diabetes_test$diabetes,
predicted = pred)
}
(d) Identifying the best model
The next step is to retrieve the best performing model (maximum accuracy) and printing its hyperparameters using model$control. We can observe that with a minimum split of 17 and a maximum depth of 6 the model provides most accurate results when evaluated on unseen/test dataset.
# Identify the model with maximum accuracy
best_model <- diabetes_models[[which.max(accuracy_values)]]
# Print the model hyper-parameters of the best model
best_model$control
(e) Best model evaluation on test data
After identifying the best performing model, the next step is to see how accurate the model is. Now, with the best hyperparameters, the model achieved an accuracy of 81.08% which is really great.
# Best_model accuracy on test data
pred <- predict(object = best_model,
newdata = Diabetes_test,
type = "class")
accuracy(actual = Diabetes_test$diabetes,
predicted = pred)
(f) Best model plot
Now it is time to plot the best model.
rpart.plot(x = best_model, yesno = 2, type = 0, extra = 0)
Even the above plot is for best performing model, still, it looks a little bit fragile. So your next task would be to prune it and see if you get a better interpretable decision tree or not.
I hope you learned something new. See you next time!
Note
This article was first published in onezero.blog, a data science, machine learning and research related blogging platform maintained by myself.
If you learned something new and liked this article, say 👋 / follow me on onezero.blog (my personal blogging website), Twitter, LinkedIn, YouTubeandGithub.
Featured Image Credit: Photo by Sara Dubler on Unsplash
References
[1] Breiman, L., Friedman, J., Stone, C.J. and Olshen, R.A., 1984. Classification and regression trees. CRC press.
[2] Loh, W. (2014). Fifty Years of Classification and Regression Trees 1.
[3] Newman, C. B. D. & Merz, C. (1998). UCI Repository of machine learning databases, Technical report, University of California, Irvine, Dept. of Information and Computer Sciences.