Learning goals

Lab description

For this lab we will be working with the heart dataset that you can download from here

Setup packages

You should install and load gbm (gradient boosting) and xgboost (extreme gradient boosting).

install.packages(c("gbm", "xgboost", "caret"))

Load packages and data

library(tidyverse)
library(gbm)
heart<-read.csv("https://raw.githubusercontent.com/JSC370/jsc370-2023/main/data/heart/heart.csv") |>
  mutate(
    AHD = 1 * (AHD == "Yes"),
    ChestPain = factor(ChestPain),
    Thal = factor(Thal),
    ChestPain_num = case_match(
      ChestPain,
      "asymptomatic" ~ 1,
      "nonanginal" ~ 2,
      "nontypical" ~ 3,
      .default = 0
    ),
    Thal_num = case_match(
      Thal,
      "fixed" ~ 1, 
      "normal" ~ 2,
      .default = 0
    )
  ) |>
  na.omit()
head(heart)
##   Age Sex    ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1  63   1      typical    145  233   1       2   150     0     2.3     3  0
## 2  67   1 asymptomatic    160  286   0       2   108     1     1.5     2  3
## 3  67   1 asymptomatic    120  229   0       2   129     1     2.6     2  2
## 4  37   1   nonanginal    130  250   0       0   187     0     3.5     3  0
## 5  41   0   nontypical    130  204   0       2   172     0     1.4     1  0
## 6  56   1   nontypical    120  236   0       0   178     0     0.8     1  0
##         Thal AHD ChestPain_num Thal_num
## 1      fixed   0             0        1
## 2     normal   1             1        2
## 3 reversable   1             1        0
## 4     normal   0             2        2
## 5     normal   0             3        2
## 6     normal   0             3        2

Questions

Question 1: Gradient Boosting

Evaluate the effect of critical boosting parameters (number of boosting iterations, shrinkage/learning rate, and tree depth/interaction). In gbm the number of iterations is controlled by n.trees (default is 100), the shrinkage/learning rate is controlled by shrinkage (default is 0.001), and interaction depth by interaction.depth (default is 1).

Note, boosting can overfit if the number of trees is too large. The shrinkage parameter controls the rate at which the boosting learns. Very small \(\lambda\) can require using a very large number of trees to achieve good performance. Finally, interaction depth controls the interaction order of the boosted model. A value of 1 implies an additive model, a value of 2 implies a model with up to 2-way interactions, etc. the default is 1.

  1. Split the heart data into training and testing. Also need to make character variables into numeric variables and get rid of missing values.
set.seed(370)
train <- sample(1:nrow(heart), floor(nrow(heart) * 0.7))
test <- setdiff(1:nrow(heart), train)
  1. Set the seed and train a boosting classification with gbm using 10-fold cross-validation (cv.folds=10) on the training data with n.trees = 5000, shrinkage = 0.001, and interaction.depth =1. Plot the cross-validation errors as a function of the boosting iteration and calculate the test MSE.
set.seed(2100)
heart_boost <- gbm(
  AHD ~ ., data = heart[train, ],
  distribution = "bernoulli",
  cv.folds = 10,
  n.trees = 5000, 
  interaction.depth = 1,  # maximum depth of individual tree
  shrinkage = .001,       # learning rate
  class.stratify.cv = TRUE
)
# plot
data.frame(
  n_tree = seq.int(length(heart_boost$cv.error)),
  cv = heart_boost$cv.error,
  train = heart_boost$train.error
) |>
  ggplot(aes(x = n_tree)) +
  theme_minimal() +
  geom_line(aes(y = cv, colour = "10-Fold Cross Validation Error")) +
  geom_line(aes(y = train, colour = "Train")) +
  scale_colour_manual(values = c("steelblue", "#FF69B4")) +
  labs(colour = NULL, y = "Log loss") +
  theme(legend.position = "top")

yhat_boost <- predict(heart_boost, heart[test, ], 
                      n.trees = 5000, type = "response")
