Readings: ISLR Section 8.1

In our last lectures of the semester, we are going to return to prediction and classification problems and see a very different set of tools from what we saw previously.

Learning objectives

After this lecture, you will be able to

Decision Trees and Classification

Suppose we have a collection of animals of four different types: Dogs, cats, lizards and snakes. A simple way to tell these different animals apart is with a decision tree.

A simple example of a decision tree for classifying animals.

Each node of a decision tree has a question associated to it, and the children of each node correspond to different answers to this question (most of the time we just take these to be yes/no questions, i.e., binary conditions we can check).

To classify an animal, we start at the root of the tree, and proceed down the tree to a leaf according to answers.

For example, say we want to classify this handsome fellow:

Credit: Wikipedia

  1. This animal doesn’t have fur, so the question at the root node is answered “No.” and we proceed along the “No” branch to the “Has legs?” node.
  2. This animal does have legs, so we proceed along the “Yes” branch to the leaf node labeled “Lizard”. We conclude that this animal is a lizard.

Observe that we could change the order that we ask our questions, and the tree changes. The same set of questions as our animal tree above, but reordered. Notice how the structure of the tree changes.

By the way, this should remind you a lot of the game 20 Questions. We narrow down the space of possibilities one yes/no question at a time.

When plain old regression doesn’t cut it

Well, that’s all fine and good, but how does 20 Questions help us do statistics?

Well, consider the following data (see the demo R file on the course webpage for the source file), which has univariate data x with binary labels y (e.g., cat vs not-cat image labels):

Binary responses y as a function of input x.

This data is saved in binary_response_demo.csv, available from the course github page for this lecture. Let’s try and predict this response with logistic regression.

demo_data <- read.csv('binary_response.demo.csv');

lr1 <- glm( y ~ 1 + x, data=demo_data, family=binomial );

summary(lr1)
## 
## Call:
## glm(formula = y ~ 1 + x, family = binomial, data = demo_data)
## 
## Deviance Residuals: 
##    Min      1Q  Median      3Q     Max  
## -1.232  -1.230   1.125   1.125   1.126  
## 
## Coefficients:
##               Estimate Std. Error z value Pr(>|z|)  
## (Intercept)  0.1241499  0.0633708   1.959   0.0501 .
## x           -0.0003017  0.0212652  -0.014   0.9887  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1382.4  on 999  degrees of freedom
## Residual deviance: 1382.4  on 998  degrees of freedom
## AIC: 1386.4
## 
## Number of Fisher Scoring iterations: 3

Now let’s assess our model’s prediction performance. Note that this is on the training data– if we can predict well anywhere it should be here!

probabilities <- predict(lr1, type='response', newdata = demo_data );
predictions <- ifelse(probabilities>0.5, 1, 0);
sum(predictions==demo_data$y)/nrow(demo_data)
## [1] 0.531

Just for comparison, here’s chance performance:

sum(demo_data$y==1)/nrow(demo_data)
## [1] 0.531

Just to drive home what’s going wrong, here, let’s look at our model’s output as a function of x and compare to the actual response y. We’ll plot our model’s predicted probability in blue.

plot(demo_data$x,demo_data$y)
xvals <- seq(min(demo_data$x), max(demo_data$x), 0.1);
# the predict function expects us to pass the new data in a data frame.
dfnew <- data.frame('x'=xvals)
yvals <- predict(lr1, type='response', newdata = dfnew );
lines(xvals,yvals,col='blue')

Hmm… looks like linear regression basically just learned to always predict \(0.55\) or something… What’s up with that?!

Well, linear regression expects that our data follows a model like \[ \Pr[ Y = 1 | X=x ] = \frac{ e^{\beta x} }{ 1 + e^{\beta x} }. \] But in our data, the responses are 1 when \(x \approx 0\) and the responses are 0 otherwise. There is no way for logistic regression to capture this pattern, at least without us transforming the variable x in some clever way (e.g., replacing it with its absolute value).

On the other hand, here’s a simple classification rule: \[ y = \begin{cases} 1 &\mbox{ if } |x| \le 2 \\ 0 &\mbox{ otherwise. } \end{cases} \]

In other words, we are slicing up the space of x values, and labeling some parts of the space with 1 and other parts with 0.

