heart
data.For this lab we will be working with simulated data and the
heart
dataset that you can download from here
You should install and load rpart
(trees),
randomForest
(random forest), gbm
(gradient
boosting) and xgboost
(extreme gradient boosting).
install.packages(c("rpart", "rpart.plot", "randomForest", "gbm", "xgboost"))
library(tidyverse)
library(rpart)
library(rpart.plot)
library(randomForest)
library(gbm)
# library(xgboost)
heart <- read.csv("https://raw.githubusercontent.com/JSC370/jsc370-2023/main/data/heart/heart.csv") |>
mutate(
AHD = 1 * (AHD == "Yes"),
ChestPain = factor(ChestPain),
Thal = factor(Thal)
)
head(heart)
## Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca
## 1 63 1 typical 145 233 1 2 150 0 2.3 3 0
## 2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3
## 3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2
## 4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0
## 5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0
## 6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0
## Thal AHD
## 1 fixed 0
## 2 normal 1
## 3 reversable 1
## 4 normal 0
## 5 normal 0
## 6 normal 0
\[Y_i = \sin(X_i) + \varepsilon_i\]
\[\varepsilon_i \sim N(0, 0.5^2)\] for \(i=1,\ldots,1000\).
set.seed(1984)
n <- 1000
x <- runif(n, -5,5) # uniform distribution [-5, 5]
error <- rnorm(n, sd = 0.5)
y <- sin(x) + error
nonlin <- data.frame(y = y, x = x)
train_size <- sample(1:1000, size = 500)
nonlin_train <- nonlin[train_size,]
nonlin_test <- nonlin[-train_size,]
ggplot(nonlin, aes(y = y,x = x)) +
geom_point(alpha = .3) +
theme_minimal()
# method ='anova' indicates regression tree, cp = 0 ensures that binary recursive partitioning will not stop early due to lack of improvement in RSS by an amount of at least cp
treefit <- rpart(y ~ x, data = nonlin_train, method = "anova",
control = list(cp = 0))
# note: the height of the branches are proportional to the improvement in RSS
rpart.plot(treefit)
# plot the cp relative error to determine the optimal complexity parameter
plotcp(treefit)
treefit$cptable
## CP nsplit rel error xerror xstd
## 1 0.1946974705 0 1.0000000 1.0040153 0.05021945
## 2 0.0340611345 3 0.4159076 0.4538104 0.02638272
## 3 0.0273934137 4 0.3818465 0.4005153 0.02381225
## 4 0.0189166164 5 0.3544530 0.3935674 0.02372447
## 5 0.0187413385 6 0.3355364 0.3960865 0.02397866
## 6 0.0121323195 7 0.3167951 0.3733062 0.02312415
## 7 0.0117742571 8 0.3046628 0.3603378 0.02145075
## 8 0.0106806598 9 0.2928885 0.3578205 0.02147093
## 9 0.0073025938 10 0.2822078 0.3534446 0.02167188
## 10 0.0022726695 11 0.2749053 0.3370788 0.02130490
## 11 0.0018682275 13 0.2703599 0.3422712 0.02175280
## 12 0.0017155126 14 0.2684917 0.3401212 0.02164749
## 13 0.0016638943 16 0.2650607 0.3407474 0.02142361
## 14 0.0012684940 17 0.2633968 0.3441431 0.02192873
## 15 0.0012380829 19 0.2608598 0.3455235 0.02190487
## 16 0.0012185543 20 0.2596217 0.3465451 0.02202560
## 17 0.0011560361 21 0.2584031 0.3480298 0.02223123
## 18 0.0011186278 23 0.2560911 0.3482530 0.02221418
## 19 0.0011053218 24 0.2549724 0.3500691 0.02223374
## 20 0.0010825391 25 0.2538671 0.3500691 0.02223374
## 21 0.0008734059 26 0.2527846 0.3538991 0.02235048
## 22 0.0008675315 27 0.2519112 0.3603456 0.02281530
## 23 0.0008005684 28 0.2510436 0.3610295 0.02282622
## 24 0.0007660162 29 0.2502431 0.3612981 0.02270867
## 25 0.0007606718 30 0.2494771 0.3613280 0.02274626
## 26 0.0007542690 31 0.2487164 0.3592501 0.02241081
## 27 0.0007398102 33 0.2472079 0.3591680 0.02233911
## 28 0.0006485116 34 0.2464680 0.3593999 0.02233074
## 29 0.0006005097 36 0.2451710 0.3609370 0.02255022
## 30 0.0004515101 37 0.2445705 0.3623265 0.02269062
## 31 0.0004173810 38 0.2441190 0.3650851 0.02274125
## 32 0.0002682229 39 0.2437016 0.3648361 0.02268516
## 33 0.0001455512 40 0.2434334 0.3653993 0.02276076
## 34 0.0000000000 41 0.2432878 0.3646539 0.02275065
cp_summary <- treefit$cptable # $cptable extracts the table directly
optimal_cp <- cp_summary[9, 1]
tree_pruned <- prune(treefit, cp = optimal_cp)
# plot the pruned tree
rpart.plot(tree_pruned)
# summarize the pruned tree object and relate the summary to the plotted tree above
summary(tree_pruned)
## Call:
## rpart(formula = y ~ x, data = nonlin_train, method = "anova",
## control = list(cp = 0))
## n= 500
##
## CP nsplit rel error xerror xstd
## 1 0.194697471 0 1.0000000 1.0040153 0.05021945
## 2 0.034061135 3 0.4159076 0.4538104 0.02638272
## 3 0.027393414 4 0.3818465 0.4005153 0.02381225
## 4 0.018916616 5 0.3544530 0.3935674 0.02372447
## 5 0.018741339 6 0.3355364 0.3960865 0.02397866
## 6 0.012132319 7 0.3167951 0.3733062 0.02312415
## 7 0.011774257 8 0.3046628 0.3603378 0.02145075
## 8 0.010680660 9 0.2928885 0.3578205 0.02147093
## 9 0.007302594 10 0.2822078 0.3534446 0.02167188
##
## Variable importance
## x
## 100
##
## Node number 1: 500 observations, complexity param=0.1946975
## mean=-0.1147421, MSE=0.8380283
## left son=2 (436 obs) right son=3 (64 obs)
## Primary splits:
## x < -3.484362 to the right, improve=0.1407386, (0 missing)
##
## Node number 2: 436 observations, complexity param=0.1946975
## mean=-0.2463199, MSE=0.7852125
## left son=4 (193 obs) right son=5 (243 obs)
## Primary splits:
## x < -0.05508361 to the left, improve=0.1965717, (0 missing)
##
## Node number 3: 64 observations, complexity param=0.01213232
## mean=0.781632, MSE=0.2764066
## left son=6 (43 obs) right son=7 (21 obs)
## Primary splits:
## x < -4.612867 to the right, improve=0.2873718, (0 missing)
##
## Node number 4: 193 observations, complexity param=0.03406113
## mean=-0.6871574, MSE=0.3678353
## left son=8 (169 obs) right son=9 (24 obs)
## Primary splits:
## x < -2.955229 to the right, improve=0.2010375, (0 missing)
##
## Node number 5: 243 observations, complexity param=0.1946975
## mean=0.1038103, MSE=0.839768
## left son=10 (98 obs) right son=11 (145 obs)
## Primary splits:
## x < 3.133467 to the right, improve=0.5805772, (0 missing)
##
## Node number 6: 43 observations
## mean=0.5846748, MSE=0.2062735
##
## Node number 7: 21 observations
## mean=1.184925, MSE=0.1779356
##
## Node number 8: 169 observations, complexity param=0.01177426
## mean=-0.7896347, MSE=0.3017005
## left son=16 (121 obs) right son=17 (48 obs)
## Primary splits:
## x < -0.8897254 to the left, improve=0.09676081, (0 missing)
##
## Node number 9: 24 observations
## mean=0.03445331, MSE=0.2388637
##
## Node number 10: 98 observations, complexity param=0.01874134
## mean=-0.7455278, MSE=0.3121068
## left son=20 (69 obs) right son=21 (29 obs)
## Primary splits:
## x < 3.814829 to the right, improve=0.2567438, (0 missing)
##
## Node number 11: 145 observations, complexity param=0.02739341
## mean=0.6778457, MSE=0.3793274
## left son=22 (21 obs) right son=23 (124 obs)
## Primary splits:
## x < 0.4782685 to the left, improve=0.2086857, (0 missing)
##
## Node number 16: 121 observations, complexity param=0.01068066
## mean=-0.8972479, MSE=0.3048796
## left son=32 (71 obs) right son=33 (50 obs)
## Primary splits:
## x < -2.103778 to the right, improve=0.1213146, (0 missing)
##
## Node number 17: 48 observations
## mean=-0.5183597, MSE=0.1909035
##
## Node number 20: 69 observations
## mean=-0.9290447, MSE=0.2562326
##
## Node number 21: 29 observations
## mean=-0.3088841, MSE=0.1742598
##
## Node number 22: 21 observations
## mean=-0.005837036, MSE=0.2249018
##
## Node number 23: 124 observations, complexity param=0.01891662
## mean=0.7936307, MSE=0.3129137
## left son=46 (26 obs) right son=47 (98 obs)
## Primary splits:
## x < 2.54086 to the right, improve=0.20428, (0 missing)
##
## Node number 32: 71 observations
## mean=-1.058638, MSE=0.2497405
##
## Node number 33: 50 observations
## mean=-0.6680742, MSE=0.2936702
##
## Node number 46: 26 observations
## mean=0.3027775, MSE=0.233719
##
## Node number 47: 98 observations
## mean=0.9238571, MSE=0.2530437
# extract and sort x values at split points
x_splits <- sort(tree_pruned$splits[ , "index"])
# extract the corresponding y values by passing in the x values
y_splits <- predict(tree_pruned, data.frame(x = c(-999, x_splits)))
data.frame(x_splits_lower = c(-999, x_splits), y_fitted = y_splits)
## x_splits_lower y_fitted
## 1 -999.00000000 1.184925143
## 2 -4.61286694 0.584674844
## 3 -3.48436187 0.034453315
## 4 -2.95522928 -0.668074168
## 5 -2.10377810 -1.058637893
## 6 -0.88972540 -0.518359653
## 7 -0.05508361 -0.005837036
## 8 0.47826850 0.923857058
## 9 2.54085971 0.302777531
## 10 3.13346687 -0.308884110
## 11 3.81482886 -0.929044682
# using base R plot functions
stpfn <- stepfun(x_splits, y_splits)
plot(y ~ x, data = nonlin_train, pch = 20)
plot(stpfn, add = TRUE, lwd = 2, col = '#FF69B4', pch = 19)
## ggplot version
step_df <- data.frame(
# adding trailing lines at both ends
x = c(-999, x_splits, 999),
y = c(y_splits, y_splits[length(y_splits)])
)
ggplot(mapping = aes(x = x, y = y)) +
theme_minimal() +
geom_point(data = nonlin_train, alpha = .5) +
geom_step(data = step_df, colour = "#FF69B4") +
geom_point(data = step_df, colour = "#FF69B4", size = 2) +
coord_cartesian(xlim = c(-5, 5))
# base R
lmfit <- lm(y ~ x, data = nonlin_train)
plot(y ~ x, data = nonlin_train, pch = 20, col = "darkgrey")
abline(lmfit, col = "steelblue", lwd = 2)
plot(stpfn, add = TRUE, lwd = 2, col = '#FF69B4', pch = 19)
# ggplot
ggplot(mapping = aes(x = x, y = y)) +
theme_minimal() +
geom_point(data = nonlin_train, alpha = .5, colour = "darkgrey") +
geom_smooth(data = nonlin_train, colour = "steelblue",
method = "lm", formula = y ~ x, se = FALSE) +
geom_step(data = step_df, colour = "#FF69B4") +
geom_point(data = step_df, colour = "#FF69B4", size = 2) +
coord_cartesian(xlim = c(-5, 5))
Write about the differences
\[MSE = \frac{1}{n} \sum_{i=1}^n \left(\hat{y}_i-y_i\right)^2\]
tree_pred <- predict(tree_pruned, nonlin_test)
lm_pred <- predict(lmfit, nonlin_test)
tibble(tree_pred, lm_pred, y = nonlin_test$y) |>
summarise(
tree_mse = sum((tree_pred - y)^2) / n(),
lm_mse = sum((lm_pred - y)^2) / n()
)
## # A tibble: 1 × 2
## tree_mse lm_mse
## <dbl> <dbl>
## 1 0.319 0.785
Write about the differences
heart
data into training and testing
(70-30%)glimpse(heart)
## Rows: 303
## Columns: 14
## $ Age <int> 63, 67, 67, 37, 41, 56, 62, 57, 63, 53, 57, 56, 56, 44, 52, …
## $ Sex <int> 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, …
## $ ChestPain <fct> typical, asymptomatic, asymptomatic, nonanginal, nontypical,…
## $ RestBP <int> 145, 160, 120, 130, 130, 120, 140, 120, 130, 140, 140, 140, …
## $ Chol <int> 233, 286, 229, 250, 204, 236, 268, 354, 254, 203, 192, 294, …
## $ Fbs <int> 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, …
## $ RestECG <int> 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, …
## $ MaxHR <int> 150, 108, 129, 187, 172, 178, 160, 163, 147, 155, 148, 153, …
## $ ExAng <int> 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Oldpeak <dbl> 2.3, 1.5, 2.6, 3.5, 1.4, 0.8, 3.6, 0.6, 1.4, 3.1, 0.4, 1.3, …
## $ Slope <int> 3, 2, 2, 3, 1, 1, 3, 1, 2, 3, 2, 2, 2, 1, 1, 1, 3, 1, 1, 1, …
## $ Ca <int> 0, 3, 2, 0, 0, 0, 2, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, …
## $ Thal <fct> fixed, normal, reversable, normal, normal, normal, normal, n…
## $ AHD <dbl> 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, …
set.seed(2023)
train <- sample(1:nrow(heart), round(0.7 * nrow(heart)))
heart_train <- heart[train, ]
heart_test <- heart[-train, ]
heart_tree <- rpart(
AHD ~ ., data = heart_train,
method = "class",
control = list(minsplits = 10, minbucket = 3, cp = 0, xval = 10)
)
rpart.plot(heart_tree)
# plot the cp relative error to determine the optimal complexity parameter
plotcp(heart_tree)
heart_cp_summary <- heart_tree$cptable # $cptable extracts the table directly
# picking the min xerror (the tree isn't too large to begin with)
heart_optimal_cp <- heart_cp_summary[4, 1]
heart_tree_pruned <- prune(heart_tree, cp = heart_optimal_cp)
rpart.plot(heart_tree_pruned)
summary(heart_tree_pruned)
## Call:
## rpart(formula = AHD ~ ., data = heart_train, method = "class",
## control = list(minsplits = 10, minbucket = 3, cp = 0, xval = 10))
## n= 212
##
## CP nsplit rel error xerror xstd
## 1 0.520 0 1.00 1.17 0.07240804
## 2 0.050 1 0.48 0.55 0.06382095
## 3 0.045 3 0.38 0.51 0.06223434
## 4 0.010 5 0.29 0.48 0.06093609
##
## Variable importance
## Thal ChestPain MaxHR Ca ExAng Oldpeak Sex Age
## 24 16 12 12 10 9 8 4
## RestBP Slope
## 3 2
##
## Node number 1: 212 observations, complexity param=0.52
## predicted class=0 expected loss=0.4716981 P(node) =1
## class counts: 112 100
## probabilities: 0.528 0.472
## left son=2 (116 obs) right son=3 (96 obs)
## Primary splits:
## Thal splits as RLR, improve=32.81092, (2 missing)
## Ca < 0.5 to the left, improve=30.33955, (3 missing)
## ChestPain splits as RLLL, improve=28.48156, (0 missing)
## ExAng < 0.5 to the left, improve=19.59505, (0 missing)
## Oldpeak < 1.7 to the left, improve=18.61642, (0 missing)
## Surrogate splits:
## MaxHR < 147.5 to the right, agree=0.695, adj=0.326, (2 split)
## Sex < 0.5 to the left, agree=0.690, adj=0.316, (0 split)
## ExAng < 0.5 to the left, agree=0.681, adj=0.295, (0 split)
## ChestPain splits as RLLL, agree=0.671, adj=0.274, (0 split)
## Oldpeak < 1.7 to the left, agree=0.643, adj=0.211, (0 split)
##
## Node number 2: 116 observations, complexity param=0.05
## predicted class=0 expected loss=0.2241379 P(node) =0.5471698
## class counts: 90 26
## probabilities: 0.776 0.224
## left son=4 (82 obs) right son=5 (34 obs)
## Primary splits:
## Ca < 0.5 to the left, improve=8.882549, (1 missing)
## ChestPain splits as RLLL, improve=6.229038, (0 missing)
## Oldpeak < 2.1 to the left, improve=5.242263, (0 missing)
## MaxHR < 128.5 to the right, improve=4.249828, (0 missing)
## ExAng < 0.5 to the left, improve=4.132062, (0 missing)
## Surrogate splits:
## Age < 66.5 to the left, agree=0.748, adj=0.147, (1 split)
## MaxHR < 128.5 to the right, agree=0.739, adj=0.118, (0 split)
## Oldpeak < 1.7 to the left, agree=0.722, adj=0.059, (0 split)
## RestBP < 151 to the left, agree=0.713, adj=0.029, (0 split)
##
## Node number 3: 96 observations, complexity param=0.045
## predicted class=1 expected loss=0.2291667 P(node) =0.4528302
## class counts: 22 74
## probabilities: 0.229 0.771
## left son=6 (40 obs) right son=7 (56 obs)
## Primary splits:
## Ca < 0.5 to the left, improve=7.559679, (2 missing)
## ChestPain splits as RLLL, improve=7.540488, (0 missing)
## Oldpeak < 0.85 to the left, improve=5.989262, (0 missing)
## ExAng < 0.5 to the left, improve=3.234947, (0 missing)
## MaxHR < 143.5 to the right, improve=2.487101, (0 missing)
## Surrogate splits:
## Age < 53.5 to the left, agree=0.745, adj=0.385, (2 split)
## ChestPain splits as RRLL, agree=0.691, adj=0.256, (0 split)
## Oldpeak < 0.95 to the left, agree=0.660, adj=0.179, (0 split)
## MaxHR < 172 to the right, agree=0.638, adj=0.128, (0 split)
## RestBP < 111 to the left, agree=0.617, adj=0.077, (0 split)
##
## Node number 4: 82 observations
## predicted class=0 expected loss=0.09756098 P(node) =0.3867925
## class counts: 74 8
## probabilities: 0.902 0.098
##
## Node number 5: 34 observations, complexity param=0.05
## predicted class=1 expected loss=0.4705882 P(node) =0.1603774
## class counts: 16 18
## probabilities: 0.471 0.529
## left son=10 (18 obs) right son=11 (16 obs)
## Primary splits:
## ChestPain splits as RLLL, improve=4.843954, (0 missing)
## Sex < 0.5 to the left, improve=3.706089, (0 missing)
## Oldpeak < 0.85 to the left, improve=3.126891, (0 missing)
## Slope < 1.5 to the left, improve=3.126891, (0 missing)
## MaxHR < 119.5 to the right, improve=2.596349, (0 missing)
## Surrogate splits:
## Oldpeak < 0.85 to the left, agree=0.765, adj=0.500, (0 split)
## Slope < 1.5 to the left, agree=0.765, adj=0.500, (0 split)
## RestBP < 129 to the right, agree=0.735, adj=0.438, (0 split)
## MaxHR < 125.5 to the right, agree=0.735, adj=0.438, (0 split)
## ExAng < 0.5 to the left, agree=0.706, adj=0.375, (0 split)
##
## Node number 6: 40 observations, complexity param=0.045
## predicted class=1 expected loss=0.45 P(node) =0.1886792
## class counts: 18 22
## probabilities: 0.450 0.550
## left son=12 (17 obs) right son=13 (23 obs)
## Primary splits:
## ChestPain splits as RLLL, improve=5.856266, (0 missing)
## ExAng < 0.5 to the left, improve=4.150877, (0 missing)
## Oldpeak < 0.05 to the left, improve=4.113480, (0 missing)
## Age < 51.5 to the right, improve=3.200000, (0 missing)
## MaxHR < 144 to the right, improve=3.200000, (0 missing)
## Surrogate splits:
## ExAng < 0.5 to the left, agree=0.750, adj=0.412, (0 split)
## MaxHR < 144 to the right, agree=0.725, adj=0.353, (0 split)
## RestBP < 122 to the left, agree=0.650, adj=0.176, (0 split)
## Age < 51.5 to the right, agree=0.625, adj=0.118, (0 split)
## Oldpeak < 0.7 to the left, agree=0.625, adj=0.118, (0 split)
##
## Node number 7: 56 observations
## predicted class=1 expected loss=0.07142857 P(node) =0.2641509
## class counts: 4 52
## probabilities: 0.071 0.929
##
## Node number 10: 18 observations
## predicted class=0 expected loss=0.2777778 P(node) =0.08490566
## class counts: 13 5
## probabilities: 0.722 0.278
##
## Node number 11: 16 observations
## predicted class=1 expected loss=0.1875 P(node) =0.0754717
## class counts: 3 13
## probabilities: 0.188 0.812
##
## Node number 12: 17 observations
## predicted class=0 expected loss=0.2352941 P(node) =0.08018868
## class counts: 13 4
## probabilities: 0.765 0.235
##
## Node number 13: 23 observations
## predicted class=1 expected loss=0.2173913 P(node) =0.1084906
## class counts: 5 18
## probabilities: 0.217 0.783
\[Err = \frac{1}{n}\sum_{i=1}^nI\left(y_i\neq y\right)\]
heart_pred <- predict(heart_tree, heart_test)
sum((heart_pred[ , 2] > 0.5) == (heart_test$AHD == 0)) / nrow(heart_test)
## [1] 0.2197802
heart_tree_full <- rpart(
AHD ~ ., data = heart, # full data
method = "class",
control = list(minsplits = 10, minbucket = 3, xval = 10,
cp = heart_optimal_cp) # optimal cp
)
rpart.plot(heart_tree_full)
Compare the performance of classification trees (above), bagging,
random forests for predicting heart disease based on the
heart
data. Train each of the models on the training data
and extract the cross-validation (or out-of-bag error for bagging and
Random forest).
For bagging use randomForest
with mtry
equal to the number of features (all other parameters at their default
values). Generate the variable importance plot using
varImpPlot
and extract variable importance from the
randomForest
fitted object using the
importance
function.
n_features <- dim(heart_train)[2] - 1
heart_bagging <- randomForest(as.factor(AHD) ~ . ,
data = heart_train,
mtry = n_features,
na.action = na.omit)
mean(heart_bagging$err.rate[ , 1]) # OOB error is the first column of the matrix
## [1] 0.1949861
varImpPlot(heart_bagging, main = "Variable importance plot (Bagging)")
importance(heart_bagging)
## MeanDecreaseGini
## Age 8.1721939
## Sex 2.1393115
## ChestPain 10.9170764
## RestBP 6.0438493
## Chol 8.1692413
## Fbs 0.4373468
## RestECG 0.9730543
## MaxHR 7.0689581
## ExAng 1.4429492
## Oldpeak 10.9760397
## Slope 2.0876743
## Ca 21.2673363
## Thal 22.9585727
randomForest
with the default
parameters. Generate the variable importance plot using
varImpPlot
and extract variable importance from the
randomForest
fitted object using the
importance
function.# floor(sqrt(n_features))
heart_forest <- randomForest(as.factor(AHD) ~ . ,
data = heart_train,
# mtry = n_features,
na.action = na.omit)
mean(heart_forest$err.rate[ , 1]) # OOB error is the first column of the matrix
## [1] 0.1812747
varImpPlot(heart_forest, main = "Variable importance plot (Random forest)")
importance(heart_forest)
## MeanDecreaseGini
## Age 9.3909356
## Sex 3.3956807
## ChestPain 12.0217136
## RestBP 6.9772276
## Chol 7.5957596
## Fbs 0.8847653
## RestECG 1.6661209
## MaxHR 10.9485451
## ExAng 4.7193365
## Oldpeak 11.5671639
## Slope 3.3500874
## Ca 14.9527583
## Thal 14.1381410
Comparison of the (OOB/CV) classification errors
errs <- c(
`Single Tree` = sum((heart_pred[ , 2] > 0.5) == (heart_test$AHD == 0)) / nrow(heart_test),
`Bagging` = mean(heart_bagging$err.rate[ , 1]),
`Random Forest` = mean(heart_forest$err.rate[ , 1])
)
data.frame(
errs = errs,
type = names(errs)
) |>
ggplot(aes(x = reorder(type, -errs), y = errs)) +
theme_minimal() +
geom_point(size = 3) +
geom_segment(aes(xend = reorder(type, -errs), yend = 0),
linetype = "dotted") +
labs(x = NULL, y = "0-1 Loss") +
coord_cartesian(ylim = c(0, 1))
gbm
with cv.folds=5
to
perform 5-fold cross-validation, and set class.stratify.cv
to AHD
(heart disease outcome) so that cross-validation is
performed stratifying by AHD
. Plot the cross-validation
error as a function of the boosting iteration/trees (the
$cv.error
component of the object returned by
gbm
) and determine whether additional boosting iterations
are warranted. If so, run additional iterations with
gbm.more
(use the R help to check its syntax). Choose the
optimal number of iterations. Use the summary.gbm
function
to generate the variable importance plot and extract variable
importance/influence (summary.gbm
does both). Generate 1D
and 2D marginal plots with gbm.plot
to assess the effect of
the top three variables and their 2-way interactions.heart_boost <- gbm(
AHD ~ ., data = heart_train,
distribution = "bernoulli",
cv.folds = 5, class.stratify.cv = TRUE,
n.trees = 3000
)
plot(heart_boost$cv.error, type = "l", ylim = c(0, 2),
lwd = 2, col = "steelblue")
lines(heart_boost$train.error, col = "#FF69B4", lwd = 2)
# plot(heart_boost)
# ggplot version
data.frame(
n_tree = seq.int(length(heart_boost$cv.error)),
cv = heart_boost$cv.error,
train = heart_boost$train.error
) |>
ggplot(aes(x = n_tree)) +
theme_minimal() +
geom_line(aes(y = cv, colour = "5-Fold Cross Validation Error")) +
geom_line(aes(y = train, colour = "Train")) +
geom_vline(xintercept = which.min(heart_boost$cv.error),
linetype = "dotted", colour = "steelblue",
linewidth = .5) +
annotate("text",
x = which.min(heart_boost$cv.error) + 10,
y = .01,
label = paste("Min CV error at", which.min(heart_boost$cv.error), "trees"),
hjust = 0, vjust = 0, colour = "steelblue") +
scale_colour_manual(values = c("steelblue", "#FF69B4")) +
labs(colour = NULL, y = "Log loss") +
theme(legend.position = "top")
heart_boost_opt <- gbm(
AHD ~ ., data = heart_train,
distribution = "bernoulli",
cv.folds = 5, class.stratify.cv = TRUE,
n.trees = which.min(heart_boost$cv.error)
)
summary.gbm(heart_boost_opt)
## var rel.inf
## Thal Thal 27.2010273
## Ca Ca 24.5304750
## ChestPain ChestPain 15.3669971
## Oldpeak Oldpeak 9.5926901
## Age Age 5.9828818
## Chol Chol 4.9222494
## MaxHR MaxHR 4.7552005
## RestBP RestBP 2.5600230
## Sex Sex 2.5183908
## ExAng ExAng 2.1264810
## RestECG RestECG 0.2330646
## Slope Slope 0.2105194
## Fbs Fbs 0.0000000
plot.gbm(heart_boost_opt, i.var = c("Thal", "Ca"))
plot.gbm(heart_boost_opt, i.var = c("Ca", "ChestPain"))
plot.gbm(heart_boost_opt, i.var = c("Thal", "ChestPain"))