# sksurv.tree.SurvivalTree#

class sksurv.tree.SurvivalTree(splitter='best', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None)[source]#

A survival tree.

The quality of a split is measured by the log-rank splitting rule.

See 1, 2 and 3 for further description.

Parameters
• splitter (string, optional, default: "best") – The strategy used to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split.

• max_depth (int or None, optional, default: None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.

• min_samples_split (int, float, optional, default: 6) –

The minimum number of samples required to split an internal node:

• If int, then consider min_samples_split as the minimum number.

• If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.

• min_samples_leaf (int, float, optional, default: 3) –

The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least min_samples_leaf training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.

• If int, then consider min_samples_leaf as the minimum number.

• If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.

• min_weight_fraction_leaf (float, optional, default: 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.

• max_features (int, float, string or None, optional, default: None) –

The number of features to consider when looking for the best split:

• If int, then consider max_features features at each split.

• If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.

• If “auto”, then max_features=sqrt(n_features).

• If “sqrt”, then max_features=sqrt(n_features).

• If “log2”, then max_features=log2(n_features).

• If None, then max_features=n_features.

Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than max_features features.

• random_state (int, RandomState instance or None, optional, default: None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.

• max_leaf_nodes (int or None, optional, default: None) – Grow a tree with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.

event_times_#

Unique time points where events occurred.

Type

array of shape = (n_event_times,)

max_features_#

The inferred value of max_features.

Type

int,

n_features_in_#

Number of features seen during fit.

Type

int

feature_names_in_#

Names of features seen during fit. Defined only when X has feature names that are all strings.

Type

ndarray of shape (n_features_in_,)

tree_#

The underlying Tree object. Please refer to help(sklearn.tree._tree.Tree) for attributes of Tree object.

Type

Tree object

sksurv.ensemble.RandomSurvivalForest

An ensemble of SurvivalTrees.

References

1

Leblanc, M., & Crowley, J. (1993). Survival Trees by Goodness of Split. Journal of the American Statistical Association, 88(422), 457–467.

2

Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008). Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.

3

Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R. R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.

__init__(splitter='best', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None)[source]#

Methods

 __init__([splitter, max_depth, ...]) apply(X[, check_input]) Return the index of the leaf that each sample is predicted as. decision_path(X[, check_input]) Return the decision path in the tree. fit(X, y[, sample_weight, check_input, ...]) Build a survival tree from the training set (X, y). get_params([deep]) Get parameters for this estimator. predict(X[, check_input]) Predict risk score. Predict cumulative hazard function. predict_survival_function(X[, check_input, ...]) Predict survival function. score(X, y) Returns the concordance index of the prediction. set_params(**params) Set the parameters of this estimator.
apply(X, check_input=True)[source]#

Return the index of the leaf that each sample is predicted as.

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – The input samples. Internally, it will be converted to dtype=np.float32 and if a sparse matrix is provided to a sparse csr_matrix.

• check_input (bool, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns

X_leaves – For each datapoint x in X, return the index of the leaf x ends up in. Leaves are numbered within [0; self.tree_.node_count), possibly with gaps in the numbering.

Return type

array-like, shape = (n_samples,)

decision_path(X, check_input=True)[source]#

Return the decision path in the tree.

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – The input samples. Internally, it will be converted to dtype=np.float32 and if a sparse matrix is provided to a sparse csr_matrix.

• check_input (bool, default=True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns

indicator – Return a node indicator CSR matrix where non zero elements indicates that the samples goes through the nodes.

Return type

sparse matrix, shape = (n_samples, n_nodes)

fit(X, y, sample_weight=None, check_input=True, X_idx_sorted='deprecated')[source]#

Build a survival tree from the training set (X, y).

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix

• y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.

• check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

• X_idx_sorted (deprecated, default="deprecated") – This parameter is deprecated and has no effect

Return type

self

get_params(deep=True)#

Get parameters for this estimator.

Parameters

deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns

params – Parameter names mapped to their values.

Return type

dict

predict(X, check_input=True)[source]#

Predict risk score.

The risk score is the total number of events, which can be estimated by the sum of the estimated cumulative hazard function $$\hat{H}_h$$ in terminal node $$h$$.

$\sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,$

where $$n(h)$$ denotes the number of distinct event times of samples belonging to the same terminal node as $$x$$.

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix.

• check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns

risk_scores – Predicted risk scores.

Return type

ndarray, shape = (n_samples,)

predict_cumulative_hazard_function(X, check_input=True, return_array=False)[source]#

Predict cumulative hazard function.

The cumulative hazard function (CHF) for an individual with feature vector $$x$$ is computed from all samples of the training data that are in the same terminal node as $$x$$. It is estimated by the Nelson–Aalen estimator.

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix.

• check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

• return_array (boolean, default: False) – If set, return an array with the cumulative hazard rate for each self.event_times_, otherwise an array of sksurv.functions.StepFunction.

Returns

cum_hazard – If return_array is set, an array with the cumulative hazard rate for each self.event_times_, otherwise an array of length n_samples of sksurv.functions.StepFunction instances will be returned.

Return type

ndarray

Examples

>>> import matplotlib.pyplot as plt
>>> from sksurv.datasets import load_whas500
>>> from sksurv.tree import SurvivalTree


Load and prepare the data.

>>> X, y = load_whas500()
>>> X = X.astype(float)


Fit the model.

>>> estimator = SurvivalTree().fit(X, y)


Estimate the cumulative hazard function for the first 5 samples.

>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])


Plot the estimated cumulative hazard functions.

>>> for fn in chf_funcs:
...    plt.step(fn.x, fn(fn.x), where="post")
...
>>> plt.ylim(0, 1)
>>> plt.show()

predict_survival_function(X, check_input=True, return_array=False)[source]#

Predict survival function.

The survival function for an individual with feature vector $$x$$ is computed from all samples of the training data that are in the same terminal node as $$x$$. It is estimated by the Kaplan-Meier estimator.

Parameters
• X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix.

• check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

• return_array (boolean, default: False) – If set, return an array with the probability of survival for each self.event_times_, otherwise an array of sksurv.functions.StepFunction.

Returns

survival – If return_array is set, an array with the probability of survival for each self.event_times_, otherwise an array of length n_samples of sksurv.functions.StepFunction instances will be returned.

Return type

ndarray

Examples

>>> import matplotlib.pyplot as plt
>>> from sksurv.datasets import load_whas500
>>> from sksurv.tree import SurvivalTree


Load and prepare the data.

>>> X, y = load_whas500()
>>> X = X.astype(float)


Fit the model.

>>> estimator = SurvivalTree().fit(X, y)


Estimate the survival function for the first 5 samples.

>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])


Plot the estimated survival functions.

>>> for fn in surv_funcs:
...    plt.step(fn.x, fn(fn.x), where="post")
...
>>> plt.ylim(0, 1)
>>> plt.show()

score(X, y)[source]#

Returns the concordance index of the prediction.

Parameters
• X (array-like, shape = (n_samples, n_features)) – Test samples.

• y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.

Returns

cindex – Estimated concordance index.

Return type

float

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters

**params (dict) – Estimator parameters.

Returns

self – Estimator instance.

Return type

estimator instance