plot(demo_data$x,demo_data$y)
abline(v=c(-2,2), col='red', lwd=2); # Draw our classification boundaries.

Let’s implement that and compare its prediction accuracy to our broken logistic regression.

clever_rule <- function(x) {
  if( abs(x) <= 2 ) {
    return(1);
  } else{
    return(0);
  }
}
clever_rule <- Vectorize(clever_rule)
predictions <- clever_rule( demo_data$x )
sum(predictions==demo_data$y)/nrow(demo_data)
## [1] 0.924

Compare that to our accuracy of \(\approx 0.5\) using logistic regression. Pretty good!

In situations like this, classification rules can save us where regression would otherwise fail. The problem is how to learn the classification rule in the first place– we could just look at the data and eyeball that \(|x| \le 2\) was a good thing to check, but it’s not so simple in practice when our data can be far more complicated.

Regression Trees

So let’s take the intuition we just developed and extend it a bit– we just saw how to use a decision tree for classification, but it’s quite simple to extend the idea to give us a new way to perform regression.

The Hitters data set in the ISLR package is an illustrative example (also used in Section 8.1.1 of ISLR, which this example is based on). This data

library(ISLR)
head(Hitters) # See ?Hitters for more information.
##                   AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun
## -Andy Allanson      293   66     1   30  29    14     1    293    66      1
## -Alan Ashby         315   81     7   24  38    39    14   3449   835     69
## -Alvin Davis        479  130    18   66  72    76     3   1624   457     63
## -Andre Dawson       496  141    20   65  78    37    11   5628  1575    225
## -Andres Galarraga   321   87    10   39  42    30     2    396   101     12
## -Alfredo Griffin    594  169     4   74  51    35    11   4408  1133     19
##                   CRuns CRBI CWalks League Division PutOuts Assists Errors
## -Andy Allanson       30   29     14      A        E     446      33     20
## -Alan Ashby         321  414    375      N        W     632      43     10
## -Alvin Davis        224  266    263      A        W     880      82     14
## -Andre Dawson       828  838    354      N        E     200      11      3
## -Andres Galarraga    48   46     33      N        E     805      40      4
## -Alfredo Griffin    501  336    194      A        W     282     421     25
##                   Salary NewLeague
## -Andy Allanson        NA         A
## -Alan Ashby        475.0         N
## -Alvin Davis       480.0         A
## -Andre Dawson      500.0         N
## -Andres Galarraga   91.5         N
## -Alfredo Griffin   750.0         A

This data set has a whole bunch of columns, but let’s just try and predict Salary from Hits (number of hits a player got last season) and Years (the number of years they have been playing). If you are familiar with baseball, you already know the ways in which this could fail, but this is just for the purposes of illustration. We’re not trying to build the best possible model, here.

First things first, let’s look at our data. Here’s what salary looks like.

hist( Hitters$Salary )

Salary is highly right-skewed, so we’re going to apply a log transformation to try and correct that a bit, and predict log(Salary) instead of Salary. Of course, if we want to just predict Salary, we can always exponentiate to get Salary again.

hist( log( Hitters$Salary ) )

Hits and Runs are fairly well-behaved (check for yourself with hist, if you like), so we’ll leave them as is.

# There are some NAs in the Hitters data, so we're making sure to drop them.
# Tree-based methods like the ones we're going to talk about today actually
# have reasonably good tools available for handling NAs, but that's outside
# the scope of the lecture.
Hitters_subset <- na.omit( cbind( log(Hitters$Salary),
                                  Hitters$Hits,
                                  Hitters$Runs ) );
# Reset the column names, because cbind destroys them.
colnames(Hitters_subset) <- c( 'logSalary', 'Hits', 'Runs' );
summary( Hitters_subset )
##    logSalary          Hits            Runs       
##  Min.   :4.212   Min.   :  1.0   Min.   :  0.00  
##  1st Qu.:5.247   1st Qu.: 71.5   1st Qu.: 33.50  
##  Median :6.052   Median :103.0   Median : 52.00  
##  Mean   :5.927   Mean   :107.8   Mean   : 54.75  
##  3rd Qu.:6.620   3rd Qu.:141.5   3rd Qu.: 73.00  
##  Max.   :7.808   Max.   :238.0   Max.   :130.00

