Introduction

The idea of trees in statistical learning has similarity with decision trees models. Trees and its extensions have been used widely for modeling the prediction of outcome variables. Let’s start with reviewing the advantages and disadvantages of this method:

Benefits:

Disadvantage:

However, if we grow a good forest and bag/boost the models to combine the fits from different tree models, we can improve the prediction power.

Regression Trees

It is much easier to explain the tree models using an example. Let’s start with the Boston housing data that we used before in \(kNN\) section. Similarly, I want to predict median housing price in a neighborhood, \({\tt medv}\), using the ratio of low state households, \({\tt lstat}\).

library(MASS) 
attach(Boston)
n = nrow(Boston)

ddf = data.frame(lstat,medv)

# Sort the data here
oo=order(ddf$lstat)

ddf = ddf[oo,]


library(rpart)

m_tree=rpart(medv~lstat,data=ddf)

#plot the estimated tree

summary(m_tree)
## Call:
## rpart(formula = medv ~ lstat, data = ddf)
##   n= 506 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.44236500      0 1.0000000 1.0027671 0.08308804
## 2 0.15283400      1 0.5576350 0.6203675 0.05364975
## 3 0.06275014      2 0.4048010 0.4638841 0.04157803
## 4 0.01374053      3 0.3420509 0.4056533 0.03797861
## 5 0.01200979      4 0.3283103 0.3803722 0.03722361
## 6 0.01159492      5 0.3163005 0.3833462 0.03747779
## 7 0.01000000      6 0.3047056 0.3766747 0.03722780
## 
## Variable importance
## lstat 
##   100 
## 
## Node number 1: 506 observations,    complexity param=0.442365
##   mean=22.53281, MSE=84.41956 
##   left son=2 (294 obs) right son=3 (212 obs)
##   Primary splits:
##       lstat < 9.725  to the right, improve=0.442365, (0 missing)
## 
## Node number 2: 294 observations,    complexity param=0.06275014
##   mean=17.34354, MSE=23.83089 
##   left son=4 (144 obs) right son=5 (150 obs)
##   Primary splits:
##       lstat < 16.085 to the right, improve=0.3825785, (0 missing)
## 
## Node number 3: 212 observations,    complexity param=0.152834
##   mean=29.72925, MSE=79.31047 
##   left son=6 (162 obs) right son=7 (50 obs)
##   Primary splits:
##       lstat < 4.65   to the right, improve=0.3882819, (0 missing)
## 
## Node number 4: 144 observations,    complexity param=0.01374053
##   mean=14.26181, MSE=18.74458 
##   left son=8 (75 obs) right son=9 (69 obs)
##   Primary splits:
##       lstat < 19.9   to the right, improve=0.2174498, (0 missing)
## 
## Node number 5: 150 observations
##   mean=20.302, MSE=10.84406 
## 
## Node number 6: 162 observations,    complexity param=0.01159492
##   mean=26.6463, MSE=42.74335 
##   left son=12 (134 obs) right son=13 (28 obs)
##   Primary splits:
##       lstat < 5.495  to the right, improve=0.07152825, (0 missing)
## 
## Node number 7: 50 observations,    complexity param=0.01200979
##   mean=39.718, MSE=67.21788 
##   left son=14 (32 obs) right son=15 (18 obs)
##   Primary splits:
##       lstat < 3.325  to the right, improve=0.1526421, (0 missing)
## 
## Node number 8: 75 observations
##   mean=12.32533, MSE=16.18776 
## 
## Node number 9: 69 observations
##   mean=16.36667, MSE=13.01729 
## 
## Node number 12: 134 observations
##   mean=25.84701, MSE=40.14324 
## 
## Node number 13: 28 observations
##   mean=30.47143, MSE=37.49776 
## 
## Node number 14: 32 observations
##   mean=37.31562, MSE=65.92382 
## 
## Node number 15: 18 observations
##   mean=43.98889, MSE=41.01765
summary(ddf$medv)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    5.00   17.02   21.20   22.53   25.00   50.00
plot(m_tree)

text(m_tree)

But, this plot does not often look good. Indeed, it usually is a mess! To plot a nice tree, I suggest using \({\tt rpart.plot}\) package.

