Random Forests#

Random forests are one of the most flexible and broadly useful AI/ML methods available. They’re sometimes considered the “Swiss army knife” of supervised AI/ML methods because they are capable of handling all kinds of data and are usually highly accurate even when other methods struggle. In this section, we’ll look at how they work and how to apply them to our example dataset.

How do random forests work?#

Random forests are fundamentally made up of several of another kind of machine learning algorithm called a decision tree. The idea behind a decision tree is that a set of data can be split into two subsets based on some feature, then the two subsets can each be split on different features, and their subsets can be split, etc., for some depth. After some number of splits in the data, the subsets will be small and a simple classification or regression can be trained for values in that subset. (In this section, we’ll look at using random forests for regression via the sklearn.ensemble.RandomForestRegressor class, but another similar type, sklearn.ensemble.RandomForestClassifier, more information here, can be used for classification.) For example, a decision tree modeling the CA Housing Dataset might initially split the data on the number of bedrooms, with all rows with fewer than 3 bedrooms going into one subset and all other rows going into another. These splits can be chosen randomly or using semi-random processes, and they can be generated very quickly.

One downside of decision trees is that they are prone to overfitting. Random forests solve this by training many decision trees from distinct randomly chosen subdatasets and averaging them together to form the final model. Although each decision tree is likely to be overfit, they are very unlikely to be overfit in the same features, so their average will usually be less overfit than any one decision tree.

Most random forest algorithms use cross validation internally to train and validate the decision trees they create, but you should still use cross validation yourself when training a random forest.

Limitations and Advantages of Random Forests#

Limitations

  • Random forests are very poor at extrapolation.

  • You can’t generally use details about the random forest model to understand how important/influential the model’s input parameters are for the output. (If you want to know if a parameter is negatively or positively related to the outputs, then linear regression is a better choice.)

Advantages

  • Random forests tend to be highly accurate for many kinds of data.

  • Random forests tend to be very robust to outliers.

  • Random forests usually handle missing data well.

Example: the California Housing Dataset#

We’ll use a random forest to try to predict the median housing prices in the CA Housing Dataset. We can start by loading the dataset as usual.

import sklearn as skl

# We use scikit-learn to download and return the CA housing dataset:
ca_housing_dataset = skl.datasets.fetch_california_housing()

# Extract the actual data rows and the feature names:
ca_housing_featdata = ca_housing_dataset['data']
ca_housing_featnames = ca_housing_dataset['feature_names']

# We also extract the "target" data, since we are using supervised learning:
ca_housing_targdata = ca_housing_dataset['target']
ca_housing_targnames = ca_housing_dataset['target_names']

As in the previous section, we’ll split the dataset into train and test subdatasets.

import numpy as np
# We set a specific random seed to make sure that this notebook runs the same way each time.
np.random.seed(0)

# Randomly select 75% of the rows to be in the training dataset.
all_rows = np.arange(ca_housing_featdata.shape[0])
n_train = int(round(len(all_rows) * 0.75))
n_test = len(all_rows) - n_train
train_rows = np.random.choice(all_rows, n_train, replace=False)
test_rows = np.setdiff1d(all_rows, train_rows)

# Extract these rows into separate matrices:
train_featdata = ca_housing_featdata[train_rows]
train_targdata = ca_housing_targdata[train_rows]
test_featdata = ca_housing_featdata[test_rows]
test_targdata = ca_housing_targdata[test_rows]

Next, we can create the random forest management object. In this case, we’ll want to give the management object a few hyperparameters, specifically the max_depth, which tells it how many times to split the data in the decision trees, and random_state, which can be used to ensure that randomized choices made by the algorithm are repeatable. We’ll use random_state=0 here, but you can use different random states to see how the algorithm varies across random runs.

from sklearn.ensemble import RandomForestRegressor

randforest = RandomForestRegressor(
    max_depth=6,
    random_state=0)

