Learning goals

Lab description

For this lab we will be working with simulated data and the heart dataset that you can download from here

Setup packages

You should install and load rpart (trees), randomForest (random forest), gbm (gradient boosting) and xgboost (extreme gradient boosting).

install.packages(c("rpart", "rpart.plot", "randomForest", "gbm", "xgboost"))

Load packages and data

library(tidyverse)
library(rpart)
library(rpart.plot)
library(randomForest)
library(gbm)
# library(xgboost)
heart <- read.csv("https://raw.githubusercontent.com/JSC370/jsc370-2023/main/data/heart/heart.csv") |>
  mutate(
    AHD = 1 * (AHD == "Yes"),
    ChestPain = factor(ChestPain),
    Thal = factor(Thal)
  )
head(heart)
##   Age Sex    ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1  63   1      typical    145  233   1       2   150     0     2.3     3  0
## 2  67   1 asymptomatic    160  286   0       2   108     1     1.5     2  3
## 3  67   1 asymptomatic    120  229   0       2   129     1     2.6     2  2
## 4  37   1   nonanginal    130  250   0       0   187     0     3.5     3  0
## 5  41   0   nontypical    130  204   0       2   172     0     1.4     1  0
## 6  56   1   nontypical    120  236   0       0   178     0     0.8     1  0
##         Thal AHD
## 1      fixed   0
## 2     normal   1
## 3 reversable   1
## 4     normal   0
## 5     normal   0
## 6     normal   0

Questions

Question 1: Trees with simulated data

  • Simulate data from a random uniform distribution [-5,5] and normally distributed errors (s.d = 0.5)
  • Create a non-linear relationship y=sin(x)+error
  • Split the data into test and training sets (500 points each), plot the data

\[Y_i = \sin(X_i) + \varepsilon_i\]

\[\varepsilon_i \sim N(0, 0.5^2)\] for \(i=1,\ldots,1000\).

set.seed(1984)
n <- 1000
x <- runif(n, -5,5) # uniform distribution [-5, 5]
error <- rnorm(n, sd = 0.5)
y <- sin(x) + error 
nonlin <- data.frame(y = y, x = x)

train_size <- sample(1:1000, size = 500)
nonlin_train <- nonlin[train_size,]
nonlin_test <- nonlin[-train_size,]

ggplot(nonlin, aes(y = y,x = x)) +
  geom_point(alpha = .3) +
  theme_minimal()

  • Fit a regression tree using the training set, plot it
# method ='anova' indicates regression tree, cp = 0 ensures that binary recursive partitioning will not stop early due to lack of improvement in RSS by an amount of at least cp
treefit <- rpart(y ~ x, data = nonlin_train, method = "anova", 
                 control = list(cp = 0))
# note: the height of the branches are proportional to the improvement in RSS
rpart.plot(treefit)

  • Determine the optimal complexity parameter (cp) to prune the tree
# plot the cp relative error to determine the optimal complexity parameter
plotcp(treefit)