library(rpart.plot)

rpart.plot(m_tree)

At each interior node, there is decision rule, \(x>c\). If \(x>c\), you should follow the branch on the left. This continues until you reach a bottom/terminal node, which also is known as the leaf of the tree.

plot(ddf$lstat,ddf$medv, col='blue')

lines(ddf$lstat,predict(m_tree,ddf),col='maroon',lwd=3)


for (i in m_tree$splits[,4] ){
abline(v=i, col="orange",lty=2)
}

The set of bottom nodes gives us a partition of the predictor \(x\) space into disjoint regions. At right, the vertical lines display the partition. With just one x, this is just a set of intervals.

Within each region (interval) we compute the average of the \(y\) values for the subset of training data in the region. This gives us the step function which is our \(\hat{f}\). The \(\bar{y}\) values are also printed at the bottom nodes. To predict, we just use the above step function estimation of f(x).

A Tree with Two Explanatory Variables

Now, let’s estimate a tree model using two features: \(x=(lstat, dis)\) and \(y=medx\).

First, let’s take a look at scatter plot of these variables:

attach(Boston)

ddf = data.frame(medv,rm,nox,lstat, dis)

# Sort the data here
oo=order(ddf$lstat)

ddf = ddf[oo,]

library(ggplot2)

ggplot(ddf, aes(lstat, dis, colour = medv)) + geom_point()

m_tree2=rpart(medv~lstat+dis,data=ddf)


rpart.plot(m_tree2, box.palette = "Grays")