The idea behind regression trees is to create a decision tree just like we did above, but instead of labels at the leave (e.g., animal names), we have predictions (i.e., kind of like regression outputs, hence why this is called a regression tree). So for our Hitters data, here’s one possible tree we might build:

Credit: ISLR Figure 8.1

As before, we start at the root, which says to check if Years is smaller than 4.5. If it is, we go down the branch to the left. If Years is greater or equal to 4.5, we continue down the right branch. We continue this until we reach a leaf. Each leaf has a prediction of log(Salary) for all observations that reach that leaf.

Example: consider a hitter with 6 years in the league and 98 hits last season. What does our regression tree predict to be this player’s log-salary next year?

  1. We start at the root. 6 years is bigger than 4.5, so we proceed down the right branch.
  2. Then we reach the Hits < 117.5 node. 98 hits is less than 117.5, so we proceed down the left branch
  3. We end up at the leaf labeled 6.00, so this is our prediction for the log-salary of this player. \(e^6 \approx 403\), so we predict that this player makes about $400,000 dollars a year.

If you think for a minute, you’ll notice that this decision tree really just boils down to partitioning the input space as follows:

Credit: ISLR Figure 8.2

Any points that land in the region labeled R1 are predicted to have log-salary of 5.11. Any points landing in region R2 are predicted to have a log-salary of 6.0, and any points landing in region R3 are predicted to have a log-salary of 6.4.

Aside: Tree Regression vs Linear Regression

In essence, regression trees predict responses by learning a piecewise constant function.

Our simple examples above already shows the advantage and disadvantage of this compared to linear or logistic regression.

Linear and logistic regression can really only learn monotonic functions of the data, in the sense that as any particular predictor increases or decreases, the linear or logistic regression output can only increase or decrease (according to the sign of the coefficient on that predictor). This is because these regression models ultimate depend on the quantity \[ \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \dots + \beta_p x_p. \] Said another (slightly simplified) way, linear regression can only learn lines, so it cannot possibly handle situations like our synthetic data example above, where x values close to zero are labeled 1 and x values far from zero are labeled 0.

Regression trees avoid this issue entirely.

Now, as for the disadvantage of regression trees, well, we’ll get to that.

Learning a regression tree

Regression trees sound pretty great! But we haven’t actually said anything yet about how to actually learn a regression tree from data. That is, analogously to how we learn the coefficients in linear regression, how should we learn the “splits” at the nodes of the tree?

It will be helpful to abstract things a bit. Just like in plain old linear regression, we have \(n\) data points \(X_1,X_2,\dots,X_n \in \mathbb{R}^p\). Ultimately, our regression tree is going to try and partition the space into some number of parts, \(R_1,R_2,\dots,R_J\) (we’ll assume for now that we just know what this number \(J\) should be– we’ll come back to this).

We want to do the following:

  1. Partition the input space (i.e., \(\mathbb{R}^p\)) into \(J\) parts, \(R_1,R_2,\dots,R_J\). Recall that by definition of a partition, we have \(R_i \cap R_j = \emptyset\) for \(i \neq j\) and \(\cup_j R_j = \mathbb{R}^p\).
  2. For each part, assign a value \(\hat{y}_j\). This is our prediction for all points that land in part \(R_j\). Most typically, this will be the mean of the responses of all training data in region \(R_j\), and we’ll use that throughout this lecture just for the sake of concreteness and simplicity, but we stress that there are other things you can do in some special cases.

Example: Suppose that we have two regions \(R_1\) and \(R_2\) with response means \(\hat{y}_1 = 10\) and \(\hat{y}_2 = 20\). Given a data point \(x \in R_1\), we predict its response to be \(\hat{y}_1\), that is, \(10\). If \(x \in R_2\) we predict its response to be \(20\).

Okay, but how do we decide among different choices of regions \(R_1,R_2,\dots,R_J\)?

Well, our old friend least squares is back yet again. We will try to choose the regions to minimize the RSS \[ \sum_{j=1}^J \sum_{i \in R_j} \left( y_i - \hat{y}_{R_j} \right)^2, \] where by the sum over all \(i \in R_j\) we mean the sum over all data points \(X_i\) such that \(X_i \in R_j\), and \(\hat{y}_{R_j}\) is the mean of the points in \(R_j\). That is, if \(n_j\) is the number of data points \(X_i\) for which \(X_i \in R_j\), \[ \hat{y}_{R_j} = \frac{1}{n_j} \sum_{i \in R_j} X_i. \]