treefit$cptable
##              CP nsplit rel error    xerror       xstd
## 1  0.1946974705      0 1.0000000 1.0040153 0.05021945
## 2  0.0340611345      3 0.4159076 0.4538104 0.02638272
## 3  0.0273934137      4 0.3818465 0.4005153 0.02381225
## 4  0.0189166164      5 0.3544530 0.3935674 0.02372447
## 5  0.0187413385      6 0.3355364 0.3960865 0.02397866
## 6  0.0121323195      7 0.3167951 0.3733062 0.02312415
## 7  0.0117742571      8 0.3046628 0.3603378 0.02145075
## 8  0.0106806598      9 0.2928885 0.3578205 0.02147093
## 9  0.0073025938     10 0.2822078 0.3534446 0.02167188
## 10 0.0022726695     11 0.2749053 0.3370788 0.02130490
## 11 0.0018682275     13 0.2703599 0.3422712 0.02175280
## 12 0.0017155126     14 0.2684917 0.3401212 0.02164749
## 13 0.0016638943     16 0.2650607 0.3407474 0.02142361
## 14 0.0012684940     17 0.2633968 0.3441431 0.02192873
## 15 0.0012380829     19 0.2608598 0.3455235 0.02190487
## 16 0.0012185543     20 0.2596217 0.3465451 0.02202560
## 17 0.0011560361     21 0.2584031 0.3480298 0.02223123
## 18 0.0011186278     23 0.2560911 0.3482530 0.02221418
## 19 0.0011053218     24 0.2549724 0.3500691 0.02223374
## 20 0.0010825391     25 0.2538671 0.3500691 0.02223374
## 21 0.0008734059     26 0.2527846 0.3538991 0.02235048
## 22 0.0008675315     27 0.2519112 0.3603456 0.02281530
## 23 0.0008005684     28 0.2510436 0.3610295 0.02282622
## 24 0.0007660162     29 0.2502431 0.3612981 0.02270867
## 25 0.0007606718     30 0.2494771 0.3613280 0.02274626
## 26 0.0007542690     31 0.2487164 0.3592501 0.02241081
## 27 0.0007398102     33 0.2472079 0.3591680 0.02233911
## 28 0.0006485116     34 0.2464680 0.3593999 0.02233074
## 29 0.0006005097     36 0.2451710 0.3609370 0.02255022
## 30 0.0004515101     37 0.2445705 0.3623265 0.02269062
## 31 0.0004173810     38 0.2441190 0.3650851 0.02274125
## 32 0.0002682229     39 0.2437016 0.3648361 0.02268516
## 33 0.0001455512     40 0.2434334 0.3653993 0.02276076
## 34 0.0000000000     41 0.2432878 0.3646539 0.02275065
cp_summary <- treefit$cptable # $cptable extracts the table directly
optimal_cp <- cp_summary[9, 1]
  • Plot the pruned tree and summarize
tree_pruned <- prune(treefit, cp = optimal_cp)
# plot the pruned tree
rpart.plot(tree_pruned)

