sksurv.ensemble.RandomSurvivalForest#
- class sksurv.ensemble.RandomSurvivalForest(n_estimators=100, *, max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features='sqrt', max_leaf_nodes=None, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, max_samples=None, low_memory=False)[source]#
A random survival forest.
A random survival forest is a meta estimator that fits a number of survival trees on various sub-samples of the dataset and uses averaging to improve the predictive accuracy and control over-fitting. The sub-sample size is always the same as the original input sample size but the samples are drawn with replacement if bootstrap=True (default).
In each survival tree, the quality of a split is measured by the log-rank splitting rule.
See the User Guide, [1] and [2] for further description.
- Parameters:
n_estimators (int, optional, default: 100) – The number of trees in the forest.
max_depth (int or None, optional, default: None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split (int, float, optional, default: 6) –
The minimum number of samples required to split an internal node:
If int, then consider min_samples_split as the minimum number.
If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
min_samples_leaf (int, float, optional, default: 3) –
The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaftraining samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.If int, then consider min_samples_leaf as the minimum number.
If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
min_weight_fraction_leaf (float, optional, default: 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.
max_features (int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt') –
The number of features to consider when looking for the best split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features.
Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than
max_featuresfeatures.max_leaf_nodes (int or None, optional, default: None) – Grow a tree with
max_leaf_nodesin best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.bootstrap (bool, optional, default: True) – Whether bootstrap samples are used when building trees. If False, the whole dataset is used to build each tree.
oob_score (bool, optional, default: False) – Whether to use out-of-bag samples to estimate the generalization accuracy.
n_jobs (int or None, optional, default: None) – The number of jobs to run in parallel.
fit(),predict(),decision_path()andapply()are all parallelized over the trees.Nonemeans 1 unless in ajoblib.parallel_backendcontext.-1means using all processors.random_state (int, RandomState instance or None, optional, default: None) – Controls both the randomness of the bootstrapping of the samples used when building trees (if
bootstrap=True) and the sampling of the features to consider when looking for the best split at each node (ifmax_features < n_features).verbose (int, optional, default: 0) – Controls the verbosity when fitting and predicting.
warm_start (bool, optional, default: False) – When set to
True, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new forest.max_samples (int or float, optional, default: None) –
If bootstrap is True, the number of samples to draw from X to train each base estimator.
If None (default), then draw X.shape[0] samples.
If int, then draw max_samples samples.
If float, then draw max_samples * X.shape[0] samples. Thus, max_samples should be in the interval (0.0, 1.0].
low_memory (bool, optional, default: False) – If set,
predict()computations use reduced memory butpredict_cumulative_hazard_function()andpredict_survival_function()are not implemented.
- estimators_#
The collection of fitted sub-estimators.
- Type:
list of SurvivalTree instances
- unique_times_#
Unique time points.
- Type:
ndarray, shape = (n_unique_times,)
- n_features_in_#
Number of features seen during
fit.- Type:
int
- feature_names_in_#
Names of features seen during
fit. Defined only when X has feature names that are all strings.- Type:
ndarray, shape = (n_features_in_,)
- oob_score_#
Concordance index of the training dataset obtained using an out-of-bag estimate.
- Type:
float
See also
sksurv.tree.SurvivalTreeA single survival tree.
Notes
The default values for the parameters controlling the size of the trees (e.g.
max_depth,min_samples_leaf, etc.) lead to fully grown and unpruned trees which can potentially be very large on some data sets. To reduce memory consumption, the complexity and size of the trees should be controlled by setting those parameter values.Compared to scikit-learn’s random forest models,
RandomSurvivalForestcurrently does not support controlling the depth of a tree based on the log-rank test statistics or it’s associated p-value, i.e., the parameters min_impurity_decrease or min_impurity_split are absent. In addition, the feature_importances_ attribute is not available. It is recommended to estimate feature importances viasklearn.inspection.permutation_importance().The features are always randomly permuted at each split. Therefore, the best found split may vary, even with the same training data,
max_features=n_featuresandbootstrap=False, if the improvement of the criterion is identical for several splits enumerated during the search of the best split. To obtain a deterministic behavior during fitting,random_statehas to be fixed.References
- __init__(n_estimators=100, *, max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features='sqrt', max_leaf_nodes=None, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, max_samples=None, low_memory=False)[source]#
Methods
__init__([n_estimators, max_depth, ...])apply(X)Apply trees in the forest to X, return leaf indices.
Return the decision path in the forest.
fit(X, y[, sample_weight])Build a forest of survival trees from the training set (X, y).
Get metadata routing of this object.
get_params([deep])Get parameters for this estimator.
predict(X)Predict risk score.
predict_cumulative_hazard_function(X[, ...])Predict cumulative hazard function.
predict_survival_function(X[, return_array])Predict survival function.
score(X, y)Returns the concordance index of the prediction.
set_fit_request(*[, sample_weight])Configure whether metadata should be requested to be passed to the
fitmethod.set_params(**params)Set the parameters of this estimator.
Attributes
The subset of drawn samples for each base estimator.
Not implemented
- apply(X)#
Apply trees in the forest to X, return leaf indices.
- Parameters:
X ({array-like, sparse matrix} of shape (n_samples, n_features)) – The input samples. Internally, its dtype will be converted to
dtype=np.float32. If a sparse matrix is provided, it will be converted into a sparsecsr_matrix.- Returns:
X_leaves – For each datapoint x in X and for each tree in the forest, return the index of the leaf x ends up in.
- Return type:
ndarray of shape (n_samples, n_estimators)
- decision_path(X)#
Return the decision path in the forest.
Added in version 0.18.
- Parameters:
X ({array-like, sparse matrix} of shape (n_samples, n_features)) – The input samples. Internally, its dtype will be converted to
dtype=np.float32. If a sparse matrix is provided, it will be converted into a sparsecsr_matrix.- Returns:
indicator (sparse matrix of shape (n_samples, n_nodes)) – Return a node indicator matrix where non zero elements indicates that the samples goes through the nodes. The matrix is of CSR format.
n_nodes_ptr (ndarray of shape (n_estimators + 1,)) – The columns from indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]] gives the indicator value for the i-th estimator.
- property estimators_samples_#
The subset of drawn samples for each base estimator.
Returns a dynamically generated list of indices identifying the samples used for fitting each member of the ensemble, i.e., the in-bag samples.
Note: the list is re-created at each call to the property in order to reduce the object memory footprint by not storing the sampling data. Thus fetching the property may be slower than expected.
- property feature_importances_#
Not implemented
- fit(X, y, sample_weight=None)[source]#
Build a forest of survival trees from the training set (X, y).
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix
y (structured array, shape = (n_samples,)) – A structured array with two fields. The first field is a boolean where
Trueindicates an event andFalseindicates right-censoring. The second field is a float with the time of event or time of censoring.
- Return type:
self
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
Get parameters for this estimator.
- Parameters:
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
params – Parameter names mapped to their values.
- Return type:
dict
- predict(X)[source]#
Predict risk score.
The ensemble risk score is the total number of events, which can be estimated by the sum of the estimated ensemble cumulative hazard function \(\hat{H}_e\).
\[\sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,\]where \(n\) denotes the total number of distinct event times in the training data.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix.
- Returns:
risk_scores – Predicted risk scores.
- Return type:
ndarray, shape = (n_samples,)
- predict_cumulative_hazard_function(X, return_array=False)[source]#
Predict cumulative hazard function.
For each tree in the ensemble, the cumulative hazard function (CHF) for an individual with feature vector \(x\) is computed from all samples of the bootstrap sample that are in the same terminal node as \(x\). It is estimated by the Nelson–Aalen estimator. The ensemble CHF at time \(t\) is the average value across all trees in the ensemble at the specified time point.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix.
return_array (bool, default: False) –
Whether to return a single array of cumulative hazard values or a list of step functions.
If False, a list of
sksurv.functions.StepFunctionobjects is returned.If True, a 2d-array of shape (n_samples, n_unique_times) is returned, where n_unique_times is the number of unique event times in the training data. Each row represents the cumulative hazard function of an individual evaluated at unique_times_.
- Returns:
cum_hazard – If return_array is False, an array of n_samples
sksurv.functions.StepFunctioninstances is returned.If return_array is True, a numeric array of shape (n_samples, n_unique_times_) is returned.
- Return type:
ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_veterans_lung_cancer >>> from sksurv.preprocessing import OneHotEncoder >>> from sksurv.ensemble import RandomSurvivalForest
Load the data and encode categorical features.
>>> X, y = load_veterans_lung_cancer() >>> Xt = OneHotEncoder().fit_transform(X)
Fit the model.
>>> estimator = RandomSurvivalForest().fit(Xt, y)
Estimate the cumulative hazard function for the first 10 samples.
>>> chf_funcs = estimator.predict_cumulative_hazard_function(Xt.iloc[:10])
Plot the estimated cumulative hazard functions.
>>> for fn in chf_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... [...] >>> plt.show()
- predict_survival_function(X, return_array=False)[source]#
Predict survival function.
For each tree in the ensemble, the survival function for an individual with feature vector \(x\) is computed from all samples of the bootstrap sample that are in the same terminal node as \(x\). It is estimated by the Kaplan-Meier estimator. The ensemble survival function at time \(t\) is the average value across all trees in the ensemble at the specified time point.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Data matrix.
return_array (bool, default: False) –
Whether to return a single array of survival probabilities or a list of step functions.
If False, a list of
sksurv.functions.StepFunctionobjects is returned.If True, a 2d-array of shape (n_samples, n_unique_times) is returned, where n_unique_times is the number of unique event times in the training data. Each row represents the survival function of an individual evaluated at unique_times_.
- Returns:
survival – If return_array is False, an array of n_samples
sksurv.functions.StepFunctioninstances is returned.If return_array is True, a numeric array of shape (n_samples, n_unique_times_) is returned.
- Return type:
ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_veterans_lung_cancer >>> from sksurv.preprocessing import OneHotEncoder >>> from sksurv.ensemble import RandomSurvivalForest
Load the data and encode categorical features.
>>> X, y = load_veterans_lung_cancer() >>> Xt = OneHotEncoder().fit_transform(X)
Fit the model.
>>> estimator = RandomSurvivalForest().fit(Xt, y)
Estimate the survival function for the first 10 samples.
>>> surv_funcs = estimator.predict_survival_function(Xt.iloc[:10])
Plot the estimated survival functions.
>>> for fn in surv_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... [...] >>> plt.ylim(0, 1) (0.0, 1.0) >>> plt.show()
- score(X, y)[source]#
Returns the concordance index of the prediction.
- Parameters:
X (array-like, shape = (n_samples, n_features)) – Test samples.
y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
- Returns:
cindex – Estimated concordance index.
- Return type:
float
See also
sksurv.metrics.concordance_index_censoredComputes the concordance index.
- set_fit_request(*, sample_weight: bool | None | str = '$UNCHANGED$') RandomSurvivalForest#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
sample_weightparameter infit.- Returns:
self – The updated object.
- Return type:
object
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
**params (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
estimator instance