Chapter 12 Tree based classification

12.1 Decision trees - recursive partitioning (rpart)

Basic algorithm

  1. Start with all variables in one group
  2. Find the variable/split that best separates the outcomes
  3. Divide the data into two groups (“leaves”) on that split (“node”)
  4. Within each split, find the best variable/split that separates the outcomes
  5. Continue until the groups are too small or sufficiently “pure”

Decision trees work with both continuous and categorical variables as independent variables.

We will try to predict the cut of the diamond based on every other variable we know.

library(caret)
library(rattle)
library(ggplot2)
library(dplyr)
data(diamonds)
diamonds = sample_n(diamonds, 1000) # Reduce the dataset size to improve processing
inTrain <- createDataPartition(y=diamonds$cut, p=0.75, list=FALSE)
training <- diamonds[inTrain,]
testing <- diamonds[-inTrain,]

Let us build the model. Method to use is rpart which is short for recursive partitioning.

fit <- train(cut ~ ., method="rpart", data=training)

Now we plot the tree. We will need the rattle library to do this, make sure it is installed.

First we look at the text based version of the tree. Then a fancier graphic.

print(fit$finalModel)
## n= 752 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 752 440 Ideal (0.041 0.093 0.21 0.24 0.41)  
##   2) table>=57.5 328 170 Premium (0.055 0.15 0.26 0.48 0.058)  
##     4) depth>=63.05 54  30 Good (0.26 0.44 0.28 0 0.019) *
##     5) depth< 63.05 274 116 Premium (0.015 0.091 0.25 0.58 0.066) *
##   3) table< 57.5 424 131 Ideal (0.031 0.05 0.17 0.059 0.69) *
library(rattle)
fancyRpartPlot(fit$finalModel)

How do you read the numbers in the boxes?

  • The top line is the predicted class
  • The second line is the predicted probability of the various categories (you can get the order of the categories using the levels command, eg levels(diamonds$cut) in this case.)
  • The final line is the percentage of observations in that node

We can verify the top line:

prop.table(table(diamonds$cut))
## 
##      Fair      Good Very Good   Premium     Ideal 
##     0.041     0.093     0.207     0.243     0.416

How accurate is this model? We check out the confusion matrix:

confusionMatrix(testing$cut, predict(fit, newdata=testing)) #actual, predicted
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Fair Good Very Good Premium Ideal
##   Fair         0    7         0       2     1
##   Good         0    6         0       7    10
##   Very Good    0    6         0      21    24
##   Premium      0    0         0      52     8
##   Ideal        0    1         0       7    96
## 
## Overall Statistics
##                                           
##                Accuracy : 0.621           
##                  95% CI : (0.5574, 0.6816)
##     No Information Rate : 0.5605          
##     P-Value [Acc > NIR] : 0.03122         
##                                           
##                   Kappa : 0.4348          
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: Fair Class: Good Class: Very Good
## Sensitivity                   NA     0.30000               NA
## Specificity              0.95968     0.92544           0.7944
## Pos Pred Value                NA     0.26087               NA
## Neg Pred Value                NA     0.93778               NA
## Prevalence               0.00000     0.08065           0.0000
## Detection Rate           0.00000     0.02419           0.0000
## Detection Prevalence     0.04032     0.09274           0.2056
## Balanced Accuracy             NA     0.61272               NA
##                      Class: Premium Class: Ideal
## Sensitivity                  0.5843       0.6906
## Specificity                  0.9497       0.9266
## Pos Pred Value               0.8667       0.9231
## Neg Pred Value               0.8032       0.7014
## Prevalence                   0.3589       0.5605
## Detection Rate               0.2097       0.3871
## Detection Prevalence         0.2419       0.4194
## Balanced Accuracy            0.7670       0.8086

12.2 Random Forest

Random forest is really just a decision tree with bagging.

Before we talk of bagging, let us look at what bootstrapping means. In bootstrapping, you treat the sample as if it were the population, and draw repeated samples of equal size from it. The samples are drawn with replacement. Now think that for each of these new samples you calculate a population characteristic, say the median. Because you potentially have a very large number of samples (theoretically infinite), you can get a distribution of the median of the population from our original single sample. If we hadn’t done bootstrapping (ie resample from the sample with replacement), we would have only one estimate for the median.

Though it is counter-intuitive to me, it improves the estimation process and reduces variance.

Bagging is a type of ensemble learning. Ensemble learning is where we combine multiple models to produce a better prediction or classification. In bagging, we produce multiple different training sets (called bootstrap samples), by sampling with replacement from the original dataset. Then, for each bootstrap sample, we build a model. The results in an ensemble of models, where each model votes with the equal weight. Typically, the goal of this procedure is to reduce the variance of the model of interest (e.g. decision trees).

If the above does not make sense, that is okay too.

The caret commands for using random forest is the same as for other models.

Let us solve exactly the same problem as before but using bagging - that algorithm is called Random Forests. It works as follows:

  1. Bootstrap samples
  2. At each split, bootstrap variables
  3. Grow multiple trees and vote
data(diamonds)
diamonds = sample_n(diamonds, 1000)
inTrain <- createDataPartition(y=diamonds$cut, p=0.75, list=FALSE)
training <- diamonds[inTrain,]
testing <- diamonds[-inTrain,]
fit <- train(cut ~ ., method="rf", data=training)
fit
## Random Forest 
## 
## 752 samples
##   9 predictor
##   5 classes: 'Fair', 'Good', 'Very Good', 'Premium', 'Ideal' 
## 
## No pre-processing
## Resampling: Bootstrapped (25 reps) 
## Summary of sample sizes: 752, 752, 752, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.5962501  0.4015812
##   11    0.6481090  0.4916450
##   20    0.6407784  0.4833075
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 11.

But there is a loss of interpretability as there is no ‘tree’ to show. The individual trees though can be obtained using the getTree function.