# summarize the pruned tree object and relate the summary to the plotted tree above
summary(tree_pruned)
## Call:
## rpart(formula = y ~ x, data = nonlin_train, method = "anova", 
##     control = list(cp = 0))
##   n= 500 
## 
##            CP nsplit rel error    xerror       xstd
## 1 0.194697471      0 1.0000000 1.0040153 0.05021945
## 2 0.034061135      3 0.4159076 0.4538104 0.02638272
## 3 0.027393414      4 0.3818465 0.4005153 0.02381225
## 4 0.018916616      5 0.3544530 0.3935674 0.02372447
## 5 0.018741339      6 0.3355364 0.3960865 0.02397866
## 6 0.012132319      7 0.3167951 0.3733062 0.02312415
## 7 0.011774257      8 0.3046628 0.3603378 0.02145075
## 8 0.010680660      9 0.2928885 0.3578205 0.02147093
## 9 0.007302594     10 0.2822078 0.3534446 0.02167188
## 
## Variable importance
##   x 
## 100 
## 
## Node number 1: 500 observations,    complexity param=0.1946975
##   mean=-0.1147421, MSE=0.8380283 
##   left son=2 (436 obs) right son=3 (64 obs)
##   Primary splits:
##       x < -3.484362   to the right, improve=0.1407386, (0 missing)
## 
## Node number 2: 436 observations,    complexity param=0.1946975
##   mean=-0.2463199, MSE=0.7852125 
##   left son=4 (193 obs) right son=5 (243 obs)
##   Primary splits:
##       x < -0.05508361 to the left,  improve=0.1965717, (0 missing)
## 
## Node number 3: 64 observations,    complexity param=0.01213232
##   mean=0.781632, MSE=0.2764066 
##   left son=6 (43 obs) right son=7 (21 obs)
##   Primary splits:
##       x < -4.612867   to the right, improve=0.2873718, (0 missing)
## 
## Node number 4: 193 observations,    complexity param=0.03406113
##   mean=-0.6871574, MSE=0.3678353 
##   left son=8 (169 obs) right son=9 (24 obs)
##   Primary splits:
##       x < -2.955229   to the right, improve=0.2010375, (0 missing)
## 
## Node number 5: 243 observations,    complexity param=0.1946975
##   mean=0.1038103, MSE=0.839768 
##   left son=10 (98 obs) right son=11 (145 obs)
##   Primary splits:
##       x < 3.133467    to the right, improve=0.5805772, (0 missing)
## 
## Node number 6: 43 observations
##   mean=0.5846748, MSE=0.2062735 
## 
## Node number 7: 21 observations
##   mean=1.184925, MSE=0.1779356 
## 
## Node number 8: 169 observations,    complexity param=0.01177426
##   mean=-0.7896347, MSE=0.3017005 
##   left son=16 (121 obs) right son=17 (48 obs)
##   Primary splits:
##       x < -0.8897254  to the left,  improve=0.09676081, (0 missing)
## 
## Node number 9: 24 observations
##   mean=0.03445331, MSE=0.2388637 
## 
## Node number 10: 98 observations,    complexity param=0.01874134
##   mean=-0.7455278, MSE=0.3121068 
##   left son=20 (69 obs) right son=21 (29 obs)
##   Primary splits:
##       x < 3.814829    to the right, improve=0.2567438, (0 missing)
## 
## Node number 11: 145 observations,    complexity param=0.02739341
##   mean=0.6778457, MSE=0.3793274 
##   left son=22 (21 obs) right son=23 (124 obs)
##   Primary splits:
##       x < 0.4782685   to the left,  improve=0.2086857, (0 missing)
## 
## Node number 16: 121 observations,    complexity param=0.01068066
##   mean=-0.8972479, MSE=0.3048796 
##   left son=32 (71 obs) right son=33 (50 obs)
##   Primary splits:
##       x < -2.103778   to the right, improve=0.1213146, (0 missing)
## 
## Node number 17: 48 observations
##   mean=-0.5183597, MSE=0.1909035 
## 
## Node number 20: 69 observations
##   mean=-0.9290447, MSE=0.2562326 
## 
## Node number 21: 29 observations
##   mean=-0.3088841, MSE=0.1742598 
## 
## Node number 22: 21 observations
##   mean=-0.005837036, MSE=0.2249018 
## 
## Node number 23: 124 observations,    complexity param=0.01891662
##   mean=0.7936307, MSE=0.3129137 
##   left son=46 (26 obs) right son=47 (98 obs)
##   Primary splits:
##       x < 2.54086     to the right, improve=0.20428, (0 missing)
## 
## Node number 32: 71 observations
##   mean=-1.058638, MSE=0.2497405 
## 
## Node number 33: 50 observations
##   mean=-0.6680742, MSE=0.2936702 
## 
## Node number 46: 26 observations
##   mean=0.3027775, MSE=0.233719 
## 
## Node number 47: 98 observations
##   mean=0.9238571, MSE=0.2530437
  • Based on the plot and/or summary of the pruned tree create a vector of the (ordered) split points for variable x, and a vector of fitted values for the intervals determined by the split points of x.
# extract and sort x values at split points
x_splits <- sort(tree_pruned$splits[ , "index"])
# extract the corresponding y values by passing in the x values
y_splits <- predict(tree_pruned, data.frame(x = c(-999, x_splits)))
data.frame(x_splits_lower = c(-999, x_splits), y_fitted = y_splits)
##    x_splits_lower     y_fitted
## 1   -999.00000000  1.184925143
## 2     -4.61286694  0.584674844
## 3     -3.48436187  0.034453315
## 4     -2.95522928 -0.668074168
## 5     -2.10377810 -1.058637893
## 6     -0.88972540 -0.518359653
## 7     -0.05508361 -0.005837036
## 8      0.47826850  0.923857058
## 9      2.54085971  0.302777531
## 10     3.13346687 -0.308884110
## 11     3.81482886 -0.929044682
  • plot the step function corresponding to the fitted (pruned) tree
