Link to source file

In lecture, we saw our first example of tree-based methods: regression trees. This discussion section will give you practice building regression trees using the built-in tools in R. The notes largely follow the lab in Section 8.3.2 in ISLR.

The two most commonly used packages in R for building regression trees are tree and rpart (short for “recursive partitioning”– remember that regression trees really just partition the data space). rpart seems to be eclipsing tree in terms of popularity, so we’ll use that. See here to read about the tree package if you’re interested. The lab in ISLR uses the tree package, if you want a nice example of how to use it.

The documentation for rpart is here. It should be available in your R installation, but if not, you can install it in the usual way.

library(rpart)
# See documentation with ?rpart.

The Boston Housing data set is a famous data set that contains median housing prices for approximately 500 towns in the Boston area (collected in the 70s– these houses are a lot more expensive, now). It is included in the MASS library. Let’s load it, and take a few minutes to read the documentation to make sure we understand what all the variables mean.

data(Boston)
?Boston

We’re going totry and predict the price, which is stored in Boston$medv, short for median value, the median price of all homes in the town, in thousands of dollars.

hist( Boston$medv)

Question: looking at this histogram, should we consider a variable transformation? Why or why not? Discuss.

You’ll see some analyses of this data set out there that take logarithms of these prices. We’ll leave them as is, if for no other reason than for keeping our analysis roughly similar to the worked example in your textbook, but this kind of thing is always worth considering when you have price or salary data.

Setting up the data

We’ll split the data into a train and test set, to demonstrate effects of overfitting. We’ll set aside one fifth of the data as test data, and the rest will be train.

test_inds <- sample(1:nrow(Boston), nrow(Boston)/5);
Boston.test <- Boston[test_inds,];
Boston.train <- Boston[-test_inds,];
head(Boston.test)
##        crim zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 262 0.53412 20  3.97    0 0.647 7.520 89.4 2.1398   5 264    13.0 388.37  7.26
## 335 0.03738  0  5.19    0 0.515 6.310 38.5 6.4584   5 224    20.2 389.40  6.75
## 249 0.16439 22  5.86    0 0.431 6.433 49.1 7.8265   7 330    19.1 374.71  9.52
## 287 0.01965 80  1.76    0 0.385 6.230 31.5 9.0892   1 241    18.2 341.60 12.93
## 42  0.12744  0  6.91    0 0.448 6.770  2.9 5.7209   3 233    17.9 385.41  4.84
## 452 5.44114  0 18.10    0 0.713 6.655 98.2 2.3552  24 666    20.2 355.29 17.73
##     medv
## 262 43.1
## 335 20.7
## 249 24.5
## 287 20.1
## 42  26.6
## 452 15.2
head(Boston.train)
##      crim   zn indus chas   nox    rm  age    dis rad tax ptratio  black lstat
## 1 0.00632 18.0  2.31    0 0.538 6.575 65.2 4.0900   1 296    15.3 396.90  4.98
## 2 0.02731  0.0  7.07    0 0.469 6.421 78.9 4.9671   2 242    17.8 396.90  9.14
## 4 0.03237  0.0  2.18    0 0.458 6.998 45.8 6.0622   3 222    18.7 394.63  2.94
## 5 0.06905  0.0  2.18    0 0.458 7.147 54.2 6.0622   3 222    18.7 396.90  5.33
## 6 0.02985  0.0  2.18    0 0.458 6.430 58.7 6.0622   3 222    18.7 394.12  5.21
## 7 0.08829 12.5  7.87    0 0.524 6.012 66.6 5.5605   5 311    15.2 395.60 12.43
##   medv
## 1 24.0
## 2 21.6
## 4 33.4
## 5 36.2
## 6 28.7
## 7 22.9

Fitting the tree

Fitting a regression tree with rpart is largely analogous to fitting logistic regression with glm. We specify a formula, a data set, and a method (really a loss function).

rt <- rpart( medv ~ ., data=Boston.train, method='anova')
# method='anova' tells rpart to use the squared errors loss we saw in lecture.

# If we just naively print out the tree, we see a text representation of its node splits.
rt
## n= 405 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 405 32961.0200 22.31309  
##    2) lstat>=9.725 241  5911.4170 17.37095  
##      4) lstat>=19.23 67  1280.6070 12.44776  
##        8) nox>=0.603 49   508.2253 10.72245 *
##        9) nox< 0.603 18   229.4644 17.14444 *
##      5) lstat< 19.23 174  2381.5670 19.26667  
##       10) lstat>=15 68   739.9024 17.02353 *
##       11) lstat< 15 106  1080.0170 20.70566 *
##    3) lstat< 9.725 164 12513.2000 29.57561  
##      6) rm< 6.941 114  3578.7230 25.53596  
##       12) age< 90.25 107  1642.4970 24.74019  
##         24) rm< 6.531 67   619.3346 22.89104 *
##         25) rm>=6.531 40   410.3338 27.83750 *
##       13) age>=90.25 7   832.7200 37.70000 *
##      7) rm>=6.941 50  2832.5800 38.78600  
##       14) rm< 7.437 28   622.8011 34.33214 *
##       15) rm>=7.437 22   947.4345 44.45455 *

Okay, you can probably figure out what’s going on there if you stare at it enough, but I would much rather just look at the tree itself…

plot(rt)