mean((yhat_boost - heart$AHD[test])^2)
## [1] 0.1303206
  1. Repeat ii. using the same seed and n.trees=5000 with the following 3 additional combination of parameters: a) shrinkage = 0.001, interaction.depth = 2; b) shrinkage = 0.01, interaction.depth = 1; c) shrinkage = 0.01, interaction.depth = 2.
set.seed(2100)
heart_boost_2a <- gbm(
  AHD ~ ., data = heart[train, ],
  distribution = "bernoulli",
  cv.folds = 10,
  n.trees = 5000, 
  interaction.depth = 2,  # maximum depth of individual tree
  shrinkage = .001,       # learning rate
  class.stratify.cv = TRUE
)
set.seed(2100)
heart_boost_2b <- gbm(
  AHD ~ ., data = heart[train, ],
  distribution = "bernoulli",
  cv.folds = 10,
  n.trees = 5000, 
  interaction.depth = 1,  # maximum depth of individual tree
  shrinkage = .01,       # learning rate
  class.stratify.cv = TRUE
)
set.seed(2100)
heart_boost_2c <- gbm(
  AHD ~ ., data = heart[train, ],
  distribution = "bernoulli",
  cv.folds = 10,
  n.trees = 5000, 
  interaction.depth = 2,  # maximum depth of individual tree
  shrinkage = .01,       # learning rate
  class.stratify.cv = TRUE
)
# plotting CV vs train errors
err_df_1 <- data.frame(
  n_tree = seq.int(length(heart_boost$cv.error)),
  cv = heart_boost$cv.error,
  train = heart_boost$train.error,
  params = "0.001 shrinkage and 1 max tree depth"
) 
err_df_2a <- data.frame(
  n_tree = seq.int(length(heart_boost_2a$cv.error)),
  cv = heart_boost_2a$cv.error,
  train = heart_boost_2a$train.error,
  params = "0.001 shrinkage and 2 max tree depth"
) 
err_df_2b <- data.frame(
  n_tree = seq.int(length(heart_boost_2b$cv.error)),
  cv = heart_boost_2b$cv.error,
  train = heart_boost_2b$train.error,
  params = "0.01 shrinkage and 1 max tree depth"
) 
err_df_2c <- data.frame(
  n_tree = seq.int(length(heart_boost_2c$cv.error)),
  cv = heart_boost_2c$cv.error,
  train = heart_boost_2c$train.error,
  params = "0.01 shrinkage and 2 max tree depth"
) 
rbind(err_df_1, err_df_2a, err_df_2b, err_df_2c) |>
  ggplot(aes(x = n_tree)) +
  theme_minimal() +
  geom_line(aes(y = cv, colour = "10-Fold Cross Validation Error")) +
  geom_line(aes(y = train, colour = "Train")) +
  scale_colour_manual(values = c("steelblue", "#FF69B4")) +
  labs(colour = NULL, y = "Log loss") +
  theme(legend.position = "top") +
  facet_grid(params ~ .)

# MSEs
yhat_boost_2a <- predict(heart_boost_2a, newdata = heart[test, ],
                        n.trees = 5000, type = "response")
yhat_boost_2b <- predict(heart_boost_2b, newdata = heart[test, ],
                        n.trees = 5000, type = "response")
yhat_boost_2c <- predict(heart_boost_2c, newdata = heart[test, ],
                        n.trees = 5000, type = "response")
data.frame(
  params = c(
    "0.001 shrinkage\n1 max tree depth",
    "0.001 shrinkage\n2 max tree depth",
    "0.01 shrinkage\n1 max tree depth",
    "0.01 shrinkage\n2 max tree depth"
  ),
  mses = c(
    mean((yhat_boost - heart$AHD[test])^2),
    mean((yhat_boost_2a - heart$AHD[test])^2),
    mean((yhat_boost_2b - heart$AHD[test])^2),
    mean((yhat_boost_2c - heart$AHD[test])^2)
  )
) |>
  ggplot(aes(y = params, x = mses)) +
  theme_minimal() +
  geom_point(size = 1.5) +
  geom_segment(aes(xend = 0, yend = params), linetype = "dotted") +
  labs(x = "MSE", y = NULL)