# using base R plot functions
stpfn <- stepfun(x_splits, y_splits)
plot(y ~ x, data = nonlin_train, pch = 20)
plot(stpfn, add = TRUE, lwd = 2, col = '#FF69B4', pch = 19)

## ggplot version
step_df <- data.frame(
  # adding trailing lines at both ends
  x = c(-999, x_splits, 999), 
  y = c(y_splits, y_splits[length(y_splits)])
)
ggplot(mapping = aes(x = x, y = y)) +
  theme_minimal() +
  geom_point(data = nonlin_train, alpha = .5) +
  geom_step(data = step_df, colour = "#FF69B4") +
  geom_point(data = step_df, colour = "#FF69B4", size = 2) +
  coord_cartesian(xlim = c(-5, 5))

  • Fit a linear model to the training data and plot the regression line.
# base R
lmfit <-  lm(y ~ x, data = nonlin_train)
plot(y ~ x, data = nonlin_train, pch = 20, col = "darkgrey")
abline(lmfit, col = "steelblue", lwd = 2)
plot(stpfn, add = TRUE, lwd = 2, col = '#FF69B4', pch = 19)

# ggplot
ggplot(mapping = aes(x = x, y = y)) +
  theme_minimal() +
  geom_point(data = nonlin_train, alpha = .5, colour = "darkgrey") +
  geom_smooth(data = nonlin_train, colour = "steelblue",
              method = "lm", formula = y ~ x, se = FALSE) +
  geom_step(data = step_df, colour = "#FF69B4") +
  geom_point(data = step_df, colour = "#FF69B4", size = 2) +
  coord_cartesian(xlim = c(-5, 5))

  • Contrast the quality of the fit of the tree model vs. linear regression by inspection of the plot

Write about the differences

  • Compute the test MSE of the pruned tree and the linear regression model

\[MSE = \frac{1}{n} \sum_{i=1}^n \left(\hat{y}_i-y_i\right)^2\]

tree_pred <- predict(tree_pruned, nonlin_test)
lm_pred <- predict(lmfit, nonlin_test)
tibble(tree_pred, lm_pred, y = nonlin_test$y) |>
  summarise(
    tree_mse = sum((tree_pred - y)^2) / n(),
    lm_mse = sum((lm_pred - y)^2) / n()
  ) 
## # A tibble: 1 × 2
##   tree_mse lm_mse
##      <dbl>  <dbl>
## 1    0.319  0.785
  • Is the lm or regression tree better at fitting a non-linear function?

Write about the differences


Question 2: Analysis of Real Data

  • Split the heart data into training and testing (70-30%)
glimpse(heart)
## Rows: 303
## Columns: 14
## $ Age       <int> 63, 67, 67, 37, 41, 56, 62, 57, 63, 53, 57, 56, 56, 44, 52, …
## $ Sex       <int> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, …
## $ ChestPain <fct> typical, asymptomatic, asymptomatic, nonanginal, nontypical,…
## $ RestBP    <int> 145, 160, 120, 130, 130, 120, 140, 120, 130, 140, 140, 140, …
## $ Chol      <int> 233, 286, 229, 250, 204, 236, 268, 354, 254, 203, 192, 294, …
## $ Fbs       <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, …
## $ RestECG   <int> 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, …
## $ MaxHR     <int> 150, 108, 129, 187, 172, 178, 160, 163, 147, 155, 148, 153, …
## $ ExAng     <int> 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Oldpeak   <dbl> 2.3, 1.5, 2.6, 3.5, 1.4, 0.8, 3.6, 0.6, 1.4, 3.1, 0.4, 1.3, …
## $ Slope     <int> 3, 2, 2, 3, 1, 1, 3, 1, 2, 3, 2, 2, 2, 1, 1, 1, 3, 1, 1, 1, …
## $ Ca        <int> 0, 3, 2, 0, 0, 0, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Thal      <fct> fixed, normal, reversable, normal, normal, normal, normal, n…
## $ AHD       <dbl> 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, …
set.seed(2023)
train <- sample(1:nrow(heart), round(0.7 * nrow(heart)))
heart_train <- heart[train, ]
heart_test <- heart[-train, ]
  • Fit a classification tree using rpart, plot the full tree