print(m_tree2)
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 506 42716.3000 22.53281  
##    2) lstat>=9.725 294  7006.2830 17.34354  
##      4) lstat>=16.085 144  2699.2200 14.26181  
##        8) dis< 2.0037 75  1114.6990 11.96933 *
##        9) dis>=2.0037 69   761.9316 16.75362 *
##      5) lstat< 16.085 150  1626.6090 20.30200 *
##    3) lstat< 9.725 212 16813.8200 29.72925  
##      6) lstat>=4.65 162  6924.4230 26.64630  
##       12) dis>=2.4501 144  3493.3720 25.67986  
##         24) lstat>=5.495 118  2296.9260 24.71695 *
##         25) lstat< 5.495 26   590.4850 30.05000 *
##       13) dis< 2.4501 18  2220.5910 34.37778 *
##      7) lstat< 4.65 50  3360.8940 39.71800  
##       14) dis>=3.20745 38  2087.6880 37.00789  
##         28) lstat>=3.99 13   179.3908 32.06154 *
##         29) lstat< 3.99 25  1424.8400 39.58000 *
##       15) dis< 3.20745 12   110.3000 48.30000 *
summary(m_tree2)
## Call:
## rpart(formula = medv ~ lstat + dis, data = ddf)
##   n= 506 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.44236500      0 1.0000000 1.0037152 0.08298818
## 2 0.15283400      1 0.5576350 0.6110681 0.05115273
## 3 0.06275014      2 0.4048010 0.4755064 0.04467493
## 4 0.02833720      3 0.3420509 0.4055890 0.04151633
## 5 0.02722395      4 0.3137137 0.4149826 0.04196536
## 6 0.01925703      5 0.2864897 0.4002919 0.04106525
## 7 0.01418570      6 0.2672327 0.3852009 0.04096723
## 8 0.01131786      7 0.2530470 0.3566216 0.04021798
## 9 0.01000000      8 0.2417291 0.3483610 0.03874767
## 
## Variable importance
## lstat   dis 
##    72    28 
## 
## Node number 1: 506 observations,    complexity param=0.442365
##   mean=22.53281, MSE=84.41956 
##   left son=2 (294 obs) right son=3 (212 obs)
##   Primary splits:
##       lstat < 9.725   to the right, improve=0.4423650, (0 missing)
##       dis   < 2.5977  to the left,  improve=0.1169235, (0 missing)
##   Surrogate splits:
##       dis < 4.48025 to the left,  agree=0.737, adj=0.373, (0 split)
## 
## Node number 2: 294 observations,    complexity param=0.06275014
##   mean=17.34354, MSE=23.83089 
##   left son=4 (144 obs) right son=5 (150 obs)
##   Primary splits:
##       lstat < 16.085  to the right, improve=0.3825785, (0 missing)
##       dis   < 2.0754  to the left,  improve=0.3369205, (0 missing)
##   Surrogate splits:
##       dis < 2.2166  to the left,  agree=0.724, adj=0.438, (0 split)
## 
## Node number 3: 212 observations,    complexity param=0.152834
##   mean=29.72925, MSE=79.31047 
##   left son=6 (162 obs) right son=7 (50 obs)
##   Primary splits:
##       lstat < 4.65    to the right, improve=0.3882819, (0 missing)
##       dis   < 2.16475 to the right, improve=0.1528193, (0 missing)
##   Surrogate splits:
##       dis < 1.48495 to the right, agree=0.769, adj=0.02, (0 split)
## 
## Node number 4: 144 observations,    complexity param=0.01925703
##   mean=14.26181, MSE=18.74458 
##   left son=8 (75 obs) right son=9 (69 obs)
##   Primary splits:
##       dis   < 2.0037  to the left,  improve=0.3047506, (0 missing)
##       lstat < 19.9    to the right, improve=0.2174498, (0 missing)
##   Surrogate splits:
##       lstat < 19.73   to the right, agree=0.771, adj=0.522, (0 split)
## 
## Node number 5: 150 observations
##   mean=20.302, MSE=10.84406 
## 
## Node number 6: 162 observations,    complexity param=0.0283372
##   mean=26.6463, MSE=42.74335 
##   left son=12 (144 obs) right son=13 (18 obs)
##   Primary splits:
##       dis   < 2.4501  to the right, improve=0.17481020, (0 missing)
##       lstat < 5.495   to the right, improve=0.07152825, (0 missing)
## 
## Node number 7: 50 observations,    complexity param=0.02722395
##   mean=39.718, MSE=67.21788 
##   left son=14 (38 obs) right son=15 (12 obs)
##   Primary splits:
##       dis   < 3.20745 to the right, improve=0.3460110, (0 missing)
##       lstat < 3.325   to the right, improve=0.1526421, (0 missing)
##   Surrogate splits:
##       lstat < 1.95    to the right, agree=0.8, adj=0.167, (0 split)
## 
## Node number 8: 75 observations
##   mean=11.96933, MSE=14.86266 
## 
## Node number 9: 69 observations
##   mean=16.75362, MSE=11.04249 
## 
## Node number 12: 144 observations,    complexity param=0.0141857
##   mean=25.67986, MSE=24.25952 
##   left son=24 (118 obs) right son=25 (26 obs)
##   Primary splits:
##       lstat < 5.495   to the right, improve=0.17346010, (0 missing)
##       dis   < 4.39775 to the right, improve=0.07588735, (0 missing)
## 
## Node number 13: 18 observations
##   mean=34.37778, MSE=123.3662 
## 
## Node number 14: 38 observations,    complexity param=0.01131786
##   mean=37.00789, MSE=54.93915 
##   left son=28 (13 obs) right son=29 (25 obs)
##   Primary splits:
##       lstat < 3.99    to the right, improve=0.2315753, (0 missing)
##       dis   < 6.40075 to the right, improve=0.1173026, (0 missing)
## 
## Node number 15: 12 observations
##   mean=48.3, MSE=9.191667 
## 
## Node number 24: 118 observations
##   mean=24.71695, MSE=19.46548 
## 
## Node number 25: 26 observations
##   mean=30.05, MSE=22.71096 
## 
## Node number 28: 13 observations
##   mean=32.06154, MSE=13.79929 
## 
## Node number 29: 25 observations
##   mean=39.58, MSE=56.9936
library(tree)

m_tree2=tree(medv~lstat+dis,data=ddf)

summary(m_tree2)
## 
## Regression tree:
## tree(formula = medv ~ lstat + dis, data = ddf)
## Number of terminal nodes:  9 
## Residual mean deviance:  20.78 = 10330 / 497 
## Distribution of residuals:
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## -14.780  -2.713  -0.302   0.000   2.298  15.620
partition.tree(m_tree2)

