This page was generated from doc/user_guide/random-survival-forest.ipynb.
Interactive online version: Binder badge.

Using Random Survival Forests

This notebook demonstrates how to use Random Survival Forests introduced in scikit-survival 0.11.

As it’s popular counterparts for classification and regression, a Random Survival Forest is an ensemble of tree-based learners. A Random Survival Forest ensures that individual trees are de-correlated by 1) building each tree on a different bootstrap sample of the original training data, and 2) at each node, only evaluate the split criterion for a randomly selected subset of features and thresholds. Predictions are formed by aggregating predictions of individual trees in the ensemble.

To demonstrate Random Survival Forest, we are going to use data from the German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women and 8 prognostic factors: 1. age, 2. estrogen receptor (estrec), 3. whether or not a hormonal therapy was administered (horTh), 4. menopausal status (menostat), 5. number of positive lymph nodes (pnodes), 6. progesterone receptor (progrec), 7. tumor size (tsize, 8. tumor grade (tgrade).

The goal is to predict recurrence-free survival time.

[1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

First, we need to load the data and transform it into numeric values.

[2]:
X, y = load_gbsg2()

grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)
Xt = np.column_stack((Xt.values, grade_num))

feature_names = X_no_grade.columns.tolist() + ["tgrade"]

Next, the data is split into 75% for training and 25% for testing, so we can determine how well our model generalizes.

[3]:
random_state = 20

X_train, X_test, y_train, y_test = train_test_split(
    Xt, y, test_size=0.25, random_state=random_state)

Training

Several split criterion have been proposed in the past, but the most widespread one is based on the log-rank test, which you probably know from comparing survival curves among two or more groups. Using the training data, we fit a Random Survival Forest comprising 1000 trees.

[4]:
rsf = RandomSurvivalForest(n_estimators=1000,
                           min_samples_split=10,
                           min_samples_leaf=15,
                           max_features="sqrt",
                           n_jobs=-1,
                           random_state=random_state)
rsf.fit(X_train, y_train)
[4]:
RandomSurvivalForest(max_features='sqrt', min_samples_leaf=15,
                     min_samples_split=10, n_estimators=1000, n_jobs=-1,
                     random_state=20)

We can check how well the model performs by evaluating it on the test data.

[5]:
rsf.score(X_test, y_test)
[5]:
0.6759696016771488

This gives a concordance index of 0.68, which is a good a value and matches the results reported in the Random Survival Forests paper.

Predicting

For prediction, a sample is dropped down each tree in the forest until it reaches a terminal node. Data in each terminal is used to non-parametrically estimate the survival and cumulative hazard function using the Kaplan-Meier and Nelson-Aalen estimator, respectively. In addition, a risk score can be computed that represents the expected number of events for one particular terminal node. The ensemble prediction is simply the average across all trees in the forest.

Let’s first select a couple of patients from the test data according to the number of positive lymph nodes and age.

[6]:
a = np.empty(X_test.shape[0], dtype=[("age", float), ("pnodes", float)])
a["age"] = X_test[:, 0]
a["pnodes"] = X_test[:, 4]

sort_idx = np.argsort(a, order=["pnodes", "age"])
X_test_sel = pd.DataFrame(
    X_test[np.concatenate((sort_idx[:3], sort_idx[-3:]))],
    columns=feature_names)

X_test_sel
[6]:
age estrec horTh menostat pnodes progrec tsize tgrade
0 33.0 0.0 0.0 0.0 1.0 26.0 35.0 2.0
1 34.0 37.0 0.0 0.0 1.0 0.0 40.0 2.0
2 36.0 14.0 0.0 0.0 1.0 76.0 36.0 1.0
3 65.0 64.0 0.0 1.0 26.0 2.0 70.0 2.0
4 80.0 59.0 0.0 1.0 30.0 0.0 39.0 1.0
5 72.0 1091.0 1.0 1.0 36.0 2.0 34.0 2.0

The predicted risk scores indicate that risk for the last three patients is quite a bit higher than that of the first three patients.

[7]:
pd.Series(rsf.predict(X_test_sel))
[7]:
0     91.477609
1    102.897552
2     75.883786
3    170.502092
4    171.210066
5    148.691835
dtype: float64

We can have a more detailed insight by considering the predicted survival function. It shows that the biggest difference occurs roughly within the first 750 days.

[8]:
surv = rsf.predict_survival_function(X_test_sel, return_array=True)

for i, s in enumerate(surv):
    plt.step(rsf.event_times_, s, where="post", label=str(i))
plt.ylabel("Survival probability")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)
../_images/user_guide_random-survival-forest_16_0.png

Alternatively, we can also plot the predicted cumulative hazard function.

[9]:
surv = rsf.predict_cumulative_hazard_function(X_test_sel, return_array=True)

for i, s in enumerate(surv):
    plt.step(rsf.event_times_, s, where="post", label=str(i))
plt.ylabel("Cumulative hazard")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)
../_images/user_guide_random-survival-forest_18_0.png

Permutation-based Feature Importance

The implementation is based on scikit-learn’s Random Forest implementation and inherits many features, such as building trees in parallel. What’s currently missing is feature importances via the feature_importance_ attribute. This is due to the way scikit-learn’s implementation computes importances. It relies on a measure of impurity for each child node, and defines importance as the amount of decrease in impurity due to a split. For traditional regression, impurity would be measured by the variance, but for survival analysis there is no per-node impurity measure due to censoring. Instead, one could use the magnitude of the log-rank test statistic as an importance measure, but scikit-learn’s implementation doesn’t seem to allow this.

Fortunately, this is not a big concern though, as scikit-learn’s definition of feature importance is non-standard and differs from what Leo Breiman proposed in the original Random Forest paper. Instead, we can use permutation to estimate feature importance, which is preferred over scikit-learn’s definition. This is implemented in the ELI5 library, which is fully compatible with scikit-survival.

[10]:
import eli5
from eli5.sklearn import PermutationImportance

perm = PermutationImportance(rsf, n_iter=15, random_state=random_state)
perm.fit(X_test, y_test)
eli5.show_weights(perm, feature_names=feature_names)
[10]:
Weight Feature
0.0676 ± 0.0229 pnodes
0.0206 ± 0.0139 age
0.0177 ± 0.0468 progrec
0.0086 ± 0.0098 horTh
0.0032 ± 0.0198 tsize
0.0032 ± 0.0060 tgrade
-0.0007 ± 0.0018 menostat
-0.0063 ± 0.0207 estrec

The result shows that the number of positive lymph nodes (pnodes) is by far the most important feature. If its relationship to survival time is removed (by random shuffling), the concordance index on the test data drops on average by 0.0676 points. Again, this agrees with the results from the original Random Survival Forests paper.