heart_tree <- rpart(
  AHD ~ ., data = heart_train,
  method = "class", 
  control = list(minsplits = 10, minbucket = 3, cp = 0, xval = 10)
)
rpart.plot(heart_tree)

  • Plot the complexity parameter table for an rpart fit and prune the tree
# plot the cp relative error to determine the optimal complexity parameter
plotcp(heart_tree)

heart_cp_summary <- heart_tree$cptable # $cptable extracts the table directly
# picking the min xerror (the tree isn't too large to begin with)
heart_optimal_cp <- heart_cp_summary[4, 1] 
heart_tree_pruned <- prune(heart_tree, cp = heart_optimal_cp)
rpart.plot(heart_tree_pruned)

summary(heart_tree_pruned)
## Call:
## rpart(formula = AHD ~ ., data = heart_train, method = "class", 
##     control = list(minsplits = 10, minbucket = 3, cp = 0, xval = 10))
##   n= 212 
## 
##      CP nsplit rel error xerror       xstd
## 1 0.520      0      1.00   1.17 0.07240804
## 2 0.050      1      0.48   0.55 0.06382095
## 3 0.045      3      0.38   0.51 0.06223434
## 4 0.010      5      0.29   0.48 0.06093609
## 
## Variable importance
##      Thal ChestPain     MaxHR        Ca     ExAng   Oldpeak       Sex       Age 
##        24        16        12        12        10         9         8         4 
##    RestBP     Slope 
##         3         2 
## 
## Node number 1: 212 observations,    complexity param=0.52
##   predicted class=0  expected loss=0.4716981  P(node) =1
##     class counts:   112   100
##    probabilities: 0.528 0.472 
##   left son=2 (116 obs) right son=3 (96 obs)
##   Primary splits:
##       Thal      splits as  RLR,       improve=32.81092, (2 missing)
##       Ca        < 0.5   to the left,  improve=30.33955, (3 missing)
##       ChestPain splits as  RLLL,      improve=28.48156, (0 missing)
##       ExAng     < 0.5   to the left,  improve=19.59505, (0 missing)
##       Oldpeak   < 1.7   to the left,  improve=18.61642, (0 missing)
##   Surrogate splits:
##       MaxHR     < 147.5 to the right, agree=0.695, adj=0.326, (2 split)
##       Sex       < 0.5   to the left,  agree=0.690, adj=0.316, (0 split)
##       ExAng     < 0.5   to the left,  agree=0.681, adj=0.295, (0 split)
##       ChestPain splits as  RLLL,      agree=0.671, adj=0.274, (0 split)
##       Oldpeak   < 1.7   to the left,  agree=0.643, adj=0.211, (0 split)
## 
## Node number 2: 116 observations,    complexity param=0.05
##   predicted class=0  expected loss=0.2241379  P(node) =0.5471698
##     class counts:    90    26
##    probabilities: 0.776 0.224 
##   left son=4 (82 obs) right son=5 (34 obs)
##   Primary splits:
##       Ca        < 0.5   to the left,  improve=8.882549, (1 missing)
##       ChestPain splits as  RLLL,      improve=6.229038, (0 missing)
##       Oldpeak   < 2.1   to the left,  improve=5.242263, (0 missing)
##       MaxHR     < 128.5 to the right, improve=4.249828, (0 missing)
##       ExAng     < 0.5   to the left,  improve=4.132062, (0 missing)
##   Surrogate splits:
##       Age     < 66.5  to the left,  agree=0.748, adj=0.147, (1 split)
##       MaxHR   < 128.5 to the right, agree=0.739, adj=0.118, (0 split)
##       Oldpeak < 1.7   to the left,  agree=0.722, adj=0.059, (0 split)
##       RestBP  < 151   to the left,  agree=0.713, adj=0.029, (0 split)
## 
## Node number 3: 96 observations,    complexity param=0.045
##   predicted class=1  expected loss=0.2291667  P(node) =0.4528302
##     class counts:    22    74
##    probabilities: 0.229 0.771 
##   left son=6 (40 obs) right son=7 (56 obs)
##   Primary splits:
##       Ca        < 0.5   to the left,  improve=7.559679, (2 missing)
##       ChestPain splits as  RLLL,      improve=7.540488, (0 missing)
##       Oldpeak   < 0.85  to the left,  improve=5.989262, (0 missing)
##       ExAng     < 0.5   to the left,  improve=3.234947, (0 missing)
##       MaxHR     < 143.5 to the right, improve=2.487101, (0 missing)
##   Surrogate splits:
##       Age       < 53.5  to the left,  agree=0.745, adj=0.385, (2 split)
##       ChestPain splits as  RRLL,      agree=0.691, adj=0.256, (0 split)
##       Oldpeak   < 0.95  to the left,  agree=0.660, adj=0.179, (0 split)
##       MaxHR     < 172   to the right, agree=0.638, adj=0.128, (0 split)
##       RestBP    < 111   to the left,  agree=0.617, adj=0.077, (0 split)
## 
## Node number 4: 82 observations
##   predicted class=0  expected loss=0.09756098  P(node) =0.3867925
##     class counts:    74     8
##    probabilities: 0.902 0.098 
## 
## Node number 5: 34 observations,    complexity param=0.05
##   predicted class=1  expected loss=0.4705882  P(node) =0.1603774
##     class counts:    16    18
##    probabilities: 0.471 0.529 
##   left son=10 (18 obs) right son=11 (16 obs)
##   Primary splits:
##       ChestPain splits as  RLLL,      improve=4.843954, (0 missing)
##       Sex       < 0.5   to the left,  improve=3.706089, (0 missing)
##       Oldpeak   < 0.85  to the left,  improve=3.126891, (0 missing)
##       Slope     < 1.5   to the left,  improve=3.126891, (0 missing)
##       MaxHR     < 119.5 to the right, improve=2.596349, (0 missing)
##   Surrogate splits:
##       Oldpeak < 0.85  to the left,  agree=0.765, adj=0.500, (0 split)
##       Slope   < 1.5   to the left,  agree=0.765, adj=0.500, (0 split)
##       RestBP  < 129   to the right, agree=0.735, adj=0.438, (0 split)
##       MaxHR   < 125.5 to the right, agree=0.735, adj=0.438, (0 split)
##       ExAng   < 0.5   to the left,  agree=0.706, adj=0.375, (0 split)
## 
## Node number 6: 40 observations,    complexity param=0.045
##   predicted class=1  expected loss=0.45  P(node) =0.1886792
##     class counts:    18    22
##    probabilities: 0.450 0.550 
##   left son=12 (17 obs) right son=13 (23 obs)
##   Primary splits:
##       ChestPain splits as  RLLL,      improve=5.856266, (0 missing)
##       ExAng     < 0.5   to the left,  improve=4.150877, (0 missing)
##       Oldpeak   < 0.05  to the left,  improve=4.113480, (0 missing)
##       Age       < 51.5  to the right, improve=3.200000, (0 missing)
##       MaxHR     < 144   to the right, improve=3.200000, (0 missing)
##   Surrogate splits:
##       ExAng   < 0.5   to the left,  agree=0.750, adj=0.412, (0 split)
##       MaxHR   < 144   to the right, agree=0.725, adj=0.353, (0 split)
##       RestBP  < 122   to the left,  agree=0.650, adj=0.176, (0 split)
##       Age     < 51.5  to the right, agree=0.625, adj=0.118, (0 split)
##       Oldpeak < 0.7   to the left,  agree=0.625, adj=0.118, (0 split)
## 
## Node number 7: 56 observations
##   predicted class=1  expected loss=0.07142857  P(node) =0.2641509
##     class counts:     4    52
##    probabilities: 0.071 0.929 
## 
## Node number 10: 18 observations
##   predicted class=0  expected loss=0.2777778  P(node) =0.08490566
##     class counts:    13     5
##    probabilities: 0.722 0.278 
## 
## Node number 11: 16 observations
##   predicted class=1  expected loss=0.1875  P(node) =0.0754717
##     class counts:     3    13
##    probabilities: 0.188 0.812 
## 
## Node number 12: 17 observations
##   predicted class=0  expected loss=0.2352941  P(node) =0.08018868
##     class counts:    13     4
##    probabilities: 0.765 0.235 
## 
## Node number 13: 23 observations
##   predicted class=1  expected loss=0.2173913  P(node) =0.1084906
##     class counts:     5    18
##    probabilities: 0.217 0.783
  • Compute the test misclassification error