attach(Boston)
mid<-mean(medv)
ggplot(ddf, aes(lstat, rm, colour = medv)) + geom_point()+scale_color_gradient2(midpoint=mid, low="blue", mid="cyan",
                     high="maroon", space ="Lab" )

m_tree31=rpart(medv~lstat+rm+dis,data=ddf)

summary(m_tree31)
## Call:
## rpart(formula = medv ~ lstat + rm + dis, data = ddf)
##   n= 506 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.45274420      0 1.0000000 1.0014734 0.08290984
## 2 0.17117244      1 0.5472558 0.6330082 0.05712185
## 3 0.07165784      2 0.3760834 0.4251218 0.04691289
## 4 0.03616428      3 0.3044255 0.3495625 0.04286299
## 5 0.03336923      4 0.2682612 0.3259299 0.04303387
## 6 0.02311607      5 0.2348920 0.3079637 0.04221997
## 7 0.01585116      6 0.2117759 0.2922020 0.04058005
## 8 0.01000000      7 0.1959248 0.2869241 0.04052948
## 
## Variable importance
##    rm lstat   dis 
##    54    33    13 
## 
## Node number 1: 506 observations,    complexity param=0.4527442
##   mean=22.53281, MSE=84.41956 
##   left son=2 (430 obs) right son=3 (76 obs)
##   Primary splits:
##       rm    < 6.941   to the left,  improve=0.4527442, (0 missing)
##       lstat < 9.725   to the right, improve=0.4423650, (0 missing)
##       dis   < 2.5977  to the left,  improve=0.1169235, (0 missing)
##   Surrogate splits:
##       lstat < 4.83    to the right, agree=0.891, adj=0.276, (0 split)
## 
## Node number 2: 430 observations,    complexity param=0.1711724
##   mean=19.93372, MSE=40.27284 
##   left son=4 (175 obs) right son=5 (255 obs)
##   Primary splits:
##       lstat < 14.4    to the right, improve=0.4222277, (0 missing)
##       dis   < 2.58835 to the left,  improve=0.2036255, (0 missing)
##       rm    < 6.5455  to the left,  improve=0.1442877, (0 missing)
##   Surrogate splits:
##       dis < 2.23935 to the left,  agree=0.781, adj=0.463, (0 split)
##       rm  < 5.858   to the left,  agree=0.688, adj=0.234, (0 split)
## 
## Node number 3: 76 observations,    complexity param=0.07165784
##   mean=37.23816, MSE=79.7292 
##   left son=6 (46 obs) right son=7 (30 obs)
##   Primary splits:
##       rm    < 7.437   to the left,  improve=0.50515690, (0 missing)
##       lstat < 4.68    to the right, improve=0.33189140, (0 missing)
##       dis   < 6.75885 to the right, improve=0.03479472, (0 missing)
##   Surrogate splits:
##       lstat < 3.99    to the right, agree=0.776, adj=0.433, (0 split)
## 
## Node number 4: 175 observations,    complexity param=0.02311607
##   mean=14.956, MSE=19.27572 
##   left son=8 (86 obs) right son=9 (89 obs)
##   Primary splits:
##       dis   < 2.0037  to the left,  improve=0.2927244, (0 missing)
##       lstat < 19.83   to the right, improve=0.2696497, (0 missing)
##       rm    < 5.567   to the left,  improve=0.0750970, (0 missing)
##   Surrogate splits:
##       lstat < 19.73   to the right, agree=0.749, adj=0.488, (0 split)
##       rm    < 5.5505  to the left,  agree=0.634, adj=0.256, (0 split)
## 
## Node number 5: 255 observations,    complexity param=0.03616428
##   mean=23.3498, MSE=26.0087 
##   left son=10 (248 obs) right son=11 (7 obs)
##   Primary splits:
##       dis   < 1.5511  to the right, improve=0.2329242, (0 missing)
##       lstat < 4.91    to the right, improve=0.2208409, (0 missing)
##       rm    < 6.543   to the left,  improve=0.2172099, (0 missing)
## 
## Node number 6: 46 observations,    complexity param=0.01585116
##   mean=32.11304, MSE=41.29592 
##   left son=12 (7 obs) right son=13 (39 obs)
##   Primary splits:
##       lstat < 9.65    to the right, improve=0.35644260, (0 missing)
##       dis   < 3.45845 to the left,  improve=0.08199333, (0 missing)
##       rm    < 7.3     to the right, improve=0.05938584, (0 missing)
##   Surrogate splits:
##       dis < 1.6469  to the left,  agree=0.87, adj=0.143, (0 split)
## 
## Node number 7: 30 observations
##   mean=45.09667, MSE=36.62832 
## 
## Node number 8: 86 observations
##   mean=12.53953, MSE=16.13867 
## 
## Node number 9: 89 observations
##   mean=17.29101, MSE=11.21228 
## 
## Node number 10: 248 observations,    complexity param=0.03336923
##   mean=22.93629, MSE=14.75159 
##   left son=20 (193 obs) right son=21 (55 obs)
##   Primary splits:
##       rm    < 6.543   to the left,  improve=0.38962730, (0 missing)
##       lstat < 7.685   to the right, improve=0.33560120, (0 missing)
##       dis   < 2.8405  to the left,  improve=0.03769213, (0 missing)
##   Surrogate splits:
##       lstat < 5.055   to the right, agree=0.839, adj=0.273, (0 split)
##       dis   < 10.648  to the left,  agree=0.782, adj=0.018, (0 split)
## 
## Node number 11: 7 observations
##   mean=38, MSE=204.1457 
## 
## Node number 12: 7 observations
##   mean=23.05714, MSE=61.85673 
## 
## Node number 13: 39 observations
##   mean=33.73846, MSE=20.24391 
## 
## Node number 20: 193 observations
##   mean=21.65648, MSE=8.23738 
## 
## Node number 21: 55 observations
##   mean=27.42727, MSE=11.69398
rpart.plot(m_tree31, box.palette = "Grays")