Great! So to find our regression tree we just need to hunt over all possible partitions into \(R_1,R_2,\dots,R_J\), compute this RSS for each one, and pick the best one. Hopefully it’s already clear that this isn’t really an option– it’s computationally infeasible, to say the least. After all, there are infinitely many ways to partition the data space into \(J\) parts. Instead of searching over all partitions, we will use a greedy approach, somewhat reminiscent of forward stepwise selection.

At each step of our procedure, we will recursively split our data space, choosing each split in a greedy manner. That is, at each step we will pick the best split of our data space right now, rather than trying to cleverly look ahead. Again, the similarity to stepwise selection should be clear– there the risk was that our greedy selection of variables might miss the globally best solution. There is a similar risk here.

Recursive Binary Splitting

Things are even worse, though– even if we just ask about splitting our data space in half at each step, there are a whole lot of different ways to split a space– we could make all kinds of weird curves or zig-zags or… Instead of opening that can of worms, we will make splits that lie along the axes of the data space. That is, to make a split, we will choose a predictor (i.e., one of our dimensions \(k=1,2,\dots,p\)), and make a split along that predictor’s dimension by choosing a split point \(s\). Each time we do this, we split one of our existing regions into two new regions. Hence this is called recursive binary splitting.

To start, we just want to split the space in half, into regions \[ \{ z \in \mathbb{R}^p : z_k < s \} ~\text{ and } ~ \{ z \in \mathbb{R}^p : z_k \ge s \}, \] and we will always do this in a way that results in the smallest RSS among the possible splits.

So, all in all, defining \[ R_1(j,s) = \{ z \in \mathbb{R}^p : z_k < s \} ~\text{ and } ~ R_2(j,s) = \{ z \in \mathbb{R}^p : z_k \ge s \}, \] we begin by partitioning the data space into two regions by choosing \(j \in \{1,2,\dots,p\}\) and \(s \in \mathbb{R}\) so as to minimize \[ \sum_{i : x_i \in R_1(j,s)} \left( y_i - \hat{y}_{R_1} \right)^2 + \sum_{i : x_i \in R_2(j,s)} \left( y_i - \hat{y}_{R_2} \right)^2 \] where again the \(\hat{y}\) terms are the means of the observations in the two regions.

Now, having done the above, we will have split the data space into two regions. It’s time to recurse– we will consider each of \(R_1\) and \(R_2\) separately, and consider all possible splits of each of these regions. Note that now we are not trying to split the whole space in half– we will only split one of the two regions, either \(R_1\) or \(R_2\). Again, we consider all the splits of either \(R_1\) or \(R_2\) and we carry out the split that results in the smallest RSS (this time the RSS is a sum of three sums over regions, rather than two like the equation above).

We keep repeating this procedure until some appropriate stopping cirterion is reached. Most often this stopping criterion takes the form of something like “stop when no region has more than \(C\) data points in it”, where \(C\) is some small integer (say, 5 or 10). Another common choice is to stop when none of the available splits decrease the RSS by more than some threshold amount.

Illustration: result of binary splitt with five regions

Hopefully it is clear how this recursive binary splitting procedure yields a tree. To drive things home, here’s an illustration.

  • On the left is one possible way of partitioning a two-dimensional data space into five regions (four splits– the first split creates two regions, each additional split turns one region into two, so after \(m\) splits we have \(m+1\) regions).

  • On the right is the tree corresponding to this partition. You can see, either by looking at the partition or at the tree, the order that the splits happened (with the exception that it is impossible to tell if the \(t_1\) or \(t_3\) split happened first just from looking at the data space– only the tree tells us that).

Credit: adapted from ISLR Figure 8.3

Pruning the tree

Now, the process described above usually produces very good RSS on training data, but if you use a train-test split, you’ll see that the above procedure overfits very easily, a phenomenon that you’ll explore in discussion section.