Okay, but where are the labels? If you think back to our discussion of dendrograms and hierarchical clustering, you’ll recall that R is generally bad at handling trees. You need to tell R’s plotting functions about the text labels separately.

# uniform-=TRUE helps to prevent the node labels from bumping up against the
# each other.
plot(rt, uniform=TRUE, main='Regression tree for Boston housing prices')
text(rt, all=TRUE, use.n = TRUE, cex=.8 )

Okay, still hard to read , but hopefully you get the idea by looking at this plot and the node splits printed out by R.

rt
## n= 405 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 405 32961.0200 22.31309  
##    2) lstat>=9.725 241  5911.4170 17.37095  
##      4) lstat>=19.23 67  1280.6070 12.44776  
##        8) nox>=0.603 49   508.2253 10.72245 *
##        9) nox< 0.603 18   229.4644 17.14444 *
##      5) lstat< 19.23 174  2381.5670 19.26667  
##       10) lstat>=15 68   739.9024 17.02353 *
##       11) lstat< 15 106  1080.0170 20.70566 *
##    3) lstat< 9.725 164 12513.2000 29.57561  
##      6) rm< 6.941 114  3578.7230 25.53596  
##       12) age< 90.25 107  1642.4970 24.74019  
##         24) rm< 6.531 67   619.3346 22.89104 *
##         25) rm>=6.531 40   410.3338 27.83750 *
##       13) age>=90.25 7   832.7200 37.70000 *
##      7) rm>=6.941 50  2832.5800 38.78600  
##       14) rm< 7.437 28   622.8011 34.33214 *
##       15) rm>=7.437 22   947.4345 44.45455 *

Did we overfit?

Okay, we fit a tree to the training data. Let’s see how we did. First of all, let’s check our MSE on the training data. We’ll compute our model’s predicted output on each of the training instances. The predict function works with the rpart objects just like with other models you’ve seen this semester.

yhat_train <-predict( rt, Boston.train);
train_resids <- yhat_train - Boston.train$medv; # compute the residuals
# Now square them and take the mean to get MSE.
# Note that if we just computed the RSS (i.e., no mean, just a sum), it would be
# harder to directly compare to the test set, which is of a different size.
mean(train_resids^2)
## [1] 14.7907

Now do the test set.

Modify the code above to compute the MSE of our regression tree on the test set.

#TODO: code goes here.

Unless something really weird happened, you should see that the RSS is a lot worse. That’s a lot worse. Looks like we might have overfit…

Pruning the tree

Lucky for us, regression trees have a natural regularization method– we can make the tree less bushy by pruning it. In essence, this serves to “smooth out” the piecewise constant function that our tree encodes. We do that with the prune function. Note that there are other functions called prune (you’ll see this if you type ?prune), so you’re best off specifying that you mean rpart::prune.

To prune a tree, we just pass our tree into this function, and specify the complexity parameter cp. Smaller values of this number correspond to more splits, i.e., more complicated models.

So let’s first make sure we understand that parameter. Every rpart object has an attribute called cptable, which contains information that shows how the number of splits (again, that’s the number of nodes) varies with this cp parameter. Each cp value corresponds to the largest cp value that corresponds to this complexity– any larger and you can’t afford enough splits with your complexity “budget”. Note that the table also includes information about each model’s error, as well as information for doing CV (see Section 4 of the rpart long-form documentation; we return to this point briefly below ).

rt$cptable
##           CP nsplit rel error    xerror       xstd
## 1 0.44101794      0 1.0000000 1.0081323 0.09436808
## 2 0.18512472      1 0.5589821 0.6586532 0.06538595
## 3 0.06823948      2 0.3738573 0.4914515 0.05572899
## 4 0.03829810      3 0.3056179 0.4310189 0.05381044
## 5 0.03347910      4 0.2673198 0.4135346 0.05543222
## 6 0.01859253      5 0.2338407 0.3944587 0.05543053
## 7 0.01703975      6 0.2152481 0.3751225 0.05330331
## 8 0.01647150      7 0.1982084 0.3727647 0.05331989
## 9 0.01000000      8 0.1817369 0.3461705 0.05294787

As cp gets smaller, the number of splits (nsplit) decreases. If we ask for a tree with a particular cp value, rpart will go to our tree’s cptable and pick out a tree with the largest number of splits that is no more complicated that our cp score (i.e., no smaller). So, for example, if we ask for cp=0.4, we just get back the tree with a single split (i.e., a root node with two children):

plot( rpart::prune(rt, cp=0.4) )

If we set cp=0.03, we get back 4 splits:

plot( rpart::prune(rt, cp=0.03) )

Looking at that tree, you can see that there are indeed four splits (i.e., four nodes– the root, the one on the right, and the two to the left of the root).

Now, we can use this to do CV just like we’ve seen previously (if we had \(K\) folds instead of just a train-test split), or we could use the “1-SD” rule, about which see the end of section 6.1 in your textbook or refer to Section 4 of the rpart documentation.

Just for the sake of illustration, though, let’s just try something reasonable and use the 4-split tree.

Modify the code above to measure the MSE of this pruned tree on the training and test data. Assess the presence or absence of over-fitting. Is the overfitting at least less bad than the tree we started with?

# TODO: code goes here.

Do the same with the 5-split tree and compare to four splits.

#TODO: code goes here.