m_tree32=rpart(medv~lstat+rm+dis,data=ddf)


rpart.plot(m_tree32, box.palette = "Grays")

What is the loss problem and optimization question in tree algorithm?

As shown above, we want to split the feature space, i.e. \(X=\{x_1, x_2, \dots, x_k\}\), to smaller rectangular, or boxes. If we find a splitting pattern which has the lowest prediction error for the training data, then we can predict the associated outcome with the prediction feature set \(X_p\) by finding to which box this set belongs. Roughly, we can think about about tree algorithm like a \(kNN\) algorithm in which \(k\) is dynamic and can change according to the changes in the data.

\[\begin{align} min_{R_1,R_2,\dots,R_J} \sum_{j=1}^J \sum_{i \in R_j} (y_i-\hat{y}_{R_j})^2 \end{align}\]

This loss has a straightforward intuition: we want to find \(R_j\)s to minimize the sum of the squared deviation of observations from the mean response within each box. In theory, we can define infinite number of boxes, and we need to evaluate the squared errors of each set of boxes to see which one is minimized. This is computationally very intensive and almost impossible.

To solve this problem, a recursive partitioning top-down approach is used to find the optimal boxes:

This is a greedy approach since the splitting process only focus on each step and pick the split that is associated with the best fit at the current step, rather than looking ahead and picking the split which lead to a better prediction tree in the following steps.

How to fit the data to a tree model?

To execute the above algorithm, statistical algorithms re-write the above problem as follow:

\[\begin{align} C(T,y)=L(T,y)+\alpha |T| \end{align}\]

where L(T,y) is the loss of fitting outcome \(y\) with tree \(T\). Our goal always is minimizing the loss, that is small \(L\) is preferred.

However, we do not want to make the model/tree too complex. A complex model might be good in capturing all variances in the training data, but would lead to a biased prediction of in the test data. Thus, we add the number of nodes in tree T, i.e. \(|T|\), with a penalty parameter \(\alpha\) to the cost function \(C(T,y)\). For the continuous \(y\), we can use RMSE, and for categorical \(y\), we can use any of the classification measures that are discussed in the previous section.

What is the best \(\alpha\)? if we pick a big \(\alpha\), then the penalty for having a big and complex tree is large. Thus, the optimization problem return a small tree which can lead to a large \(L\) on the training data. On the other hand, for a small \(\alpha\), we allow the model pick a big tree. Here, \(\alpha\) is analogous to \(k\) in \(kNN\). How do we pick \(\alpha\)? The answer similar to most of the similar cases in this course is cross-validation.