Just as having more predictors in linear regression gives us more flexibility to describe the data (and hence more opportunity to overfit), having a lot of regions means that the stepwise constant function that we use for prediction has a lot of freedom to really closely follow the training data. The way to avoid this is to make a smaller tree with fewer splits (equivalently, to have fewer regions). Just as with the LASSO and other regularization procedures, the hope is that by biasing ourselves toward simpler trees, we avoid getting fooled by the variance of the training data (there’s our old friend the bias-variance tradeoff again!).

What we would like to do, is build a really big tree (“big” meaning that it has lots of nodes– lots of splits), and then simplify it by “pruning” it, perhaps using CV to compare all the different subtrees of the same size. Unfortunately, considering all possible subtrees runs into computational issues again– there are just too many different subtrees to consider.

Instead, we will take a different approach, called cost complexity pruning, so called because we penalize the RSS with a term that measures complexity of the model, as measured by the number of nodes in the tree, which we denote \(|T|\). For a given tuning parameter \(\alpha \ge 0\), our new cost is \[ \sum_{m=1}^{|T|} \sum_{i : x_i \in R_m} \left(y_i - \hat{y}_{R_m} \right)^2 + \alpha |T|. \]

This should look a lot like the LASSO or ridge regression– we have the RSS, and we have added to it a quantity that is bigger when our model is “more complex”, with a tuning parameter (\(\lambda\) in the LASSO and ridge, \(\alpha\) here) that lets us specify how much we want to favor good RSS versus low complexity– when \(\alpha\) is bigger, we pay a higher price for more complex models, so a big tree needs to have better prediction to “pay for” its complexity.

As we increase \(\alpha\) from \(0\) to infinity, this corresponds to “pruning” branches from the tree. Equivalently, as \(\alpha\) increases, we merge more and more regions (i.e., undo their splits). As with the LASSO, we can choose a good value of \(\alpha\) using CV.

You’ll get some practice with actually doing this in R in discussion section using the rpart package (rpart is short for recursive partition– hopefully it’s clear why that’s a relevant name!). We should mention that there is another very popular package for regression trees and related models, called tree, which is worth looking at. Each package has its advantages and disadvantages (e.g., I think tree is a bit easier to use, but it has fewer fun bells and whistles), so it’s worth looking at them both to see which one you prefer to have as your go-to.

Trees vs Linear Models

Linear regression and regression trees both try to solve the same problem– minimize a squared error between observed responses and our predictions. They differ in the kind of functions that they learn to do this. Linear regression learned a prediction function of the form \[ f(X) = \beta_0 + \sum_{j=1}^p \beta_j X_j. \]

That is, it models the response as a linear function of the data.

In contrast, regression trees learn piecewise constant functions. These have the form \[ f(X) = \sum_{j=1}^J a_j \mathbb{1}_{R_j}(X), \] where \(a_1,a_2,\dots,a_J \in \mathbb{R}\) and \[ \mathbb{1}_{C}(z) = \begin{cases} 1 &\mbox{ if } z \in C, \\ 0 &\mbox{otherwise.} \end{cases} \] Here’s a good illustration of how each of these classes of functions can succeed or fail, depending on the data. This example is technically for classification, not regression, but it illustrates the basic problem nonetheless.

The green and yellow regions represent different classes, i.e., different parts of the data space that we want our model to be able to separate. Speaking somewhat cartoonishly, the problem is that linear regression only knows how to draw lines, and regression trees only know how to make boxes. Credit: ISLR Figure 8.7.

Advantages of Tree-based Methods

Tree-based methods have a number of advantages to them, mostly due to the simplicity of interpretation. For example, it is very easy to visualize a regression tree by, well, just drawing the tree (you’ll get practice with that in discussion section). Furthermore, tree models are very interpretable– decision trees are very easy things to think about, and they mirror human decision making in a lot of way (or so people claim!).

Two major drawbacs of simple tree-based models:

  1. They tend to be rather fragile– small changes in the data can result in very different predictions
  2. They tend not to perform as well as other similarly simple prediction methods, owing to the simplicity of the piecewise constant functions they learn

In our last lecture, we will see an adaptation of these methods that solves both of these problems– random forests and related methods are competitive with even the most cutting-edge neural nets on many tasks, while retaining a lot of the nice interpretability and visualization properties of regression trees.