sksurv.tree.ExtraSurvivalTree#

class sksurv.tree.ExtraSurvivalTree(*, splitter='random', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, low_memory=False)[source]#

An Extremely Randomized Survival Tree.

This class implements an Extremely Randomized Tree for survival analysis. It differs from SurvivalTree in how splits are chosen: instead of searching for the optimal split, it considers a random subset of features and random thresholds for each feature, then picks the best among these random candidates.

Parameters:
  • splitter ({'best', 'random'}, default: 'random') – The strategy used to choose the split at each node. Supported strategies are ‘best’ to choose the best split and ‘random’ to choose the best random split.

  • max_depth (int or None, optional, default: None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.

  • min_samples_split (int, float, optional, default: 6) –

    The minimum number of samples required to split an internal node:

    • If int, then consider min_samples_split as the minimum number.

    • If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.

  • min_samples_leaf (int, float, optional, default: 3) –

    The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least min_samples_leaf training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.

    • If int, then consider min_samples_leaf as the minimum number.

    • If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.

  • min_weight_fraction_leaf (float, optional, default: 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.

  • max_features (int, float or {'sqrt', 'log2'} or None, optional, default: None) –

    The number of features to consider when looking for the best split:

    • If int, then consider max_features features at each split.

    • If float, then max_features is a fraction and max(1, int(max_features * n_features_in_)) features are considered at each split.

    • If “sqrt”, then max_features=sqrt(n_features).

    • If “log2”, then max_features=log2(n_features).

    • If None, then max_features=n_features.

    Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than max_features features.

  • random_state (int, RandomState instance or None, optional, default: None) – Controls the randomness of the estimator. The features are always randomly permuted at each split, even if splitter is set to "best". When max_features < n_features, the algorithm will select max_features at random at each split before finding the best split among them. But the best found split may vary across different runs, even if max_features=n_features. That is the case, if the improvement of the criterion is identical for several splits and one split has to be selected at random. To obtain a deterministic behavior during fitting, random_state has to be fixed to an integer.

  • max_leaf_nodes (int or None, optional, default: None) – Grow a tree with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.

  • low_memory (bool, optional, default: False) – If set, predict() computations use reduced memory but predict_cumulative_hazard_function() and predict_survival_function() are not implemented.

unique_times_#

Unique time points.

Type:

ndarray, shape = (n_unique_times,), dtype = float

max_features_#

The inferred value of max_features.

Type:

int

n_features_in_#

Number of features seen during fit.

Type:

int

feature_names_in_#

Names of features seen during fit. Defined only when X has feature names that are all strings.

Type:

ndarray, shape = (n_features_in_,), dtype = object

tree_#

The underlying Tree object. Please refer to help(sklearn.tree._tree.Tree) for attributes of Tree object.

Type:

Tree object

See also

sksurv.ensemble.ExtraSurvivalTrees

An ensemble of ExtraSurvivalTrees.

__init__(*, splitter='random', max_depth=None, min_samples_split=6, min_samples_leaf=3, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, low_memory=False)[source]#

Methods

__init__(*[, splitter, max_depth, ...])

apply(X[, check_input])

Return the index of the leaf that each sample is predicted as.

decision_path(X[, check_input])

Return the decision path in the tree.

fit(X, y[, sample_weight, check_input])

Build a survival tree from the training set (X, y).

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

Get parameters for this estimator.

predict(X[, check_input])

Predict risk score.

predict_cumulative_hazard_function(X[, ...])

Predict cumulative hazard function.

predict_survival_function(X[, check_input, ...])

Predict survival function.

score(X, y)

Returns the concordance index of the prediction.

set_fit_request(*[, check_input, sample_weight])

Configure whether metadata should be requested to be passed to the fit method.

set_params(**params)

Set the parameters of this estimator.

set_predict_request(*[, check_input])

Configure whether metadata should be requested to be passed to the predict method.

Attributes

criterion

apply(X, check_input=True)[source]#

Return the index of the leaf that each sample is predicted as.

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – The input samples. Internally, it will be converted to dtype=np.float32 and if a sparse matrix is provided to a sparse csr_matrix. If splitter='best', X is allowed to contain missing values and decisions are made as described in Missing Values Support.

  • check_input (bool, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns:

X_leaves – For each datapoint x in X, return the index of the leaf x ends up in. Leaves are numbered within [0; self.tree_.node_count), possibly with gaps in the numbering.

Return type:

ndarray, shape = (n_samples,), dtype=int

decision_path(X, check_input=True)[source]#

Return the decision path in the tree.

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – The input samples. Internally, it will be converted to dtype=np.float32 and if a sparse matrix is provided to a sparse csr_matrix. If splitter='best', X is allowed to contain missing values and decisions are made as described in Missing Values Support.

  • check_input (bool, default=True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns:

indicator – Return a node indicator CSR matrix where non zero elements indicates that the samples goes through the nodes.

Return type:

sparse matrix, shape = (n_samples, n_nodes)

fit(X, y, sample_weight=None, check_input=True)[source]#

Build a survival tree from the training set (X, y).

If splitter='best', X is allowed to contain missing values. In addition to evaluating each potential threshold on the non-missing data, the splitter will evaluate the split with all the missing values going to the left node or the right node. See Missing Values Support for details.

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix

  • y (structured array, shape = (n_samples,)) – A structured array with two fields. The first field is a boolean where True indicates an event and False indicates right-censoring. The second field is a float with the time of event or time of censoring.

  • check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Return type:

self

get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

Get parameters for this estimator.

Parameters:

deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:

params – Parameter names mapped to their values.

Return type:

dict

predict(X, check_input=True)[source]#

Predict risk score.

The risk score is the total number of events, which can be estimated by the sum of the estimated cumulative hazard function \(\hat{H}_h\) in terminal node \(h\).

\[\sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,\]

where \(n(h)\) denotes the number of distinct event times of samples belonging to the same terminal node as \(x\).

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix. If splitter='best', X is allowed to contain missing values and decisions are made as described in Missing Values Support.

  • check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

Returns:

risk_scores – Predicted risk scores.

Return type:

ndarray, shape = (n_samples,), dtype=float

predict_cumulative_hazard_function(X, check_input=True, return_array=False)[source]#

Predict cumulative hazard function.

The cumulative hazard function (CHF) for an individual with feature vector \(x\) is computed from all samples of the training data that are in the same terminal node as \(x\). It is estimated by the Nelson–Aalen estimator.

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix. If splitter='best', X is allowed to contain missing values and decisions are made as described in Missing Values Support.

  • check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

  • return_array (bool, default: False) –

    Whether to return a single array of cumulative hazard values or a list of step functions.

    If False, a list of sksurv.functions.StepFunction objects is returned.

    If True, a 2d-array of shape (n_samples, n_unique_times) is returned, where n_unique_times is the number of unique event times in the training data. Each row represents the cumulative hazard function of an individual evaluated at unique_times_.

Returns:

cum_hazard – If return_array is False, an array of n_samples sksurv.functions.StepFunction instances is returned.

If return_array is True, a numeric array of shape (n_samples, n_unique_times_) is returned.

Return type:

ndarray

Examples

>>> import matplotlib.pyplot as plt
>>> from sksurv.datasets import load_veterans_lung_cancer
>>> from sksurv.preprocessing import OneHotEncoder
>>> from sksurv.tree import ExtraSurvivalTree

Load the data and encode categorical features.

>>> X, y = load_veterans_lung_cancer()
>>> Xt = OneHotEncoder().fit_transform(X)

Fit the model.

>>> estimator = ExtraSurvivalTree().fit(Xt, y)

Estimate the cumulative hazard function for the first 10 samples.

>>> chf_funcs = estimator.predict_cumulative_hazard_function(Xt.iloc[:10])

Plot the estimated cumulative hazard functions.

>>> for fn in chf_funcs:
...     plt.step(fn.x, fn(fn.x), where="post")
...
[...]
>>> plt.show()  
../../_images/sksurv-tree-ExtraSurvivalTree-1.png
predict_survival_function(X, check_input=True, return_array=False)[source]#

Predict survival function.

The survival function for an individual with feature vector \(x\) is computed from all samples of the training data that are in the same terminal node as \(x\). It is estimated by the Kaplan-Meier estimator.

Parameters:
  • X (array-like or sparse matrix, shape = (n_samples, n_features)) – Data matrix. If splitter='best', X is allowed to contain missing values and decisions are made as described in Missing Values Support.

  • check_input (boolean, default: True) – Allow to bypass several input checking. Don’t use this parameter unless you know what you do.

  • return_array (bool, default: False) –

    Whether to return a single array of survival probabilities or a list of step functions.

    If False, a list of sksurv.functions.StepFunction objects is returned.

    If True, a 2d-array of shape (n_samples, n_unique_times) is returned, where n_unique_times is the number of unique event times in the training data. Each row represents the survival function of an individual evaluated at unique_times_.

Returns:

survival – If return_array is False, an array of n_samples sksurv.functions.StepFunction instances is returned.

If return_array is True, a numeric array of shape (n_samples, n_unique_times_) is returned.

Return type:

ndarray

Examples

>>> import matplotlib.pyplot as plt
>>> from sksurv.datasets import load_veterans_lung_cancer
>>> from sksurv.preprocessing import OneHotEncoder
>>> from sksurv.tree import ExtraSurvivalTree

Load the data and encode categorical features.

>>> X, y = load_veterans_lung_cancer()
>>> Xt = OneHotEncoder().fit_transform(X)

Fit the model.

>>> estimator = ExtraSurvivalTree().fit(Xt, y)

Estimate the survival function for the first 10 samples.

>>> surv_funcs = estimator.predict_survival_function(Xt.iloc[:10])

Plot the estimated survival functions.

>>> for fn in surv_funcs:
...     plt.step(fn.x, fn(fn.x), where="post")
...
[...]
>>> plt.ylim(0, 1)
(0.0, 1.0)
>>> plt.show()  
../../_images/sksurv-tree-ExtraSurvivalTree-2.png
score(X, y)[source]#

Returns the concordance index of the prediction.

Parameters:
  • X (array-like, shape = (n_samples, n_features)) – Test samples.

  • y (structured array, shape = (n_samples,)) – A structured array containing the binary event indicator as first field, and time of event or time of censoring as second field.

Returns:

cindex – Estimated concordance index.

Return type:

float

See also

sksurv.metrics.concordance_index_censored

Computes the concordance index.

set_fit_request(*, check_input: bool | None | str = '$UNCHANGED$', sample_weight: bool | None | str = '$UNCHANGED$') ExtraSurvivalTree#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • check_input (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for check_input parameter in fit.

  • sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in fit.

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:

**params (dict) – Estimator parameters.

Returns:

self – Estimator instance.

Return type:

estimator instance

set_predict_request(*, check_input: bool | None | str = '$UNCHANGED$') ExtraSurvivalTree#

Configure whether metadata should be requested to be passed to the predict method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:

check_input (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for check_input parameter in predict.

Returns:

self – The updated object.

Return type:

object