In lecture, we saw our first example of tree-based methods: regression trees. This discussion section will give you practice building regression trees using the built-in tools in R. The notes largely follow the lab in Section 8.3.2 in ISLR.
The two most commonly used packages in R for building regression trees are tree
and rpart
(short for “recursive partitioning”– remember that regression trees really just partition the data space). rpart
seems to be eclipsing tree
in terms of popularity, so we’ll use that. See here to read about the tree
package if you’re interested. The lab in ISLR uses the tree
package, if you want a nice example of how to use it.
The documentation for rpart
is here. It should be available in your R installation, but if not, you can install it in the usual way.
library(rpart)
# See documentation with ?rpart.
The Boston Housing data set is a famous data set that contains median housing prices for approximately 500 towns in the Boston area (collected in the 70s– these houses are a lot more expensive, now). It is included in the MASS
library. Let’s load it, and take a few minutes to read the documentation to make sure we understand what all the variables mean.
data(Boston)
?Boston
We’re going totry and predict the price, which is stored in Boston$medv
, short for median value, the median price of all homes in the town, in thousands of dollars.
hist( Boston$medv)
Question: looking at this histogram, should we consider a variable transformation? Why or why not? Discuss.
You’ll see some analyses of this data set out there that take logarithms of these prices. We’ll leave them as is, if for no other reason than for keeping our analysis roughly similar to the worked example in your textbook, but this kind of thing is always worth considering when you have price or salary data.
We’ll split the data into a train and test set, to demonstrate effects of overfitting. We’ll set aside one fifth of the data as test data, and the rest will be train.
test_inds <- sample(1:nrow(Boston), nrow(Boston)/5);
Boston.test <- Boston[test_inds,];
Boston.train <- Boston[-test_inds,];
head(Boston.test)
## crim zn indus chas nox rm age dis rad tax ptratio black lstat
## 262 0.53412 20 3.97 0 0.647 7.520 89.4 2.1398 5 264 13.0 388.37 7.26
## 335 0.03738 0 5.19 0 0.515 6.310 38.5 6.4584 5 224 20.2 389.40 6.75
## 249 0.16439 22 5.86 0 0.431 6.433 49.1 7.8265 7 330 19.1 374.71 9.52
## 287 0.01965 80 1.76 0 0.385 6.230 31.5 9.0892 1 241 18.2 341.60 12.93
## 42 0.12744 0 6.91 0 0.448 6.770 2.9 5.7209 3 233 17.9 385.41 4.84
## 452 5.44114 0 18.10 0 0.713 6.655 98.2 2.3552 24 666 20.2 355.29 17.73
## medv
## 262 43.1
## 335 20.7
## 249 24.5
## 287 20.1
## 42 26.6
## 452 15.2
head(Boston.train)
## crim zn indus chas nox rm age dis rad tax ptratio black lstat
## 1 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
## 2 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
## 4 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
## 5 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
## 6 0.02985 0.0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21
## 7 0.08829 12.5 7.87 0 0.524 6.012 66.6 5.5605 5 311 15.2 395.60 12.43
## medv
## 1 24.0
## 2 21.6
## 4 33.4
## 5 36.2
## 6 28.7
## 7 22.9
Fitting a regression tree with rpart
is largely analogous to fitting logistic regression with glm
. We specify a formula, a data set, and a method (really a loss function).
rt <- rpart( medv ~ ., data=Boston.train, method='anova')
# method='anova' tells rpart to use the squared errors loss we saw in lecture.
# If we just naively print out the tree, we see a text representation of its node splits.
rt
## n= 405
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 405 32961.0200 22.31309
## 2) lstat>=9.725 241 5911.4170 17.37095
## 4) lstat>=19.23 67 1280.6070 12.44776
## 8) nox>=0.603 49 508.2253 10.72245 *
## 9) nox< 0.603 18 229.4644 17.14444 *
## 5) lstat< 19.23 174 2381.5670 19.26667
## 10) lstat>=15 68 739.9024 17.02353 *
## 11) lstat< 15 106 1080.0170 20.70566 *
## 3) lstat< 9.725 164 12513.2000 29.57561
## 6) rm< 6.941 114 3578.7230 25.53596
## 12) age< 90.25 107 1642.4970 24.74019
## 24) rm< 6.531 67 619.3346 22.89104 *
## 25) rm>=6.531 40 410.3338 27.83750 *
## 13) age>=90.25 7 832.7200 37.70000 *
## 7) rm>=6.941 50 2832.5800 38.78600
## 14) rm< 7.437 28 622.8011 34.33214 *
## 15) rm>=7.437 22 947.4345 44.45455 *
Okay, you can probably figure out what’s going on there if you stare at it enough, but I would much rather just look at the tree itself…
plot(rt)
Okay, but where are the labels? If you think back to our discussion of dendrograms and hierarchical clustering, you’ll recall that R is generally bad at handling trees. You need to tell R’s plotting functions about the text labels separately.
# uniform-=TRUE helps to prevent the node labels from bumping up against the
# each other.
plot(rt, uniform=TRUE, main='Regression tree for Boston housing prices')
text(rt, all=TRUE, use.n = TRUE, cex=.8 )
Okay, still hard to read , but hopefully you get the idea by looking at this plot and the node splits printed out by R.
rt
## n= 405
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 405 32961.0200 22.31309
## 2) lstat>=9.725 241 5911.4170 17.37095
## 4) lstat>=19.23 67 1280.6070 12.44776
## 8) nox>=0.603 49 508.2253 10.72245 *
## 9) nox< 0.603 18 229.4644 17.14444 *
## 5) lstat< 19.23 174 2381.5670 19.26667
## 10) lstat>=15 68 739.9024 17.02353 *
## 11) lstat< 15 106 1080.0170 20.70566 *
## 3) lstat< 9.725 164 12513.2000 29.57561
## 6) rm< 6.941 114 3578.7230 25.53596
## 12) age< 90.25 107 1642.4970 24.74019
## 24) rm< 6.531 67 619.3346 22.89104 *
## 25) rm>=6.531 40 410.3338 27.83750 *
## 13) age>=90.25 7 832.7200 37.70000 *
## 7) rm>=6.941 50 2832.5800 38.78600
## 14) rm< 7.437 28 622.8011 34.33214 *
## 15) rm>=7.437 22 947.4345 44.45455 *
Okay, we fit a tree to the training data. Let’s see how we did. First of all, let’s check our MSE on the training data. We’ll compute our model’s predicted output on each of the training instances. The predict
function works with the rpart
objects just like with other models you’ve seen this semester.
yhat_train <-predict( rt, Boston.train);
train_resids <- yhat_train - Boston.train$medv; # compute the residuals
# Now square them and take the mean to get MSE.
# Note that if we just computed the RSS (i.e., no mean, just a sum), it would be
# harder to directly compare to the test set, which is of a different size.
mean(train_resids^2)
## [1] 14.7907
Now do the test set.
Modify the code above to compute the MSE of our regression tree on the test set.
#TODO: code goes here.
Unless something really weird happened, you should see that the RSS is a lot worse. That’s a lot worse. Looks like we might have overfit…
Lucky for us, regression trees have a natural regularization method– we can make the tree less bushy by pruning it. In essence, this serves to “smooth out” the piecewise constant function that our tree encodes. We do that with the prune
function. Note that there are other functions called prune
(you’ll see this if you type ?prune
), so you’re best off specifying that you mean rpart::prune
.
To prune a tree, we just pass our tree into this function, and specify the complexity parameter cp
. Smaller values of this number correspond to more splits, i.e., more complicated models.
So let’s first make sure we understand that parameter. Every rpart
object has an attribute called cptable
, which contains information that shows how the number of splits (again, that’s the number of nodes) varies with this cp
parameter. Each cp
value corresponds to the largest cp
value that corresponds to this complexity– any larger and you can’t afford enough splits with your complexity “budget”. Note that the table also includes information about each model’s error, as well as information for doing CV (see Section 4 of the rpart
long-form documentation; we return to this point briefly below ).
rt$cptable
## CP nsplit rel error xerror xstd
## 1 0.44101794 0 1.0000000 1.0081323 0.09436808
## 2 0.18512472 1 0.5589821 0.6586532 0.06538595
## 3 0.06823948 2 0.3738573 0.4914515 0.05572899
## 4 0.03829810 3 0.3056179 0.4310189 0.05381044
## 5 0.03347910 4 0.2673198 0.4135346 0.05543222
## 6 0.01859253 5 0.2338407 0.3944587 0.05543053
## 7 0.01703975 6 0.2152481 0.3751225 0.05330331
## 8 0.01647150 7 0.1982084 0.3727647 0.05331989
## 9 0.01000000 8 0.1817369 0.3461705 0.05294787
As cp
gets smaller, the number of splits (nsplit
) decreases. If we ask for a tree with a particular cp
value, rpart
will go to our tree’s cptable
and pick out a tree with the largest number of splits that is no more complicated that our cp
score (i.e., no smaller). So, for example, if we ask for cp=0.4
, we just get back the tree with a single split (i.e., a root node with two children):
plot( rpart::prune(rt, cp=0.4) )
If we set cp=0.03
, we get back 4 splits:
plot( rpart::prune(rt, cp=0.03) )
Looking at that tree, you can see that there are indeed four splits (i.e., four nodes– the root, the one on the right, and the two to the left of the root).
Now, we can use this to do CV just like we’ve seen previously (if we had \(K\) folds instead of just a train-test split), or we could use the “1-SD” rule, about which see the end of section 6.1 in your textbook or refer to Section 4 of the rpart
documentation.
Just for the sake of illustration, though, let’s just try something reasonable and use the 4-split tree.
Modify the code above to measure the MSE of this pruned tree on the training and test data. Assess the presence or absence of over-fitting. Is the overfitting at least less bad than the tree we started with?
# TODO: code goes here.
Do the same with the 5-split tree and compare to four splits.
#TODO: code goes here.