Gradient Boosting Essentials in R Using XGBOOST
Previously, we have described bagging and random forest machine learning algorithms for building a powerful predictive model (Chapter @ref(bagging-and-random-forest)).
Recall that bagging consists of taking multiple subsets of the training data set, then building multiple independent decision tree models, and then average the models allowing to create a very performant predictive model compared to the classical CART model (Chapter @ref(decision-tree-models)).
This chapter describes an alternative method called boosting, which is similar to the bagging method, except that the trees are grown sequentially: each successive tree is grown using information from previously grown trees, with the aim to minimize the error of the previous models (James et al. 2014).
For example, given a current regression tree model, the procedure is as follow:
- Fit a decision tree using the model residual errors as the outcome variable.
- Add this new decision tree, adjusted by a shrinkage parameter
lambda
, into the fitted function in order to update the residuals. lambda is a small positive value, typically comprised between 0.01 and 0.001 (James et al. 2014).
This approach results in slowly and successively improving the fitted the model resulting a very performant model. Boosting has different tuning parameters including:
- The number of trees B
- The shrinkage parameter lambda
- The number of splits in each tree.
There are different variants of boosting, including Adaboost, gradient boosting and stochastic gradient boosting.
Stochastic gradient boosting, implemented in the R package xgboost, is the most commonly used boosting technique, which involves resampling of observations and columns in each round. It offers the best performance. xgboost stands for extremely gradient boosting.
Boosting can be used for both classification and regression problems.
In this chapter we’ll describe how to compute boosting in R.
Contents:
Loading required R packages
tidyverse
for easy data manipulation and visualizationcaret
for easy machine learning workflowxgboost
for computing boosting algorithm
library(tidyverse)
library(caret)
library(xgboost)
Classification
Example of data set
Data set: PimaIndiansDiabetes2
[in mlbench
package], introduced in Chapter @ref(classification-in-r), for predicting the probability of being diabetes positive based on multiple clinical variables.
Randomly split the data into training set (80% for building a predictive model) and test set (20% for evaluating the model). Make sure to set seed for reproducibility.
# Load the data and remove NAs
data("PimaIndiansDiabetes2", package = "mlbench")
PimaIndiansDiabetes2 <- na.omit(PimaIndiansDiabetes2)
# Inspect the data
sample_n(PimaIndiansDiabetes2, 3)
# Split the data into training and test set
set.seed(123)
training.samples <- PimaIndiansDiabetes2$diabetes %>%
createDataPartition(p = 0.8, list = FALSE)
train.data <- PimaIndiansDiabetes2[training.samples, ]
test.data <- PimaIndiansDiabetes2[-training.samples, ]
Boosted classification trees
We’ll use the caret
workflow, which invokes the xgboost
package, to automatically adjust the model parameter values, and fit the final best boosted tree that explains the best our data.
We’ll use the following arguments in the function train()
:
trControl
, to set up 10-fold cross validation
# Fit the model on the training set
set.seed(123)
model <- train(
diabetes ~., data = train.data, method = "xgbTree",
trControl = trainControl("cv", number = 10)
)
# Best tuning parameter
model$bestTune
## nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
## 18 150 1 0.3 0 0.8 1 1
# Make predictions on the test data
predicted.classes <- model %>% predict(test.data)
head(predicted.classes)
## [1] neg pos neg neg pos neg
## Levels: neg pos
# Compute model prediction accuracy rate
mean(predicted.classes == test.data$diabetes)
## [1] 0.744
The prediction accuracy on new test data is 74%, which is good.
For more explanation about the boosting tuning parameters, type ?xgboost
in R to see the documentation.
Variable importance
The function varImp()
[in caret] displays the importance of variables in percentage:
varImp(model)
## xgbTree variable importance
##
## Overall
## glucose 100.00
## mass 20.23
## pregnant 15.83
## insulin 13.15
## pressure 9.51
## triceps 8.18
## pedigree 0.00
## age 0.00
Regression
Similarly, you can build a random forest model to perform regression, that is to predict a continuous variable.
Example of data set
We’ll use the Boston
data set [in MASS
package], introduced in Chapter @ref(regression-analysis), for predicting the median house value (mdev
), in Boston Suburbs, using different predictor variables.
Randomly split the data into training set (80% for building a predictive model) and test set (20% for evaluating the model).
# Load the data
data("Boston", package = "MASS")
# Inspect the data
sample_n(Boston, 3)
# Split the data into training and test set
set.seed(123)
training.samples <- Boston$medv %>%
createDataPartition(p = 0.8, list = FALSE)
train.data <- Boston[training.samples, ]
test.data <- Boston[-training.samples, ]
Boosted regression trees
Here the prediction error is measured by the RMSE, which corresponds to the average difference between the observed known values of the outcome and the predicted value by the model.
# Fit the model on the training set
set.seed(123)
model <- train(
medv ~., data = train.data, method = "xgbTree",
trControl = trainControl("cv", number = 10)
)
# Best tuning parameter mtry
model$bestTune
# Make predictions on the test data
predictions <- model %>% predict(test.data)
head(predictions)
# Compute the average prediction error RMSE
RMSE(predictions, test.data$medv)
Discussion
This chapter describes the boosting machine learning techniques and provide examples in R for building a predictive model. See also bagging and random forest methods in Chapter @ref(bagging-and-random-forest).
References
James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2014. An Introduction to Statistical Learning: With Applications in R. Springer Publishing Company, Incorporated.