Question 2: Extreme Gradient Boosting

Training an xgboost model with xgboost and perform a grid search for tuning the number of trees and the maximum depth of the tree. Also perform 10-fold cross-validation and determine the variable importance. Finally, compute the test MSE.

See this online tutorial for reference.

# we are going to tell caret package that we want to conduct a grid search w/ 10-CV
train_control <- caret::trainControl(method = "cv", number = 10, search = "grid")
# parameter grid 
tune_grid <- expand.grid(
  max_depth = c(1, 3, 5, 7),   # max tree depth 
                               # larger makes model more complex 
                               # and potentially overfit
  nrounds = 50 * (1:10),       # number of trees (iterations)
  eta = c(.3, .1, .01, .001),  # learning rate (shrinkage)
  # default values
  gamma = 0,                   # minimum loss reduction required to make 
                               # a further partition on a leaf node of the tree. 
                               # The larger gamma is, the more conservative 
                               # the algorithm will be.
  colsample_bytree = .6,       # the fraction of columns to be sampled 
                               # (borrowing from random forest)
  subsample = 1,
  min_child_weight = 1
)

heart_xgb <- caret::train(
  factor(AHD) ~ .,       # caret expects factor for classification
                         # if you use xgboost, it expects numeric values only
  data = heart[train, ], 
  method = "xgbTree",    # caret calls xgboost in the backend
                         # see https://topepo.github.io/caret/available-models.html
  trControl = train_control,
  tuneGrid = tune_grid,
  verbosity = 0.         # don't print messages from xgboost while fitting
)
print(heart_xgb)
## eXtreme Gradient Boosting 
## 
## 207 samples
##  15 predictor
##   2 classes: '0', '1' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 186, 187, 187, 186, 187, 186, ... 
## Resampling results across tuning parameters:
## 
##   eta    max_depth  nrounds  Accuracy   Kappa    
##   0.001  1           50      0.7880952  0.5708569
##   0.001  1          100      0.7880952  0.5708569
##   0.001  1          150      0.7880952  0.5708569
##   0.001  1          200      0.7880952  0.5708569
##   0.001  1          250      0.7880952  0.5708569
##   0.001  1          300      0.7880952  0.5708569
##   0.001  1          350      0.7880952  0.5708569
##   0.001  1          400      0.7880952  0.5708569
##   0.001  1          450      0.7880952  0.5708569
##   0.001  1          500      0.7880952  0.5708569
##   0.001  3           50      0.8066667  0.6075716
##   0.001  3          100      0.8066667  0.6075716
##   0.001  3          150      0.8019048  0.5977166
##   0.001  3          200      0.7971429  0.5881288
##   0.001  3          250      0.8066667  0.6070213
##   0.001  3          300      0.8019048  0.5971663
##   0.001  3          350      0.8019048  0.5971663
##   0.001  3          400      0.8019048  0.5971663
##   0.001  3          450      0.8019048  0.5971663
##   0.001  3          500      0.8019048  0.5971663
##   0.001  5           50      0.8069048  0.6083989
##   0.001  5          100      0.8116667  0.6165655
##   0.001  5          150      0.8021429  0.5974306
##   0.001  5          200      0.8021429  0.5974306
##   0.001  5          250      0.8069048  0.6070185
##   0.001  5          300      0.8019048  0.5968642
##   0.001  5          350      0.8019048  0.5968642
##   0.001  5          400      0.8019048  0.5968642
##   0.001  5          450      0.8066667  0.6062801
##   0.001  5          500      0.8066667  0.6062801
##   0.001  7           50      0.7923810  0.5784972
##   0.001  7          100      0.7971429  0.5881392
##   0.001  7          150      0.7973810  0.5881553
##   0.001  7          200      0.7971429  0.5878267
##   0.001  7          250      0.7971429  0.5878267
##   0.001  7          300      0.8019048  0.5968642
##   0.001  7          350      0.8019048  0.5968642
##   0.001  7          400      0.8019048  0.5968642
##   0.001  7          450      0.8019048  0.5968642
##   0.001  7          500      0.8019048  0.5968642
##   0.010  1           50      0.7880952  0.5708569
##   0.010  1          100      0.8076190  0.6077266
##   0.010  1          150      0.8266667  0.6475917
##   0.010  1          200      0.8216667  0.6386188
##   0.010  1          250      0.8216667  0.6372056
##   0.010  1          300      0.8216667  0.6372056
##   0.010  1          350      0.8216667  0.6372056
##   0.010  1          400      0.8121429  0.6188929
##   0.010  1          450      0.8169048  0.6287479
##   0.010  1          500      0.8169048  0.6287479
##   0.010  3           50      0.8019048  0.5971663
##   0.010  3          100      0.7923810  0.5785650
##   0.010  3          150      0.8164286  0.6258180
##   0.010  3          200      0.8211905  0.6353955
##   0.010  3          250      0.8219048  0.6369022
##   0.010  3          300      0.8173810  0.6278326
##   0.010  3          350      0.8319048  0.6574194
##   0.010  3          400      0.8319048  0.6569883
##   0.010  3          450      0.8271429  0.6477129
##   0.010  3          500      0.8319048  0.6569883
##   0.010  5           50      0.8069048  0.6070081
##   0.010  5          100      0.8116667  0.6167372
##   0.010  5          150      0.8116667  0.6161250
##   0.010  5          200      0.8116667  0.6167236
##   0.010  5          250      0.8264286  0.6455803
##   0.010  5          300      0.8169048  0.6255803
##   0.010  5          350      0.8219048  0.6358160
##   0.010  5          400      0.8221429  0.6367722
##   0.010  5          450      0.8269048  0.6471603
##   0.010  5          500      0.8219048  0.6363141
##   0.010  7           50      0.8164286  0.6260126
##   0.010  7          100      0.8214286  0.6352595
##   0.010  7          150      0.8069048  0.6060223
##   0.010  7          200      0.8164286  0.6258403
##   0.010  7          250      0.8071429  0.6073202
##   0.010  7          300      0.8221429  0.6370952
##   0.010  7          350      0.8221429  0.6360178
##   0.010  7          400      0.8221429  0.6362282
##   0.010  7          450      0.8171429  0.6260241
##   0.010  7          500      0.8171429  0.6260241
##   0.100  1           50      0.8169048  0.6287479
##   0.100  1          100      0.8311905  0.6573174
##   0.100  1          150      0.8311905  0.6570721
##   0.100  1          200      0.8359524  0.6662869
##   0.100  1          250      0.8211905  0.6361301
##   0.100  1          300      0.8211905  0.6361301
##   0.100  1          350      0.8259524  0.6449590
##   0.100  1          400      0.8259524  0.6449590
##   0.100  1          450      0.8259524  0.6449590
##   0.100  1          500      0.8259524  0.6449590
##   0.100  3           50      0.8366667  0.6678150
##   0.100  3          100      0.8219048  0.6361847
##   0.100  3          150      0.7873810  0.5670750
##   0.100  3          200      0.7973810  0.5864202
##   0.100  3          250      0.7971429  0.5873725
##   0.100  3          300      0.8021429  0.5971449
##   0.100  3          350      0.8021429  0.5971449
##   0.100  3          400      0.8019048  0.5966840
##   0.100  3          450      0.8066667  0.6071065
##   0.100  3          500      0.8066667  0.6071065
##   0.100  5           50      0.8169048  0.6262773
##   0.100  5          100      0.8121429  0.6183432
##   0.100  5          150      0.8119048  0.6159706
##   0.100  5          200      0.8023810  0.5969630
##   0.100  5          250      0.7973810  0.5869693
##   0.100  5          300      0.7971429  0.5859312
##   0.100  5          350      0.8019048  0.5952426
##   0.100  5          400      0.8019048  0.5952426
##   0.100  5          450      0.8066667  0.6059672
##   0.100  5          500      0.8066667  0.6059672
##   0.100  7           50      0.8171429  0.6250281
##   0.100  7          100      0.8116667  0.6164033
##   0.100  7          150      0.8161905  0.6250178
##   0.100  7          200      0.8116667  0.6154492
##   0.100  7          250      0.8164286  0.6250359
##   0.100  7          300      0.8164286  0.6250359
##   0.100  7          350      0.8164286  0.6250359
##   0.100  7          400      0.8114286  0.6144110
##   0.100  7          450      0.8161905  0.6234818
##   0.100  7          500      0.8161905  0.6233129
##   0.300  1           50      0.8261905  0.6464558
##   0.300  1          100      0.8211905  0.6361301
##   0.300  1          150      0.8211905  0.6361301
##   0.300  1          200      0.8259524  0.6468548
##   0.300  1          250      0.8257143  0.6452691
##   0.300  1          300      0.8109524  0.6162383
##   0.300  1          350      0.8064286  0.6075763
##   0.300  1          400      0.8064286  0.6075763
##   0.300  1          450      0.7916667  0.5791845
##   0.300  1          500      0.7866667  0.5698603
##   0.300  3           50      0.8311905  0.6557125
##   0.300  3          100      0.8214286  0.6363866
##   0.300  3          150      0.8164286  0.6246870
##   0.300  3          200      0.8164286  0.6246870
##   0.300  3          250      0.8164286  0.6246870
##   0.300  3          300      0.8164286  0.6246870
##   0.300  3          350      0.8164286  0.6246870
##   0.300  3          400      0.8164286  0.6246870
##   0.300  3          450      0.8164286  0.6246870
##   0.300  3          500      0.8164286  0.6246870
##   0.300  5           50      0.8119048  0.6158628
##   0.300  5          100      0.8119048  0.6168981
##   0.300  5          150      0.8066667  0.6073191
##   0.300  5          200      0.8066667  0.6073191
##   0.300  5          250      0.8066667  0.6073191
##   0.300  5          300      0.8116667  0.6173127
##   0.300  5          350      0.8164286  0.6266422
##   0.300  5          400      0.8164286  0.6266422
##   0.300  5          450      0.8164286  0.6266422
##   0.300  5          500      0.8164286  0.6266422
##   0.300  7           50      0.8116667  0.6142199
##   0.300  7          100      0.8161905  0.6248093
##   0.300  7          150      0.8066667  0.6055541
##   0.300  7          200      0.8066667  0.6055541
##   0.300  7          250      0.8066667  0.6055333
##   0.300  7          300      0.8114286  0.6149419
##   0.300  7          350      0.8066667  0.6050890
##   0.300  7          400      0.8114286  0.6146664
##   0.300  7          450      0.8114286  0.6146664
##   0.300  7          500      0.8066667  0.6052578
## 
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning
## 
## Tuning parameter 'min_child_weight' was held constant at a value of 1
## 
## Tuning parameter 'subsample' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 50, max_depth = 3, eta
##  = 0.1, gamma = 0, colsample_bytree = 0.6, min_child_weight = 1 and subsample
##  = 1.
# variable importance
varimp <- varImp(heart_xgb, scale = FALSE)
plot(varimp)

# MSE
yhat_xgb <- predict(heart_xgb, newdata = heart[test, ], type = "raw") 
# the output is a factor variable
# as.numeric(yhat_xgb) - 1 correctly converts back to 0-1
yhat_xgb <- as.numeric(yhat_xgb) - 1
mean((yhat_xgb- heart$AHD[test])^2) 
## [1] 0.2
# Root MSE from caret package
caret::RMSE(heart$AHD[test], yhat_xgb)
## [1] 0.4472136