\[Err = \frac{1}{n}\sum_{i=1}^nI\left(y_i\neq y\right)\]

heart_pred <- predict(heart_tree, heart_test)
sum((heart_pred[ , 2] > 0.5) == (heart_test$AHD == 0)) / nrow(heart_test)
## [1] 0.2197802
  • Fit the tree with the optimal complexity parameter to the full data (training + testing)
heart_tree_full <- rpart(
  AHD ~ ., data = heart, # full data
  method = "class", 
  control = list(minsplits = 10, minbucket = 3, xval = 10,
                 cp = heart_optimal_cp) # optimal cp
)
rpart.plot(heart_tree_full)


Question 3: Bagging, Random Forest

  • Compare the performance of classification trees (above), bagging, random forests for predicting heart disease based on the heart data. Train each of the models on the training data and extract the cross-validation (or out-of-bag error for bagging and Random forest).

  • For bagging use randomForest with mtry equal to the number of features (all other parameters at their default values). Generate the variable importance plot using varImpPlot and extract variable importance from the randomForest fitted object using the importance function.

n_features <- dim(heart_train)[2] - 1

heart_bagging <- randomForest(as.factor(AHD) ~ . , 
                              data = heart_train, 
                              mtry = n_features,
                              na.action = na.omit)
mean(heart_bagging$err.rate[ , 1]) # OOB error is the first column of the matrix
## [1] 0.1949861
varImpPlot(heart_bagging, main = "Variable importance plot (Bagging)")