# Next, we train the random forest with our data.
randforest.fit(train_featdata, train_targdata)
RandomForestRegressor(max_depth=6, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

There are a variety of other hyperparameters that can be given to the RandomForestRegressor, such as the number of estimators to make and average together (n_estimators), that are outside the scope of this lesson. See the Scikit-learn documentation on random forests for more information on these options.

For now, let’s see how well the model performs on our test dataset. Like with the LinearRegression type, we can use the score method to obtain the coefficient of determination for the model and the test data.

randforest.score(test_featdata, test_targdata)
0.7094804834967965

The random forest appears to explain about 70–71% of the variance in the test dataset; that’s somewhat better than the linear regression model we saw earlier.

What kind of data does the RandomForestRegressor provide?#

As we’ve seen with other ML tools in Scikit-learn, the RandomForestRegressor provides us with some data about the trained model. In the LinearRegression type, these data were the coefficients associated with each feature in the regression. For random forests, these data are the individual decision tree estimators. We can see the list of estimators by examining the estimators_ member variable of the randforest object we created.

randforest.estimators_
[DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=209652396),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=398764591),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=924231285),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1478610112),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=441365315),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1537364731),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=192771779),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1491434855),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1819583497),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=530702035),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=626610453),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1650906866),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1879422756),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1277901399),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1682652230),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=243580376),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1991416408),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1171049868),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1646868794),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2051556033),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1252949478),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1340754471),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=124102743),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2061486254),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=292249176),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1686997841),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1827923621),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1443447321),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=305097549),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1449105480),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=374217481),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=636393364),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=86837363),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1581585360),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1428591347),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1963466437),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1194674174),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=602801999),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1589190063),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1589512640),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2055650130),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2034131043),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1284876248),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1292401841),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1982038771),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=87950109),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1204863635),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=768281747),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=507984782),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=947610023),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=600956192),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=352272321),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=615697673),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=160516793),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1909838463),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1110745632),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=93837855),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=454869706),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1780959476),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2034098327),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1136257699),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=800291326),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1177824715),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1017555826),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1959150775),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=930076700),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=293921570),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=580757632),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=80701568),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1392175012),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=505240629),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=642848645),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=481447462),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=954863080),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=502227700),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1659957521),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1905883471),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1729147268),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=780912233),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1932520490),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1544074682),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=485603871),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1877037944),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1728073985),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=848819521),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=426405863),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=258666409),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=2017814585),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=716257571),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=657731430),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=732884087),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=734051083),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=903586222),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1538251858),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=553734235),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1076688768),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1354754446),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=463129187),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1562125877),
 DecisionTreeRegressor(max_depth=6, max_features=1.0, random_state=1396067212)]

As we can see, the estimators_ variable is a list of DecisionTreeRegressor objects. This makes sense, since random forests are just collections of decision trees. Let’s take a closer look at one of the trees in our forest. The Scikit-learn library includes a function plot_tree in the sklearn.tree subpackage that can be used to visualize a decision tree as part of a matplotlib figure.

import matplotlib.pyplot as plt

# We'll look at the first tree:
tree = randforest.estimators_[0]

# Make a figure; we have to make the figure quite large in order for all of
# the text and all the nodes in the tree to be visible!
(fig,ax) = plt.subplots(1, 1, figsize=(24,12), dpi=72*8)

# Plot the tree:
skl.tree.plot_tree(tree, ax=ax)

plt.show()
../_images/961f365744929d46608974fa41cb08bc4a88c061aedc6cf67111545c96826609.png

The tree has very small text in its cells, so you may need to open the image in a new browser tab and zoom in in order to read it. Essentially, each node in the tree details the condition for splitting the data. In the root node of the tree, for example, the data are split according to the rule x[0] <= 5.715; the x[0] here indicates the first feature used in the training (the median income in a region of CA for our dataset).

One of the nice features of decision trees and random forests is that the trees themselves can be examined and understood—just by looking through the nodes of this tree, we can get a general sense of how the algorithm has decided to calculate a prediction. The Scikit-learn library additionally includes a number of utilities and tutorials related to decision trees and how to evaluate and examine them. In particular, more information on the structure of the decision trees can be found here and more general information on decision trees can be found here.