Pruning tree

The idea here is first allow a tree grows as much as possible without being worried about its complexity size. Then, considering the complexity parameter, CP value, we prune/cut the tree with the optimal CP value.

attach(Boston)

df=Boston[,c(8,13,14)] # pick dis, lstat, medv

print(names(df))
## [1] "dis"   "lstat" "medv"
# First grow a big tree

tree_ml=tree(medv~., df)

print(tree_ml)
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 506 42720.0 22.53  
##    2) lstat < 9.725 212 16810.0 29.73  
##      4) lstat < 4.65 50  3361.0 39.72  
##        8) dis < 3.20745 12   110.3 48.30 *
##        9) dis > 3.20745 38  2088.0 37.01  
##         18) lstat < 3.99 25  1425.0 39.58 *
##         19) lstat > 3.99 13   179.4 32.06 *
##      5) lstat > 4.65 162  6924.0 26.65  
##       10) dis < 2.4501 18  2221.0 34.38 *
##       11) dis > 2.4501 144  3493.0 25.68  
##         22) lstat < 5.495 26   590.5 30.05 *
##         23) lstat > 5.495 118  2297.0 24.72 *
##    3) lstat > 9.725 294  7006.0 17.34  
##      6) lstat < 16.085 150  1627.0 20.30 *
##      7) lstat > 16.085 144  2699.0 14.26  
##       14) dis < 2.0037 75  1115.0 11.97 *
##       15) dis > 2.0037 69   761.9 16.75 *
plot(tree_ml)
text(tree_ml)

cat("Size of the big tree: \n")
## Size of the big tree:
print(length(unique(tree_ml$where)))
## [1] 9
Boston_fit=predict(tree_ml,df)

# Now, let's prune it down to a tree with 6 leaves/nodes

tree_ml_7=prune.tree(tree_ml,best=7)

cat("Size of the big tree: \n")
## Size of the big tree:
print(length(unique(tree_ml_7$where)))
## [1] 7
par(mfrow=c(1,2))

plot(tree_ml_7, type='u')
text(tree_ml_7, col='maroon', label=c('yval'), cex=.9)


Boston_fit7=predict(tree_ml_7,df)



##


ResultsMat=cbind(medv,Boston_fit,Boston_fit7)

colnames(ResultsMat)=c("medv","tree","tree7")

pairs(ResultsMat, col='maroon')

print(cor(ResultsMat))
##            medv      tree     tree7
## medv  1.0000000 0.8707875 0.8560183
## tree  0.8707875 1.0000000 0.9830393
## tree7 0.8560183 0.9830393 1.0000000
#Let's use the trained algorithm to predict the value for a specific point

Pred_point=data.frame(lstat=15,dis=2)

yhat=predict(tree_ml,Pred_point)

cat('prediction is: \n')
## prediction is:
print(yhat)
##      1 
## 20.302

A cross-validated problem

attach(Boston)
set.seed(7)

df=Boston[,c(8,13,14)]

print(names(df))
## [1] "dis"   "lstat" "medv"
# Let's fit a single tree and plot the importance of the variables


tree=rpart(medv~., method="anova",data=df,
           control=rpart.control(minsplit = 5,cp=.0005))


Ntree=length(unique(tree$where))

cat("Size of big tree is:",Ntree,"\n")
## Size of big tree is: 90
# Let's check the CV results

plotcp(tree)

best_a=which.min(tree$cptable[,"xerror"])

bestCP=tree$cptable[best_a,"CP"]

bestSize=tree$cptable[best_a,"nsplit"]+1

best_tree=prune(tree,cp=bestCP)

plot(best_tree,uniform=TRUE)
text(best_tree,digits=2, use.n = TRUE)

Bagging and Random Forest

The idea behind Bagging is bootstrapping the data adequately enough to make sure that trees with good explanatory power are captured. This approach is computationally intensive and time-consuming.

Random Forests starts from Bagging and adds another kind of randomization. In this method, instead of going over all features when we do the greedy approach, this method randomly samples a subset of \(m\) variables to search over each time we make a split.

