heart
data.For this lab we will be working with the heart
dataset
that you can download from here
You should install and load gbm
(gradient boosting) and
xgboost
(extreme gradient boosting).
install.packages(c("gbm", "xgboost", "caret"))
library(tidyverse)
library(gbm)
heart<-read.csv("https://raw.githubusercontent.com/JSC370/jsc370-2023/main/data/heart/heart.csv") |>
mutate(
AHD = 1 * (AHD == "Yes"),
ChestPain = factor(ChestPain),
Thal = factor(Thal),
ChestPain_num = case_match(
ChestPain,
"asymptomatic" ~ 1,
"nonanginal" ~ 2,
"nontypical" ~ 3,
.default = 0
),
Thal_num = case_match(
Thal,
"fixed" ~ 1,
"normal" ~ 2,
.default = 0
)
) |>
na.omit()
head(heart)
## Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1 63 1 typical 145 233 1 2 150 0 2.3 3 0
## 2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3
## 3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2
## 4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0
## 5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0
## 6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0
## Thal AHD ChestPain_num Thal_num
## 1 fixed 0 0 1
## 2 normal 1 1 2
## 3 reversable 1 1 0
## 4 normal 0 2 2
## 5 normal 0 3 2
## 6 normal 0 3 2
Evaluate the effect of critical boosting parameters (number of
boosting iterations, shrinkage/learning rate, and tree
depth/interaction). In gbm
the number of iterations is
controlled by n.trees
(default is 100), the
shrinkage/learning rate is controlled by shrinkage
(default
is 0.001), and interaction depth by interaction.depth
(default is 1).
Note, boosting can overfit if the number of trees is too large. The shrinkage parameter controls the rate at which the boosting learns. Very small \(\lambda\) can require using a very large number of trees to achieve good performance. Finally, interaction depth controls the interaction order of the boosted model. A value of 1 implies an additive model, a value of 2 implies a model with up to 2-way interactions, etc. the default is 1.
set.seed(370)
train <- sample(1:nrow(heart), floor(nrow(heart) * 0.7))
test <- setdiff(1:nrow(heart), train)
gbm
using 10-fold cross-validation
(cv.folds=10
) on the training data with
n.trees = 5000
, shrinkage = 0.001
, and
interaction.depth =1
. Plot the cross-validation errors as a
function of the boosting iteration and calculate the test MSE.set.seed(2100)
heart_boost <- gbm(
AHD ~ ., data = heart[train, ],
distribution = "bernoulli",
cv.folds = 10,
n.trees = 5000,
interaction.depth = 1, # maximum depth of individual tree
shrinkage = .001, # learning rate
class.stratify.cv = TRUE
)
# plot
data.frame(
n_tree = seq.int(length(heart_boost$cv.error)),
cv = heart_boost$cv.error,
train = heart_boost$train.error
) |>
ggplot(aes(x = n_tree)) +
theme_minimal() +
geom_line(aes(y = cv, colour = "10-Fold Cross Validation Error")) +
geom_line(aes(y = train, colour = "Train")) +
scale_colour_manual(values = c("steelblue", "#FF69B4")) +
labs(colour = NULL, y = "Log loss") +
theme(legend.position = "top")
yhat_boost <- predict(heart_boost, heart[test, ],
n.trees = 5000, type = "response")
mean((yhat_boost - heart$AHD[test])^2)
## [1] 0.1303206
n.trees=5000
with
the following 3 additional combination of parameters: a)
shrinkage = 0.001
, interaction.depth = 2
; b)
shrinkage = 0.01
, interaction.depth = 1
; c)
shrinkage = 0.01
, interaction.depth = 2
.set.seed(2100)
heart_boost_2a <- gbm(
AHD ~ ., data = heart[train, ],
distribution = "bernoulli",
cv.folds = 10,
n.trees = 5000,
interaction.depth = 2, # maximum depth of individual tree
shrinkage = .001, # learning rate
class.stratify.cv = TRUE
)
set.seed(2100)
heart_boost_2b <- gbm(
AHD ~ ., data = heart[train, ],
distribution = "bernoulli",
cv.folds = 10,
n.trees = 5000,
interaction.depth = 1, # maximum depth of individual tree
shrinkage = .01, # learning rate
class.stratify.cv = TRUE
)
set.seed(2100)
heart_boost_2c <- gbm(
AHD ~ ., data = heart[train, ],
distribution = "bernoulli",
cv.folds = 10,
n.trees = 5000,
interaction.depth = 2, # maximum depth of individual tree
shrinkage = .01, # learning rate
class.stratify.cv = TRUE
)
# plotting CV vs train errors
err_df_1 <- data.frame(
n_tree = seq.int(length(heart_boost$cv.error)),
cv = heart_boost$cv.error,
train = heart_boost$train.error,
params = "0.001 shrinkage and 1 max tree depth"
)
err_df_2a <- data.frame(
n_tree = seq.int(length(heart_boost_2a$cv.error)),
cv = heart_boost_2a$cv.error,
train = heart_boost_2a$train.error,
params = "0.001 shrinkage and 2 max tree depth"
)
err_df_2b <- data.frame(
n_tree = seq.int(length(heart_boost_2b$cv.error)),
cv = heart_boost_2b$cv.error,
train = heart_boost_2b$train.error,
params = "0.01 shrinkage and 1 max tree depth"
)
err_df_2c <- data.frame(
n_tree = seq.int(length(heart_boost_2c$cv.error)),
cv = heart_boost_2c$cv.error,
train = heart_boost_2c$train.error,
params = "0.01 shrinkage and 2 max tree depth"
)
rbind(err_df_1, err_df_2a, err_df_2b, err_df_2c) |>
ggplot(aes(x = n_tree)) +
theme_minimal() +
geom_line(aes(y = cv, colour = "10-Fold Cross Validation Error")) +
geom_line(aes(y = train, colour = "Train")) +
scale_colour_manual(values = c("steelblue", "#FF69B4")) +
labs(colour = NULL, y = "Log loss") +
theme(legend.position = "top") +
facet_grid(params ~ .)
# MSEs
yhat_boost_2a <- predict(heart_boost_2a, newdata = heart[test, ],
n.trees = 5000, type = "response")
yhat_boost_2b <- predict(heart_boost_2b, newdata = heart[test, ],
n.trees = 5000, type = "response")
yhat_boost_2c <- predict(heart_boost_2c, newdata = heart[test, ],
n.trees = 5000, type = "response")
data.frame(
params = c(
"0.001 shrinkage\n1 max tree depth",
"0.001 shrinkage\n2 max tree depth",
"0.01 shrinkage\n1 max tree depth",
"0.01 shrinkage\n2 max tree depth"
),
mses = c(
mean((yhat_boost - heart$AHD[test])^2),
mean((yhat_boost_2a - heart$AHD[test])^2),
mean((yhat_boost_2b - heart$AHD[test])^2),
mean((yhat_boost_2c - heart$AHD[test])^2)
)
) |>
ggplot(aes(y = params, x = mses)) +
theme_minimal() +
geom_point(size = 1.5) +
geom_segment(aes(xend = 0, yend = params), linetype = "dotted") +
labs(x = "MSE", y = NULL)
Training an xgboost model with xgboost
and perform a
grid search for tuning the number of trees and the maximum depth of the
tree. Also perform 10-fold cross-validation and determine the variable
importance. Finally, compute the test MSE.
See this online tutorial for reference.
# we are going to tell caret package that we want to conduct a grid search w/ 10-CV
train_control <- caret::trainControl(method = "cv", number = 10, search = "grid")
# parameter grid
tune_grid <- expand.grid(
max_depth = c(1, 3, 5, 7), # max tree depth
# larger makes model more complex
# and potentially overfit
nrounds = 50 * (1:10), # number of trees (iterations)
eta = c(.3, .1, .01, .001), # learning rate (shrinkage)
# default values
gamma = 0, # minimum loss reduction required to make
# a further partition on a leaf node of the tree.
# The larger gamma is, the more conservative
# the algorithm will be.
colsample_bytree = .6, # the fraction of columns to be sampled
# (borrowing from random forest)
subsample = 1,
min_child_weight = 1
)
heart_xgb <- caret::train(
factor(AHD) ~ ., # caret expects factor for classification
# if you use xgboost, it expects numeric values only
data = heart[train, ],
method = "xgbTree", # caret calls xgboost in the backend
# see https://topepo.github.io/caret/available-models.html
trControl = train_control,
tuneGrid = tune_grid,
verbosity = 0. # don't print messages from xgboost while fitting
)
print(heart_xgb)
## eXtreme Gradient Boosting
##
## 207 samples
## 15 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 186, 187, 187, 186, 187, 186, ...
## Resampling results across tuning parameters:
##
## eta max_depth nrounds Accuracy Kappa
## 0.001 1 50 0.7880952 0.5708569
## 0.001 1 100 0.7880952 0.5708569
## 0.001 1 150 0.7880952 0.5708569
## 0.001 1 200 0.7880952 0.5708569
## 0.001 1 250 0.7880952 0.5708569
## 0.001 1 300 0.7880952 0.5708569
## 0.001 1 350 0.7880952 0.5708569
## 0.001 1 400 0.7880952 0.5708569
## 0.001 1 450 0.7880952 0.5708569
## 0.001 1 500 0.7880952 0.5708569
## 0.001 3 50 0.8066667 0.6075716
## 0.001 3 100 0.8066667 0.6075716
## 0.001 3 150 0.8019048 0.5977166
## 0.001 3 200 0.7971429 0.5881288
## 0.001 3 250 0.8066667 0.6070213
## 0.001 3 300 0.8019048 0.5971663
## 0.001 3 350 0.8019048 0.5971663
## 0.001 3 400 0.8019048 0.5971663
## 0.001 3 450 0.8019048 0.5971663
## 0.001 3 500 0.8019048 0.5971663
## 0.001 5 50 0.8069048 0.6083989
## 0.001 5 100 0.8116667 0.6165655
## 0.001 5 150 0.8021429 0.5974306
## 0.001 5 200 0.8021429 0.5974306
## 0.001 5 250 0.8069048 0.6070185
## 0.001 5 300 0.8019048 0.5968642
## 0.001 5 350 0.8019048 0.5968642
## 0.001 5 400 0.8019048 0.5968642
## 0.001 5 450 0.8066667 0.6062801
## 0.001 5 500 0.8066667 0.6062801
## 0.001 7 50 0.7923810 0.5784972
## 0.001 7 100 0.7971429 0.5881392
## 0.001 7 150 0.7973810 0.5881553
## 0.001 7 200 0.7971429 0.5878267
## 0.001 7 250 0.7971429 0.5878267
## 0.001 7 300 0.8019048 0.5968642
## 0.001 7 350 0.8019048 0.5968642
## 0.001 7 400 0.8019048 0.5968642
## 0.001 7 450 0.8019048 0.5968642
## 0.001 7 500 0.8019048 0.5968642
## 0.010 1 50 0.7880952 0.5708569
## 0.010 1 100 0.8076190 0.6077266
## 0.010 1 150 0.8266667 0.6475917
## 0.010 1 200 0.8216667 0.6386188
## 0.010 1 250 0.8216667 0.6372056
## 0.010 1 300 0.8216667 0.6372056
## 0.010 1 350 0.8216667 0.6372056
## 0.010 1 400 0.8121429 0.6188929
## 0.010 1 450 0.8169048 0.6287479
## 0.010 1 500 0.8169048 0.6287479
## 0.010 3 50 0.8019048 0.5971663
## 0.010 3 100 0.7923810 0.5785650
## 0.010 3 150 0.8164286 0.6258180
## 0.010 3 200 0.8211905 0.6353955
## 0.010 3 250 0.8219048 0.6369022
## 0.010 3 300 0.8173810 0.6278326
## 0.010 3 350 0.8319048 0.6574194
## 0.010 3 400 0.8319048 0.6569883
## 0.010 3 450 0.8271429 0.6477129
## 0.010 3 500 0.8319048 0.6569883
## 0.010 5 50 0.8069048 0.6070081
## 0.010 5 100 0.8116667 0.6167372
## 0.010 5 150 0.8116667 0.6161250
## 0.010 5 200 0.8116667 0.6167236
## 0.010 5 250 0.8264286 0.6455803
## 0.010 5 300 0.8169048 0.6255803
## 0.010 5 350 0.8219048 0.6358160
## 0.010 5 400 0.8221429 0.6367722
## 0.010 5 450 0.8269048 0.6471603
## 0.010 5 500 0.8219048 0.6363141
## 0.010 7 50 0.8164286 0.6260126
## 0.010 7 100 0.8214286 0.6352595
## 0.010 7 150 0.8069048 0.6060223
## 0.010 7 200 0.8164286 0.6258403
## 0.010 7 250 0.8071429 0.6073202
## 0.010 7 300 0.8221429 0.6370952
## 0.010 7 350 0.8221429 0.6360178
## 0.010 7 400 0.8221429 0.6362282
## 0.010 7 450 0.8171429 0.6260241
## 0.010 7 500 0.8171429 0.6260241
## 0.100 1 50 0.8169048 0.6287479
## 0.100 1 100 0.8311905 0.6573174
## 0.100 1 150 0.8311905 0.6570721
## 0.100 1 200 0.8359524 0.6662869
## 0.100 1 250 0.8211905 0.6361301
## 0.100 1 300 0.8211905 0.6361301
## 0.100 1 350 0.8259524 0.6449590
## 0.100 1 400 0.8259524 0.6449590
## 0.100 1 450 0.8259524 0.6449590
## 0.100 1 500 0.8259524 0.6449590
## 0.100 3 50 0.8366667 0.6678150
## 0.100 3 100 0.8219048 0.6361847
## 0.100 3 150 0.7873810 0.5670750
## 0.100 3 200 0.7973810 0.5864202
## 0.100 3 250 0.7971429 0.5873725
## 0.100 3 300 0.8021429 0.5971449
## 0.100 3 350 0.8021429 0.5971449
## 0.100 3 400 0.8019048 0.5966840
## 0.100 3 450 0.8066667 0.6071065
## 0.100 3 500 0.8066667 0.6071065
## 0.100 5 50 0.8169048 0.6262773
## 0.100 5 100 0.8121429 0.6183432
## 0.100 5 150 0.8119048 0.6159706
## 0.100 5 200 0.8023810 0.5969630
## 0.100 5 250 0.7973810 0.5869693
## 0.100 5 300 0.7971429 0.5859312
## 0.100 5 350 0.8019048 0.5952426
## 0.100 5 400 0.8019048 0.5952426
## 0.100 5 450 0.8066667 0.6059672
## 0.100 5 500 0.8066667 0.6059672
## 0.100 7 50 0.8171429 0.6250281
## 0.100 7 100 0.8116667 0.6164033
## 0.100 7 150 0.8161905 0.6250178
## 0.100 7 200 0.8116667 0.6154492
## 0.100 7 250 0.8164286 0.6250359
## 0.100 7 300 0.8164286 0.6250359
## 0.100 7 350 0.8164286 0.6250359
## 0.100 7 400 0.8114286 0.6144110
## 0.100 7 450 0.8161905 0.6234818
## 0.100 7 500 0.8161905 0.6233129
## 0.300 1 50 0.8261905 0.6464558
## 0.300 1 100 0.8211905 0.6361301
## 0.300 1 150 0.8211905 0.6361301
## 0.300 1 200 0.8259524 0.6468548
## 0.300 1 250 0.8257143 0.6452691
## 0.300 1 300 0.8109524 0.6162383
## 0.300 1 350 0.8064286 0.6075763
## 0.300 1 400 0.8064286 0.6075763
## 0.300 1 450 0.7916667 0.5791845
## 0.300 1 500 0.7866667 0.5698603
## 0.300 3 50 0.8311905 0.6557125
## 0.300 3 100 0.8214286 0.6363866
## 0.300 3 150 0.8164286 0.6246870
## 0.300 3 200 0.8164286 0.6246870
## 0.300 3 250 0.8164286 0.6246870
## 0.300 3 300 0.8164286 0.6246870
## 0.300 3 350 0.8164286 0.6246870
## 0.300 3 400 0.8164286 0.6246870
## 0.300 3 450 0.8164286 0.6246870
## 0.300 3 500 0.8164286 0.6246870
## 0.300 5 50 0.8119048 0.6158628
## 0.300 5 100 0.8119048 0.6168981
## 0.300 5 150 0.8066667 0.6073191
## 0.300 5 200 0.8066667 0.6073191
## 0.300 5 250 0.8066667 0.6073191
## 0.300 5 300 0.8116667 0.6173127
## 0.300 5 350 0.8164286 0.6266422
## 0.300 5 400 0.8164286 0.6266422
## 0.300 5 450 0.8164286 0.6266422
## 0.300 5 500 0.8164286 0.6266422
## 0.300 7 50 0.8116667 0.6142199
## 0.300 7 100 0.8161905 0.6248093
## 0.300 7 150 0.8066667 0.6055541
## 0.300 7 200 0.8066667 0.6055541
## 0.300 7 250 0.8066667 0.6055333
## 0.300 7 300 0.8114286 0.6149419
## 0.300 7 350 0.8066667 0.6050890
## 0.300 7 400 0.8114286 0.6146664
## 0.300 7 450 0.8114286 0.6146664
## 0.300 7 500 0.8066667 0.6052578
##
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning
##
## Tuning parameter 'min_child_weight' was held constant at a value of 1
##
## Tuning parameter 'subsample' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 50, max_depth = 3, eta
## = 0.1, gamma = 0, colsample_bytree = 0.6, min_child_weight = 1 and subsample
## = 1.
# variable importance
varimp <- varImp(heart_xgb, scale = FALSE)
plot(varimp)
# MSE
yhat_xgb <- predict(heart_xgb, newdata = heart[test, ], type = "raw")
# the output is a factor variable
# as.numeric(yhat_xgb) - 1 correctly converts back to 0-1
yhat_xgb <- as.numeric(yhat_xgb) - 1
mean((yhat_xgb- heart$AHD[test])^2)
## [1] 0.2
# Root MSE from caret package
caret::RMSE(heart$AHD[test], yhat_xgb)
## [1] 0.4472136