sksurv.tree.SurvivalTree¶
-
class
sksurv.tree.
SurvivalTree
(splitter='best', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, presort='deprecated')[source]¶ A survival tree.
The quality of a split is measured by the log-rank splitting rule.
See [1], [2] and [3] for further description.
Parameters: - splitter (string, optional, default: "best") – The strategy used to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split.
- max_depth (int or None, optional, default: None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
- min_samples_split (int, float, optional, default: 6) –
The minimum number of samples required to split an internal node:
- If int, then consider min_samples_split as the minimum number.
- If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
- min_samples_leaf (int, float, optional, default: 3) –
The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaf
training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.- If int, then consider min_samples_leaf as the minimum number.
- If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
- min_weight_fraction_leaf (float, optional, default: 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.
- max_features (int, float, string or None, optional, default: None) –
The number of features to consider when looking for the best split:
- If int, then consider max_features features at each split.
- If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
- If “auto”, then max_features=sqrt(n_features).
- If “sqrt”, then max_features=sqrt(n_features).
- If “log2”, then max_features=log2(n_features).
- If None, then max_features=n_features.
Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than
max_features
features. - random_state (int, RandomState instance or None, optional, default: None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.
- max_leaf_nodes (int or None, optional, default: None) – Grow a tree with
max_leaf_nodes
in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. - presort (deprecated, optional, default: 'deprecated') – This parameter is deprecated and will be removed in a future version.
-
event_times_
¶ Unique time points where events occurred.
Type: array of shape = (n_event_times,)
-
max_features_
¶ The inferred value of max_features.
Type: int,
-
n_features_
¶ The number of features when
fit
is performed.Type: int
-
tree_
¶ The underlying Tree object. Please refer to
help(sklearn.tree._tree.Tree)
for attributes of Tree object.Type: Tree object
See also
sksurv.ensemble.RandomSurvivalForest
- An ensemble of SurvivalTrees.
References
[1] Leblanc, M., & Crowley, J. (1993). Survival Trees by Goodness of Split. Journal of the American Statistical Association, 88(422), 457–467. [2] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008). Random survival forests. The Annals of Applied Statistics, 2(3), 841–860. [3] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R. R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf. -
__init__
(splitter='best', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, presort='deprecated')[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
([splitter, max_depth, …])Initialize self. fit
(X, y[, sample_weight, check_input, …])Build a survival tree from the training set (X, y). predict
(X[, check_input])Predict risk score. predict_cumulative_hazard_function
(X[, …])Predict cumulative hazard function. predict_survival_function
(X[, check_input, …])Predict survival function. score
(X, y)Returns the concordance index of the prediction. -
fit
(X, y, sample_weight=None, check_input=True, X_idx_sorted=None)[source]¶ Build a survival tree from the training set (X, y).
Parameters: - X (array-like, shape = (n_samples, n_features)) – Data matrix
- y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
- check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.
- X_idx_sorted (array-like, shape = (n_samples, n_features), optional) – The indexes of the sorted training input samples. If many tree are grown on the same dataset, this allows the ordering to be cached between trees. If None, the data will be sorted here. Don’t use this parameter unless you know what to do.
Returns: Return type: self
-
predict
(X, check_input=True)[source]¶ Predict risk score.
The risk score is the total number of events, which can be estimated by the sum of the estimated cumulative hazard function \(\hat{H}_h\) in terminal node \(h\).
\[\sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,\]where \(n(h)\) denotes the number of distinct event times of samples belonging to the same terminal node as \(x\).
Parameters: - X (array-like, shape = (n_samples, n_features)) – Data matrix.
- check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.
Returns: risk_scores – Predicted risk scores.
Return type: ndarray, shape = (n_samples,)
-
predict_cumulative_hazard_function
(X, check_input=True, return_array='warn')[source]¶ Predict cumulative hazard function.
The cumulative hazard function (CHF) for an individual with feature vector \(x\) is computed from all samples of the training data that are in the same terminal node as \(x\). It is estimated by the Nelson–Aalen estimator.
Parameters: - X (array-like, shape = (n_samples, n_features)) – Data matrix.
- check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.
- return_array (boolean) – If set, return an array with the cumulative hazard rate
for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
.
Returns: cum_hazard – If return_array is set, an array with the cumulative hazard rate for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
will be returned.Return type: ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_whas500 >>> from sksurv.tree import SurvivalTree
Load and prepare the data.
>>> X, y = load_whas500() >>> X = X.astype(float)
Fit the model.
>>> estimator = SurvivalTree().fit(X, y)
Estimate the cumulative hazard function for the first 5 samples.
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5], return_array=False)
Plot the estimated cumulative hazard functions.
>>> for fn in chf_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... >>> plt.ylim(0, 1) >>> plt.show()
-
predict_survival_function
(X, check_input=True, return_array='warn')[source]¶ Predict survival function.
The survival function for an individual with feature vector \(x\) is computed from all samples of the training data that are in the same terminal node as \(x\). It is estimated by the Kaplan-Meier estimator.
Parameters: - X (array-like, shape = (n_samples, n_features)) – Data matrix.
- check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.
- return_array (boolean) – If set, return an array with the probability
of survival for each self.event_times_,
otherwise an array of
sksurv.functions.StepFunction
.
Returns: survival – If return_array is set, an array with the probability of survival for each self.event_times_, otherwise an array of
sksurv.functions.StepFunction
will be returned.Return type: ndarray
Examples
>>> import matplotlib.pyplot as plt >>> from sksurv.datasets import load_whas500 >>> from sksurv.tree import SurvivalTree
Load and prepare the data.
>>> X, y = load_whas500() >>> X = X.astype(float)
Fit the model.
>>> estimator = SurvivalTree().fit(X, y)
Estimate the survival function for the first 5 samples.
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5], return_array=False)
Plot the estimated survival functions.
>>> for fn in surv_funcs: ... plt.step(fn.x, fn(fn.x), where="post") ... >>> plt.ylim(0, 1) >>> plt.show()
-
score
(X, y)[source]¶ Returns the concordance index of the prediction.
Parameters: - X (array-like, shape = (n_samples, n_features)) – Test samples.
- y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.
Returns: cindex – Estimated concordance index.
Return type: float