importance(heart_bagging)
##           MeanDecreaseGini
## Age              8.1721939
## Sex              2.1393115
## ChestPain       10.9170764
## RestBP           6.0438493
## Chol             8.1692413
## Fbs              0.4373468
## RestECG          0.9730543
## MaxHR            7.0689581
## ExAng            1.4429492
## Oldpeak         10.9760397
## Slope            2.0876743
## Ca              21.2673363
## Thal            22.9585727
  • For random forests use randomForest with the default parameters. Generate the variable importance plot using varImpPlot and extract variable importance from the randomForest fitted object using the importance function.
# floor(sqrt(n_features))
heart_forest <- randomForest(as.factor(AHD) ~ . , 
                             data = heart_train, 
                             # mtry = n_features,
                             na.action = na.omit)
mean(heart_forest$err.rate[ , 1]) # OOB error is the first column of the matrix
## [1] 0.1812747
varImpPlot(heart_forest, main = "Variable importance plot (Random forest)")

importance(heart_forest)
##           MeanDecreaseGini
## Age              9.3909356
## Sex              3.3956807
## ChestPain       12.0217136
## RestBP           6.9772276
## Chol             7.5957596
## Fbs              0.8847653
## RestECG          1.6661209
## MaxHR           10.9485451
## ExAng            4.7193365
## Oldpeak         11.5671639
## Slope            3.3500874
## Ca              14.9527583
## Thal            14.1381410

Comparison of the (OOB/CV) classification errors

