sksurv.ensemble.ExtraSurvivalTrees#
- class sksurv.ensemble.ExtraSurvivalTrees(n_estimators=100, max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features='auto', max_leaf_nodes=None, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, max_samples=None)[source]#
An extremely random survival forest.
This class implements a meta estimator that fits a number of randomized survival trees (a.k.a. extra-trees) on various sub-samples of the dataset and uses averaging to improve the predictive accuracy and control over-fitting. The sub-sample size is always the same as the original input sample size but the samples are drawn with replacement if bootstrap=True (default).
In each randomized survival tree, the quality of a split is measured by the log-rank splitting rule.
Compared to
RandomSurvivalForest
, randomness goes one step further in the way splits are computed. As inRandomSurvivalForest
, a random subset of candidate features is used, but instead of looking for the most discriminative thresholds, thresholds are drawn at random for each candidate feature and the best of these randomly-generated thresholds is picked as the splitting rule.- Parameters
n_estimators (integer, optional, default: 100) – The number of trees in the forest.
max_depth (int or None, optional, default: None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split (int, float, optional, default: 6) –
The minimum number of samples required to split an internal node:
If int, then consider min_samples_split as the minimum number.
If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
min_samples_leaf (int, float, optional, default: 3) –
The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaf
training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.If int, then consider min_samples_leaf as the minimum number.
If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
min_weight_fraction_leaf (float, optional, default: 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.
max_features (int, float, string or None, optional, default: None) –
The number of features to consider when looking for the best split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
If “auto”, then max_features=sqrt(n_features).
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features.
Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than
max_features
features.max_leaf_nodes (int or None, optional, default: None) – Grow a tree with
max_leaf_nodes
in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.bootstrap (boolean, optional, default: True) – Whether bootstrap samples are used when building trees. If False, the whole datset is used to build each tree.
oob_score (bool, default: False) – Whether to use out-of-bag samples to estimate the generalization accuracy.
n_jobs (int or None, optional (default=None)) – The number of jobs to run in parallel for both fit and predict.
None
means 1 unless in ajoblib.parallel_backend
context.-1
means using all processors.random_state (int, RandomState instance or None, optional, default: None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.
verbose (int, optional, default: 0) – Controls the verbosity when fitting and predicting.
warm_start (bool, optional, default: False) – When set to
True
, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new forest.max_samples (int or float, optional, default: None) – If bootstrap is True, the number of samples to draw from X to train each base estimator. - If None (default), then draw X.shape[0] samples. - If int, then draw max_samples samples. - If float, then draw max_samples * X.shape[0] samples. Thus, max_samples should be in the interval (0.0, 1.0].
- estimators_#
The collection of fitted sub-estimators.
- Type
list of SurvivalTree instances
- event_times_#
Unique time points where events occurred.
- Type
array of shape = (n_event_times,)
- n_features_in_#
The number of features when
fit
is performed.- Type
int
- feature_names_in_#
Names of features seen during
fit
. Defined only when X has feature names that are all strings.- Type
ndarray of shape (n_features_in_,)
- oob_score_#
Concordance index of the training dataset obtained using an out-of-bag estimate.
- Type
float
See also
sksurv.tree.SurvivalTree
A single survival tree.
- __init__(n_estimators=100, max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features='auto', max_leaf_nodes=None, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, max_samples=None)[source]#
Methods
__init__
([n_estimators, max_depth, ...])apply
(X)Apply trees in the forest to X, return leaf indices.
Return the decision path in the forest.
fit
(X, y[, sample_weight])Build a forest of survival trees from the training set (X, y).
get_params
([deep])Get parameters for this estimator.
predict
(X)Predict risk score.
predict_cumulative_hazard_function
(X[, ...])Predict cumulative hazard function.
predict_survival_function
(X[, return_array])Predict survival function.
score
(X, y)Returns the concordance index of the prediction.
set_params
(**params)Set the parameters of this estimator.
Attributes
Not implemented
Attribute n_features_ was deprecated in version 1.0 and will be removed in 1.2.
- apply(X)#
Apply trees in the forest to X, return leaf indices.
- Parameters
X ({array-like, sparse matrix} of shape (n_samples, n_features)) – The input samples. Internally, its dtype will be converted to
dtype=np.float32
. If a sparse matrix is provided, it will be converted into a sparsecsr_matrix
.- Returns
X_leaves – For each datapoint x in X and for each tree in the forest, return the index of the leaf x ends up in.
- Return type
ndarray of shape (n_samples, n_estimators)
- decision_path(X)#
Return the decision path in the forest.
New in version 0.18.
- Parameters
X ({array-like, sparse matrix} of shape (n_samples, n_features)) – The input samples. Internally, its dtype will be converted to
dtype=np.float32
. If a sparse matrix is provided, it will be converted into a sparsecsr_matrix
.- Returns
indicator (sparse matrix of shape (n_samples, n_nodes)) – Return a node indicator matrix where non zero elements indicates that the samples goes through the nodes. The matrix is of CSR format.
n_nodes_ptr (ndarray of shape (n_estimators + 1,)) – The columns from indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]] gives the indicator value for the i-th estimator.
- property feature_importances_#
Not implemented
- fit(X, y, sample_weight=None)[source]#
Build a forest of survival trees from the training set (X, y).
- Parameters
X (array-like, shape = (n_samples, n_features)) – Data matrix
y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
- Return type
self
- get_params(deep=True)#
Get parameters for this estimator.
- Parameters
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns
params – Parameter names mapped to their values.
- Return type
dict
- property n_features_#
Attribute n_features_ was deprecated in version 1.0 and will be removed in 1.2. Use n_features_in_ instead.
Number of features when fitting the estimator.
- Type
DEPRECATED
- predict(X)[source]#
Predict risk score.
The ensemble risk score is the total number of events, which can be estimated by the sum of the estimated ensemble cumulative hazard function \(\hat{H}_e\).
\[\sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,\]where \(n\) denotes the total number of distinct event times in the training data.
- Parameters
X (array-like, shape = (n_samples, n_features)) – Data matrix.
- Returns
risk_scores – Predicted risk scores.
- Return type
ndarray, shape = (n_samples,)
- predict_cumulative_hazard_function(X, return_array=False)[source]#
Predict cumulative hazard function.
For each tree in the ensemble, the cumulative hazard function (CHF) for an individual with feature vector \(x\) is computed from all samples of the bootstrap sample that are in the same terminal node as \(x\). It is estimated by the Nelson–Aalen estimator. The ensemble CHF at time \(t\) is the average value across all trees in the ensemble at the specified time point.
- Parameters
X (array-like, shape = (n_samples, n_features)) – Data matrix.
return_array (boolean) – If set, return an array with the cumulative hazard rate for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
.
- Returns
cum_hazard – If return_array is set, an array with the cumulative hazard rate for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
will be returned.- Return type
ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_whas500 >>> from sksurv.ensemble import ExtraSurvivalTrees
Load and prepare the data.
>>> X, y = load_whas500() >>> X = X.astype(float)
Fit the model.
>>> estimator = ExtraSurvivalTrees().fit(X, y)
Estimate the cumulative hazard function for the first 5 samples.
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
Plot the estimated cumulative hazard functions.
>>> for fn in chf_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... >>> plt.ylim(0, 1) >>> plt.show()
- predict_survival_function(X, return_array=False)[source]#
Predict survival function.
For each tree in the ensemble, the survival function for an individual with feature vector \(x\) is computed from all samples of the bootstrap sample that are in the same terminal node as \(x\). It is estimated by the Kaplan-Meier estimator. The ensemble survival function at time \(t\) is the average value across all trees in the ensemble at the specified time point.
- Parameters
X (array-like, shape = (n_samples, n_features)) – Data matrix.
return_array (boolean) – If set, return an array with the probability of survival for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
.
- Returns
survival – If return_array is set, an array with the probability of survival for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
will be returned.- Return type
ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_whas500 >>> from sksurv.ensemble import ExtraSurvivalTrees
Load and prepare the data.
>>> X, y = load_whas500() >>> X = X.astype(float)
Fit the model.
>>> estimator = ExtraSurvivalTrees().fit(X, y)
Estimate the survival function for the first 5 samples.
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
Plot the estimated survival functions.
>>> for fn in surv_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... >>> plt.ylim(0, 1) >>> plt.show()
- score(X, y)[source]#
Returns the concordance index of the prediction.
- Parameters
X (array-like, shape = (n_samples, n_features)) – Test samples.
y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
- Returns
cindex – Estimated concordance index.
- Return type
float
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters
**params (dict) – Estimator parameters.
- Returns
self – Estimator instance.
- Return type
estimator instance