In this method, more types of trees will be evaluated. Since this method is Bootstraped, the important variables will be identified in average more than others; and thus can be used for prediction.

How to choose parameters for a Random Forest model? \(B\) is the number of Bootstrapped samples, and \(m\) is the number of variables to sample. A common choice for \(m\) is \(\sqrt p\), where \(p\) is the number of features in the model. When we set \(m=p\), then Random Forest is Bagging.

California Housing Data

library(randomForest)


rawData=read.csv("https://raw.githubusercontent.com/babakrezaee/MethodsCourses/master/DataSets/calhouse.csv")

# First divide the sample to train, validation, and test samples


set.seed(7)


n=nrow(rawData)

n1=floor(n/2)
n2=floor(n/4)
n3=n-n1-n2

#Shuffle the data
ii=sample(1:n,n)

CAtrain=rawData[ii[1:n1],]
CAval=rawData[ii[n1+1:n2],]
CAtest=rawData[ii[n1+n2+1:n3],]

## Fitting using the RF on train and evaluate on the tes, and predict on val

RF=randomForest(logMedVal~.,data=CAtrain,mtry=3,ntree=500)

RFpred=predict(RF,newdata=CAval)

RMSE=sqrt(mean(CAval$logMedVal-RFpred)^2)

cat('RMSE on the train data for Random Forest is',RMSE,"\n")
## RMSE on the train data for Random Forest is 0.002017326
getTree(RF, 1, labelVar=TRUE)[1:30,]
##    left daughter right daughter        split var split point status
## 1              2              3     medianIncome    3.558150     -3
## 2              4              5     AveOccupancy    2.284483     -3
## 3              6              7     AveOccupancy    2.583916     -3
## 4              8              9         AveRooms    4.335409     -3
## 5             10             11        AveBedrms    1.127303     -3
## 6             12             13 housingMedianAge   22.500000     -3
## 7             14             15        longitude -117.605000     -3
## 8             16             17 housingMedianAge   22.500000     -3
## 9             18             19     medianIncome    2.509550     -3
## 10            20             21        longitude -118.900000     -3
## 11            22             23       households  612.500000     -3
## 12            24             25     medianIncome    5.144600     -3
## 13            26             27        longitude -122.405000     -3
## 14            28             29     medianIncome    5.653950     -3
## 15            30             31     medianIncome    5.214250     -3
## 16            32             33     medianIncome    3.078150     -3
## 17            34             35       population  314.500000     -3
## 18            36             37         latitude   38.935000     -3
## 19            38             39     medianIncome    3.072900     -3
## 20            40             41        longitude -121.625000     -3
## 21            42             43     medianIncome    2.225750     -3
## 22            44             45         AveRooms    4.328272     -3
## 23            46             47         latitude   38.475000     -3
## 24            48             49        longitude -116.130000     -3
## 25            50             51       households  140.500000     -3
## 26            52             53       households  520.000000     -3
## 27            54             55        longitude -118.605000     -3
## 28            56             57     medianIncome    4.532150     -3
## 29            58             59         latitude   37.965000     -3
## 30            60             61     AveOccupancy    3.219902     -3
##    prediction
## 1    12.08072
## 2    11.77485
## 3    12.40353
## 4    12.05452
## 5    11.69947
## 6    12.57349
## 7    12.32932
## 8    12.20406
## 9    11.88783
## 10   11.73685
## 11   11.56837
## 12   12.38116
## 13   12.66501
## 14   12.36566
## 15   12.08772
## 16   11.94164
## 17   12.30353
## 18   11.57327
## 19   12.12194
## 20   11.56460
## 21   11.89506
## 22   11.51895
## 23   11.81418
## 24   12.26544
## 25   12.65011
## 26   12.84608
## 27   12.63211
## 28   12.19170
## 29   12.72716
## 30   11.93548
pairs(cbind(CAval$logMedVal,RFpred))

print(cor(cbind(CAval$logMedVal,RFpred)))
##                     RFpred
##        1.0000000 0.9122118
## RFpred 0.9122118 1.0000000
varImpPlot(RF)


  1. Copyright 2019. This is an in-progress project, please do not cite or reproduce without my permission.