errs <- c(
  `Single Tree` = sum((heart_pred[ , 2] > 0.5) == (heart_test$AHD == 0)) / nrow(heart_test),
  `Bagging` = mean(heart_bagging$err.rate[ , 1]),
  `Random Forest` = mean(heart_forest$err.rate[ , 1])
)
data.frame(
  errs = errs,
  type = names(errs)
) |>
  ggplot(aes(x = reorder(type, -errs), y = errs)) +
  theme_minimal() +
  geom_point(size = 3) +
  geom_segment(aes(xend = reorder(type, -errs), yend = 0), 
               linetype = "dotted") +
  labs(x = NULL, y = "0-1 Loss") +
  coord_cartesian(ylim = c(0, 1))


Question 4: Boosting

  • For boosting use gbm with cv.folds=5 to perform 5-fold cross-validation, and set class.stratify.cv to AHD (heart disease outcome) so that cross-validation is performed stratifying by AHD. Plot the cross-validation error as a function of the boosting iteration/trees (the $cv.error component of the object returned by gbm) and determine whether additional boosting iterations are warranted. If so, run additional iterations with gbm.more (use the R help to check its syntax). Choose the optimal number of iterations. Use the summary.gbm function to generate the variable importance plot and extract variable importance/influence (summary.gbm does both). Generate 1D and 2D marginal plots with gbm.plot to assess the effect of the top three variables and their 2-way interactions.
heart_boost <- gbm(
  AHD ~ ., data = heart_train, 
  distribution = "bernoulli",
  cv.folds = 5, class.stratify.cv = TRUE,
  n.trees = 3000
)
plot(heart_boost$cv.error, type = "l", ylim = c(0, 2), 
     lwd = 2, col = "steelblue")
lines(heart_boost$train.error, col = "#FF69B4", lwd = 2)

# plot(heart_boost)
# ggplot version
data.frame(
  n_tree = seq.int(length(heart_boost$cv.error)),
  cv = heart_boost$cv.error,
  train = heart_boost$train.error
) |>
  ggplot(aes(x = n_tree)) +
  theme_minimal() +
  geom_line(aes(y = cv, colour = "5-Fold Cross Validation Error")) +
  geom_line(aes(y = train, colour = "Train")) +
  geom_vline(xintercept = which.min(heart_boost$cv.error), 
             linetype = "dotted", colour = "steelblue", 
             linewidth = .5) +
  annotate("text", 
           x = which.min(heart_boost$cv.error) + 10, 
           y =  .01,
           label = paste("Min CV error at", which.min(heart_boost$cv.error), "trees"),
           hjust = 0, vjust = 0, colour = "steelblue") +
  scale_colour_manual(values = c("steelblue", "#FF69B4")) +
  labs(colour = NULL, y = "Log loss") +
  theme(legend.position = "top")

heart_boost_opt <- gbm(
  AHD ~ ., data = heart_train, 
  distribution = "bernoulli",
  cv.folds = 5, class.stratify.cv = TRUE,
  n.trees = which.min(heart_boost$cv.error)
)
summary.gbm(heart_boost_opt)

##                 var    rel.inf
## Thal           Thal 27.2010273
## Ca               Ca 24.5304750
## ChestPain ChestPain 15.3669971
## Oldpeak     Oldpeak  9.5926901
## Age             Age  5.9828818
## Chol           Chol  4.9222494
## MaxHR         MaxHR  4.7552005
## RestBP       RestBP  2.5600230
## Sex             Sex  2.5183908
## ExAng         ExAng  2.1264810
## RestECG     RestECG  0.2330646
## Slope         Slope  0.2105194
## Fbs             Fbs  0.0000000
plot.gbm(heart_boost_opt, i.var = c("Thal", "Ca"))

plot.gbm(heart_boost_opt, i.var = c("Ca", "ChestPain"))

plot.gbm(heart_boost_opt, i.var = c("Thal", "ChestPain"))


Deliverables

  1. Questions 1-4 answered, pdf or html output uploaded to quercus