Posts | Sebastian Pölsterlhttps://k-d-w.org/post/PostsSource Themes Academic (https://sourcethemes.com/academic/)en-us© Sebastian Pölsterl 2021Sat, 20 Mar 2021 20:58:57 +0100https://k-d-w.org/img/icon-192.pngPostshttps://k-d-w.org/post/- scikit-survival 0.15 Releasedhttps://k-d-w.org/blog/2021/03/scikit-survival-0.15-released/Sat, 20 Mar 2021 20:58:57 +0100https://k-d-w.org/blog/2021/03/scikit-survival-0.15-released/<p>I am proud to announce the release if version 0.15.0 of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a>,
which brings support for scikit-learn 0.24 and Python 3.9.
Moreover, if you fit a gradient boosting model with <code>loss='coxph'</code>,
you can now predict the survival and cumulative hazard function using the
<em>predict_cumulative_hazard_function</em> and <em>predict_survival_function</em> methods.</p>
<p>The other enhancement is that
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/api/generated/sksurv.metrics.cumulative_dynamic_auc.html" target="_blank">cumulative_dynamic_auc</a>
now supports evaluating time-dependent predictions.
For instance, you can now evaluate the predicted time-dependent risk of a
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/api/generated/sksurv.ensemble.RandomSurvivalForest.html" target="_blank">RandomSurvivalForest</a>
rather than just evaluating the predicted total number of events per instance,
which is what
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/api/generated/sksurv.ensemble.RandomSurvivalForest.html#sksurv.ensemble.RandomSurvivalForest.predict" target="_blank">RandomSurvivalForest.predict</a>
returns.</p>
<p>All you have to do is create an array where the columns are the
predictions at the time points you want to evaluate. The snippet
below summarizes the idea:</p>
<pre><code class="language-python">from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import cumulative_dynamic_auc
rsf = RandomSurvivalForest()
rsf.fit(X_train, y_train)
chf_funcs = rsf.predict_cumulative_hazard_function(X_test)
time_points = np.array([30, 60, …])
risk_scores = np.row_stack([
chf(time_points) for chf in chf_funcs
])
aucs, mean_auc = cumulative_dynamic_auc(
y_train, y_test, risk_scores, time_points
)
</code></pre>
<p>For a complete example, please have a look at the
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/user_guide/evaluating-survival-models.html#Using-Time-dependent-Risk-Scores" target="_blank">User Guide</a>.</p>
<p>If you want to know about all changes in scikit-survival 0.15.0, please see the
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/release_notes.html" target="_blank">release notes</a>.</p>
<p>As usual, pre-built conda packages are available for Linux, macOS, and Windows via</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source following
<a href="https://scikit-survival.readthedocs.io/en/v0.15.0/install.html#from-source" target="_blank">these instructions</a>.</p>
- scikit-survival 0.14 with Improved Documentation Releasedhttps://k-d-w.org/blog/2020/10/scikit-survival-0.14-with-improved-documentation-released/Wed, 07 Oct 2020 21:30:24 +0200https://k-d-w.org/blog/2020/10/scikit-survival-0.14-with-improved-documentation-released/<p>Today marks the release of version 0.14.0 of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a>.
The biggest change in this release is actually <em>not</em> in the code, but in the documentation.
This release features a complete overhaul of the <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/index.html" target="_blank">documentation</a>. Most importantly, the documentation has a more modern feel to it, thanks to the visually pleasing <a href="https://github.com/pandas-dev/pydata-sphinx-theme" target="_blank">pydata Sphinx theme</a>, which also powers pandas.</p>
<p>Moreover, the documentation now contains a <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/user_guide/index.html" target="_blank">User Guide</a> section that bundles several topics surrounding the use of scikit-survival. Some of these were available as separate Jupyter notebooks previously, such as the guide on <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/user_guide/evaluating-survival-models.html" target="_blank">Evaluating Survival Models</a>.
There are two new guides: The first one is on <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/user_guide/coxnet.html" target="_blank">penalized Cox models</a>. It provides a hands-on introduction to Cox’s proportional hazards model with $\ell_2$ (Ridge) and $\ell_1$ (LASSO) penalty.
The second guide, is on <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/user_guide/boosting.html" target="_blank">Gradient Boosted Models</a> and covers how gradient boosting can be used to obtain a non-linear proportional hazards model or a non-linear accelerated failure time model by using regression tree base learners. The second part of this guide
covers a variant of gradient boosting that is most suitable for high-dimensional data and is based on component-wise least squares base learners.</p>
<p>To make it easier to get started, all notebooks can now be run in a Jupyter notebook, right from your browser, just by clicking on <a href="https://mybinder.org/v2/gh/sebp/scikit-survival/master?urlpath=lab/tree/notebooks/" target="_blank"><img src="https://mybinder.org/badge_logo.svg" alt="" /></a></p>
<p>In addition to the vastly improved documentation, this release includes important bug fixes. It fixes several bugs in <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/api/generated/sksurv.linear_model.CoxnetSurvivalAnalysis.html#sksurv.linear_model.CoxnetSurvivalAnalysis" target="_blank">CoxnetSurvivalAnalysis</a>, where <code>predict</code>, <code>predict_survival_function</code>, and <code>predict_cumulative_hazard_function</code> returned wrong values if features of the training data were not centered. Moreover, the score function of <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/api/generated/sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis.html#sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis" target="_blank">ComponentwiseGradientBoostingSurvivalAnalysis</a> and <a href="https://scikit-survival.readthedocs.io/en/v0.14.0/api/generated/sksurv.ensemble.GradientBoostingSurvivalAnalysis.html#sksurv.ensemble.GradientBoostingSurvivalAnalysis" target="_blank">GradientBoostingSurvivalAnalysis</a> will now correctly compute the concordance index if <code>loss='ipcwls'</code> or <code>loss='squared'</code>.</p>
<p>For a full list of changes in scikit-survival 0.14.0, please see the
<a href="https://scikit-survival.readthedocs.io/en/v0.14.0/release_notes.html" target="_blank">release notes</a>.</p>
<p>Pre-built conda packages are available for Linux, macOS, and Windows via</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source following
<a href="https://scikit-survival.readthedocs.io/en/v0.14.0/install.html#from-source" target="_blank">these instructions</a>.</p>
- scikit-survival 0.13 Releasedhttps://k-d-w.org/blog/2020/06/scikit-survival-0.13-released/Sun, 28 Jun 2020 18:02:50 +0200https://k-d-w.org/blog/2020/06/scikit-survival-0.13-released/<p>Today, I released version 0.13.0 of
<a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a>.
Most notably, this release adds
<a href="https://scikit-survival.readthedocs.io/en/v0.13.0/generated/sksurv.metrics.brier_score.html#sksurv.metrics.brier_score" target="_blank">sksurv.metrics.brier_score</a> and <a href="https://scikit-survival.readthedocs.io/en/v0.13.0/generated/sksurv.metrics.integrated_brier_score.html#sksurv.metrics.integrated_brier_score" target="_blank">sksurv.metrics.integrated_brier_score</a>,
an updated PEP 517/518 compatible build system,
and support for scikit-learn 0.23.</p>
<p>For a full list of changes in scikit-survival 0.13.0, please see the
<a href="https://scikit-survival.readthedocs.io/en/latest/release_notes.html" target="_blank">release notes</a>.</p>
<p>Pre-built conda packages are available for Linux, macOS, and Windows via</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source following
<a href="https://scikit-survival.readthedocs.io/en/v0.13.0/install.html" target="_blank">these instructions</a>.</p>
<h2 id="the-time-dependent-brier-score">The time-dependent Brier score</h2>
<p>The time-dependent Brier score is an extension of the mean squared
error to <a href="https://scikit-survival.readthedocs.io/en/latest/understanding_predictions.html" target="_blank">right censored data</a>:</p>
<p>$$
\mathrm{BS}^c(t) = \frac{1}{n} \sum_{i=1}^n I(y_i \leq t \land \delta_i = 1)
\frac{(0 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(y_i)} + I(y_i > t)
\frac{(1 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(t)} ,
$$</p>
<p>where $\hat{\pi}(t | \mathbf{x})$ is a model’s predicted probability of
remaining event-free up to time point $t$ for feature vector $\mathbf{x}$, and $1/\hat{G}(t)$ is an inverse probability of censoring weight.</p>
<p>The Brier score is often used to assess calibration.
If a model predicts a 10% risk of experiencing an
event at time $t$, the observed frequency in the data
should match this percentage for a well calibrated model.
In addition, the Brier score is also a measure of
discrimination: whether a model is able to predict risk scores
that allow us to correctly determine the order of events.
The concordance index is probably the most common measure
of discrimination. However, the concordance index disregards
the actual values of predicted risk scores
– it is a ranking metric –
and is unable to tell us anything about calibration.</p>
<p>Let’s consider an example based on data from the
<a href="http://ascopubs.org/doi/abs/10.1200/jco.1994.12.10.2086" target="_blank">German Breast Cancer Study Group 2</a>.</p>
<pre><code class="language-python">from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import encode_categorical
from sklearn.model_selection import train_test_split
X, y = load_gbsg2()
X = encode_categorical(X)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y["cens"], random_state=1)
</code></pre>
<p>We want to train a model on the training data and assess
its discrimination and calibration on the test data.
Here, we consider a <a href="https://k-d-w.org/blog/2019/12/scikit-survival-0.11-featuring-random-survival-forests-released/" target="_blank">Random Survival Forest</a>
and <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.linear_model.CoxnetSurvivalAnalysis.html" target="_blank">Cox’s proportional hazards model with elastic-net penalty</a>.</p>
<pre><code class="language-python">from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxnetSurvivalAnalysis
rsf = RandomSurvivalForest(max_depth=2, random_state=1)
rsf.fit(X_train, y_train)
cph = CoxnetSurvivalAnalysis(l1_ratio=0.99, fit_baseline_model=True)
cph.fit(X_train, y_train)
</code></pre>
<p>First, let’s start with discrimination as measured by the
concordance index.</p>
<pre><code class="language-python">rsf_c = rsf.score(X_test, y_test)
cph_c = cph.score(X_test, y_test)
</code></pre>
<p>The result indicates that both models perform equally well,
achieving a concordance index of 0.688, which is significantly
better than a random model with 0.5 concordance index.
Unfortunately, it doesn’t help us to decide which model we should
choose. So let’s consider the time-dependent Brier score
as an alternative, which asses discrimination <em>and</em> calibration.</p>
<p>We first need to determine for which time points
$t$ we want to compute the Brier score for. We are going to use a data-driven
approach here by selecting all time points between the 10% and 90%
percentile of observed time points.</p>
<pre><code class="language-python">import numpy as np
lower, upper = np.percentile(y["time"], [10, 90])
times = np.arange(lower, upper + 1)
</code></pre>
<p>This returns 1690 time points, for which we need to estimate the
probability of survival for, which is given by the survival function.
Thus, we iterate over the predicted survival functions on the test data
and evaluate each at the time points from above.</p>
<pre><code class="language-python">rsf_surv_prob = np.row_stack([
fn(times)
for fn in rsf.predict_survival_function(X_test, return_array=False)
])
cph_surv_prob = np.row_stack([
fn(times)
for fn in cph.predict_survival_function(X_test)
])
</code></pre>
<p>Note that calling <code>predict_survival_function</code> for RandomSurvivalForest
with <code>return_array=False</code> requires scikit-survival 0.13.</p>
<p>In addition, we want to have a baseline to tell us how much better
our models are from random. A random model would simply predict 0.5
every time.</p>
<pre><code class="language-python">random_surv_prob = 0.5 * np.ones((y_test.shape[0], times.shape[0]))
</code></pre>
<p>Another useful reference is the Kaplan-Meier estimator, that does not consider
any features: it estimates a survival function only from <code>y_test</code>.
We replicate this estimate for all samples in the test data.</p>
<pre><code class="language-python">from sksurv.functions import StepFunction
from sksurv.nonparametric import kaplan_meier_estimator
km_func = StepFunction(*kaplan_meier_estimator(y_test["cens"], y_test["time"]))
km_surv_prob = np.tile(km_func(times), (y_test.shape[0], 1))
</code></pre>
<p>Instead of comparing calibration across all 1690 time points, we’ll be
using the
<a href="https://scikit-survival.readthedocs.io/en/v0.13.0/generated/sksurv.metrics.integrated_brier_score.html" target="_blank">integrated Brier score</a>
(IBS) over all time points, which will
give us a single number to compare the models by.</p>
<pre><code class="language-python">from sksurv.metrics import integrated_brier_score
random_brier = integrated_brier_score(y, y_test, random_surv_prob, times)
km_brier = integrated_brier_score(y, y_test, km_surv_prob, times)
rsf_brier = integrated_brier_score(y, y_test, rsf_surv_prob, times)
cph_brier = integrated_brier_score(y, y_test, cph_surv_prob, times)
</code></pre>
<p>The results are summarized in the table below:</p>
<div class="table-responsive">
<table class="dataframe" style="width: 40%;">
<thead>
<tr style="text-align: right;">
<th></th>
<th>RSF</th>
<th>Coxnet</th>
<th>Random</th>
<th>Kaplan-Meier</th>
</tr>
</thead>
<tbody>
<tr>
<th>c-index</th>
<td>0.688</td>
<td>0.688</td>
<td>0.500</td>
<td></td>
</tr>
<tr>
<th>IBS</th>
<td>0.194</td>
<td>0.188</td>
<td>0.247</td>
<td>0.217</td>
</tr>
</tbody>
</table>
</div>
<p>Despite Random Survival Forest and Cox’s proportional hazards model
performing equally well in terms of discrimination, there seems to be
a notable difference in terms of calibration, with
Cox’s proportional hazards model outperforming Random Survival Forest.</p>
<p>As a final note, I want to clarify that the Brier score is only
applicable for models that are able to estimate a survival function.
Hence, it currently cannot be used with
<a href="https://scikit-survival.readthedocs.io/en/v0.13.0/api.html#survival-support-vector-machine" target="_blank">Survival Support Vector Machines</a>.</p>
- Survival Analysis for Deep Learning Tutorial for TensorFlow 2https://k-d-w.org/blog/2020/05/survival-analysis-for-deep-learning-tutorial-for-tensorflow-2/Sun, 17 May 2020 16:07:22 +0200https://k-d-w.org/blog/2020/05/survival-analysis-for-deep-learning-tutorial-for-tensorflow-2/<p>A while back, I posted the <a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/" target="_blank">Survival Analysis for Deep Learning</a> tutorial.
This tutorial was written for TensorFlow 1 using the <a href="https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/estimators.md" target="_blank">tf.estimators API</a>. The changes between version 1 and the current TensorFlow 2 are quite significant, which is why the code does not run when using
a recent TensorFlow version. Therefore, I created a new version of the tutorial that is compatible with TensorFlow 2.
The text is basically identical, but the training and evaluation procedure changed.</p>
<p>The complete notebook is available on <a href="https://github.com/sebp/survival-cnn-estimator" target="_blank">GitHub</a>, or you can run it directly using
<a href="https://colab.research.google.com/github/sebp/survival-cnn-estimator/blob/master/tutorial_tf2.ipynb" target="_blank">Google Colaboratory</a>.</p>
<h2 id="notes-on-porting-to-tensorflow-2">Notes on porting to TensorFlow 2</h2>
<p>A nice feature of TensorFlow 2 is that in order to write custom metrics (such as concordance index) for
TensorBoard, you don’t need to create a <code>Summary</code> protocol buffer manually, instead
it suffices to call <code>tf.summary.scalar</code> and pass it a name and float.
So instead of</p>
<pre><code class="language-python">from sksurv.metrics import concordance_index_censored
from tensorflow.core.framework import summary_pb2
c_index_metric = concordance_index_censored(…)[0]
writer = tf.summary.FileWriterCache.get(output_dir)
buf = summary_pb2.Summary(value=[summary_pb2.Summary.Value(
tag="c-index", simple_value=c_index_metric)])
writer.add_summary(buf, global_step=global_step)
</code></pre>
<p>you can just do</p>
<pre><code class="language-python">from sksurv.metrics import concordance_index_censored
with tf.summary.create_file_writer(output_dir):
c_index_metric = concordance_index_censored(…)[0]
summary.scalar("c-index", c_index_metric, step=step)
</code></pre>
<p>Another feature that I liked is that you can now iterate over an instance of <code>tf.data.Dataset</code> and
directly access the tensors and their values. This is much more convenient than having to call <code>make_one_shot_iterator</code>
first, which gives you an iterator, which you call <code>get_next()</code> on to get actual tensors.</p>
<p>Unfortunately, I also encountered some negatives when moving to TensorFlow 2.
First of all, there’s currently no officially supported way to produce a view of the executed Graph
that is identical to what you get with TensorFlow 1, unless you use the Keras training loop with
the <code>TensorBoard</code> callback.
There’s <code>tf.summary.trace_export</code>, which as described in <a href="https://www.tensorflow.org/tensorboard/graphs#graphs_of_tffunctions" target="_blank">this guide</a>
sounds like it would produce the graph, however, using this approach you can only
view individual operations in TensorBoard, but you can’t inspect what’s the size of input and output tensors of an operation.
After searching for while, I eventually found the answer in an <a href="https://stackoverflow.com/questions/58843269/,display-graph-using-tensorflow-v2-0-in-tensorboard" target="_blank">Stack overflow post</a>, and, as it turns out, that is exactly what the <code>TensorBoard</code> callback
<a href="https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/callbacks.py#L1787" target="_blank">is doing</a>.</p>
<p>Another thing I found odd is that if you define your custom loss as a subclass of <code>tf.keras.losses.Loss</code>, it insists
that there are only two inputs <code>y_true</code> and <code>y_pred</code>. In the case of
<a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/#cox-s-proportional-hazards-model" target="_blank">Cox’s proportional hazards loss</a>
the true label comprises an event indicator and an indicator matrix specifying which pairs in a batch are comparable.
Luckily, the contents of <code>y_pred</code> don’t get checked, so you can just pass a list, but I would prefer to write something like</p>
<pre><code class="language-python">loss_fn(y_true_event=y_event, y_true_riskset=y_riskset, y_pred=pred_risk_score)
</code></pre>
<p>instead of</p>
<pre><code class="language-python">loss_fn(y_true=[y_event, y_riskset], y_pred=pred_risk_score)
</code></pre>
<p>Finally, although eager execution is now enabled by default, the code runs <em>significantly</em> faster in graph mode, i.e.
annotating your model’s call method with <code>@tf.function</code>. I guess you are only supposed to use eager execution for debugging purposes.</p>
- scikit-survival 0.12 Releasedhttps://k-d-w.org/blog/2020/04/scikit-survival-0.12-released/Wed, 15 Apr 2020 12:10:25 +0200https://k-d-w.org/blog/2020/04/scikit-survival-0.12-released/<p>Version 0.12 of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a> adds support
for scikit-learn 0.22 and Python 3.8
and comes with two noticeable improvements:</p>
<ol>
<li><a href="https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#sklearn.pipeline.Pipeline" target="_blank">sklearn.pipeline.Pipeline</a>
will now be automatically patched
to add support for <code>predict_cumulative_hazard_function</code> and <code>predict_survival_function</code>
if the underlying estimator supports it (see
<a href="https://k-d-w.org/blog/2020/04/scikit-survival-0.12-released/#using-pipelines">
first example
</a>
).</li>
<li>The regularization strength of the ridge penalty in
<a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.linear_model.CoxPHSurvivalAnalysis.html#sksurv.linear_model.CoxPHSurvivalAnalysis" target="_blank">sksurv.linear_model.CoxPHSurvivalAnalysis</a>
can now be set per feature (see
<a href="https://k-d-w.org/blog/2020/04/scikit-survival-0.12-released/#per-feature-regularization-strength">
second example
</a>
).</li>
</ol>
<p>For a full list of changes in scikit-survival 0.12, please see the
<a href="https://scikit-survival.readthedocs.io/en/latest/release_notes.html" target="_blank">release notes</a>.</p>
<p>Pre-built conda packages are available for Linux, macOS, and Windows via</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source via pip:</p>
<pre><code class="language-bash"> pip install -U scikit-survival
</code></pre>
<h2 id="using-pipelines">Using pipelines</h2>
<p>You can now create a scikit-learn pipeline and directly
call <code>predict_cumulative_hazard_function</code> and <code>predict_survival_function</code>
if the underlying estimator supports it, such as
<a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.ensemble.RandomSurvivalForest.html#sksurv.ensemble.RandomSurvivalForest" target="_blank">RandomSurvivalForest</a> below.</p>
<pre><code class="language-python">from sklearn.pipeline import make_pipeline
from sksurv.datasets import load_breast_cancer
from sksurv.ensemble import RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder
X, y = load_breast_cancer()
pipe = make_pipeline(OneHotEncoder(), RandomSurvivalForest())
pipe.fit(X, y)
surv_fn = pipe.predict_survival_function(X, y)
</code></pre>
<h2 id="per-feature-regularization-strength">Per-feature regularization strength</h2>
<p>If you want to fit Cox’s proportional hazards model to a large
set of features, but only shrink the coefficients for a subset
of features, previously, you had to use
<a href="https://scikit-survival.readthedocs.io/en/stable/generated/sksurv.linear_model.CoxnetSurvivalAnalysis.html#sksurv.linear_model.CoxnetSurvivalAnalysis" target="_blank">CoxnetSurvivalAnalysis</a>
and set the <code>penalty_factor</code> parameter accordingly.
This release adds a similar option to
<a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.linear_model.CoxnetSurvivalAnalysis.html#sksurv.linear_model.CoxnetSurvivalAnalysis" target="_blank">CoxPHSurvivalAnalysis</a>, which only uses ridge regression.</p>
<p>For instance, consider the breast cancer data, which comprises
4 established markers (age, tumor size, tumor grade, and estrogen receptor status)
and 76 genetic markers.
It is sensible to fit a model where the established markers enter unpenalized and
only the coefficients of the genetic markers get penalized.
We can achieve this by creating an array for the regularization strength $\alpha$
where the entries corresponding to the established markers are zero.</p>
<pre><code class="language-python">import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis
X, y = load_breast_cancer()
# the last 4 features are: age, er, grade, size
num_genes = X.shape[1] - 4
# add 2, because after one-hot encoding grade becomes three features
alphas = np.ones(X.shape[1] + 2)
# do not penalize established markers
alphas[num_genes:] = 0.0
# fit the model
pipe = make_pipeline(OneHotEncoder(), CoxPHSurvivalAnalysis(alpha=alphas))
pipe.fit(X, y)
</code></pre>
- scikit-survival 0.11 featuring Random Survival Forests releasedhttps://k-d-w.org/blog/2019/12/scikit-survival-0.11-featuring-random-survival-forests-released/Sat, 21 Dec 2019 17:46:08 +0100https://k-d-w.org/blog/2019/12/scikit-survival-0.11-featuring-random-survival-forests-released/<p>Today, I released a new version of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a> which
includes an implementation of <a href="https://projecteuclid.org/euclid.aoas/1223908043" target="_blank">Random Survival Forests</a>.
As it’s popular counterparts for classification and regression, a Random Survival Forest is an ensemble
of tree-based learners. A Random Survival Forest ensures that individual trees are de-correlated by 1)
building each tree on a different bootstrap sample of the original training data, and 2)
at each node, only evaluate the split criterion for a randomly selected subset of
features and thresholds. Predictions are formed by aggregating predictions of individual
trees in the ensemble.</p>
<p>For a full list of changes in scikit-survival 0.11, please see the
<a href="https://scikit-survival.readthedocs.io/en/latest/release_notes.html" target="_blank">release notes</a>.</p>
<p>The latest version can be downloaded via <em>conda</em> or <em>pip</em>. Pre-built conda packages are available for Linux, OSX and Windows via</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source via pip:</p>
<pre><code class="language-bash"> pip install -U scikit-survival
</code></pre>
<h2 id="using-random-survival-forests">Using Random Survival Forests</h2>
<p>To demonstrate Random Survival Forest, I’m going to use data from the German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women
and 8 prognostic factors:</p>
<ol>
<li>age,</li>
<li>estrogen receptor (<code>estrec</code>),</li>
<li>whether or not a hormonal therapy was administered (<code>horTh</code>),</li>
<li>menopausal status (<code>menostat</code>),</li>
<li>number of positive lymph nodes (<code>pnodes</code>),</li>
<li>progesterone receptor (<code>progrec</code>),</li>
<li>tumor size (<code>tsize</code>,</li>
<li>tumor grade (<code>tgrade</code>).</li>
</ol>
<p>The goal is to predict recurrence-free survival time.</p>
<p>The code to reproduce the results below is available in
<a href="https://github.com/sebp/scikit-survival/blob/master/examples/random-survival-forest.ipynb" target="_blank">this notebook</a>.</p>
<p>First, we need to load the data and transform it into numeric values.</p>
<pre><code class="language-python">X, y = load_gbsg2()
grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)
X_no_grade = X.drop("tgrade", axis=1)
Xt = OneHotEncoder().fit_transform(X_no_grade)
Xt = np.column_stack((Xt.values, grade_num))
feature_names = X_no_grade.columns.tolist() + ["tgrade"]
</code></pre>
<p>Next, the data is split into 75% for training and 25% for testing, so we can determine
how well our model generalizes.</p>
<pre><code class="language-python">X_train, X_test, y_train, y_test = train_test_split(
Xt, y, test_size=0.25, random_state=random_state)
</code></pre>
<h3 id="training">Training</h3>
<p>Several split criterion have been proposed in the past, but the most widespread one is based
on the log-rank test, which you probably know from comparing survival curves among two or more
groups. Using the training data, we fit a Random Survival Forest comprising 1000 trees.</p>
<pre><code class="language-python">rsf = RandomSurvivalForest(n_estimators=1000,
min_samples_split=10,
min_samples_leaf=15,
max_features="sqrt",
n_jobs=-1,
random_state=random_state)
rsf.fit(X_train, y_train)
</code></pre>
<p>We can check how well the model performs by evaluating it on the test data.</p>
<pre><code class="language-python">rsf.score(X_test, y_test)
</code></pre>
<p>This gives a concordance index of 0.68, which is a good a value and matches the results
reported in the <a href="https://projecteuclid.org/euclid.aoas/1223908043" target="_blank">Random Survival Forests paper</a>.</p>
<h3 id="predicting">Predicting</h3>
<p>For prediction, a sample is dropped down each tree in the forest until it reaches a terminal node.
Data in each terminal is used to non-parametrically estimate the survival and cumulative hazard
function using the Kaplan-Meier and Nelson-Aalen estimator, respectively. In addition, a risk score
can be computed that represents the expected number of events for one particular terminal node.
The ensemble prediction is simply the average across all trees in the forest.</p>
<p>Let’s first select a couple of patients from the test data
according to the number of positive lymph nodes and age.</p>
<pre><code class="language-python">a = np.empty(X_test.shape[0], dtype=[("age", float), ("pnodes", float)])
a["age"] = X_test[:, 0]
a["pnodes"] = X_test[:, 4]
sort_idx = np.argsort(a, order=["pnodes", "age"])
X_test_sel = pd.DataFrame(
X_test[np.concatenate((sort_idx[:3], sort_idx[-3:]))],
columns=feature_names)
</code></pre>
<div class="table-responsive">
<table class="dataframe" style="width: 40%;">
<thead>
<tr>
<th></th>
<th>age</th>
<th>estrec</th>
<th>horTh</th>
<th>menostat</th>
<th>pnodes</th>
<th>progrec</th>
<th>tsize</th>
<th>tgrade</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>33.0</td>
<td>0.0</td>
<td>0.0</td>
<td>0.0</td>
<td>1.0</td>
<td>26.0</td>
<td>35.0</td>
<td>2.0</td>
</tr>
<tr>
<th>1</th>
<td>34.0</td>
<td>37.0</td>
<td>0.0</td>
<td>0.0</td>
<td>1.0</td>
<td>0.0</td>
<td>40.0</td>
<td>2.0</td>
</tr>
<tr>
<th>2</th>
<td>36.0</td>
<td>14.0</td>
<td>0.0</td>
<td>0.0</td>
<td>1.0</td>
<td>76.0</td>
<td>36.0</td>
<td>1.0</td>
</tr>
<tr>
<th>3</th>
<td>65.0</td>
<td>64.0</td>
<td>0.0</td>
<td>1.0</td>
<td>26.0</td>
<td>2.0</td>
<td>70.0</td>
<td>2.0</td>
</tr>
<tr>
<th>4</th>
<td>80.0</td>
<td>59.0</td>
<td>0.0</td>
<td>1.0</td>
<td>30.0</td>
<td>0.0</td>
<td>39.0</td>
<td>1.0</td>
</tr>
<tr>
<th>5</th>
<td>72.0</td>
<td>1091.0</td>
<td>1.0</td>
<td>1.0</td>
<td>36.0</td>
<td>2.0</td>
<td>34.0</td>
<td>2.0</td>
</tr>
</tbody>
</table>
</div>
<p>The predicted risk scores indicate that risk for the last three patients is quite
a bit higher than that of the first three patients.</p>
<pre><code class="language-python">pd.Series(rsf.predict(X_test_sel))
</code></pre>
<pre><code>0 91.477609
1 102.897552
2 75.883786
3 170.502092
4 171.210066
5 148.691835
dtype: float64
</code></pre>
<p>We can have a more detailed insight by considering the predicted survival function.
It shows that the biggest difference occurs roughly within the first 750 days.</p>
<pre><code class="language-python">surv = rsf.predict_survival_function(X_test_sel)
for i, s in enumerate(surv):
plt.step(rsf.event_times_, s, where="post", label=str(i))
plt.ylabel("Survival probability")
plt.xlabel("Time in days")
plt.grid(True)
plt.legend()
</code></pre>
<figure>
<img src="https://k-d-w.org/blog/2019/12/scikit-survival-0.11-featuring-random-survival-forests-released/img/predicted-survival-function.svg"/>
</figure>
<p>Alternatively, we can also plot the predicted cumulative hazard function.</p>
<pre><code class="language-python">surv = rsf.predict_cumulative_hazard_function(X_test_sel)
for i, s in enumerate(surv):
plt.step(rsf.event_times_, s, where="post", label=str(i))
plt.ylabel("Cumulative hazard")
plt.xlabel("Time in days")
plt.grid(True)
plt.legend()
</code></pre>
<figure>
<img src="https://k-d-w.org/blog/2019/12/scikit-survival-0.11-featuring-random-survival-forests-released/img/predicted-chf.svg"/>
</figure>
<h3 id="permutation-based-feature-importance">Permutation-based Feature Importance</h3>
<p>The implementation is based on scikit-learn’s Random Forest implementation and inherits many
features, such as building trees in parallel. What’s currently missing is feature importances
via the <code>feature_importance_</code> attribute.
This is due to the way scikit-learn’s implementation computes importances. It relies on
a measure of <em>impurity</em> for each child node, and defines importance as the amount of
decrease in impurity due to a split. For traditional regression, impurity would be measured
by the variance, but for survival analysis there is no per-node impurity measure due to censoring.
Instead, one could use the magnitude of the log-rank test statistic as an importance measure,
but scikit-learn’s implementation doesn’t seem to allow this.</p>
<p>Fortunately, this is not a big concern though, as scikit-learn’s definition
of feature importance is non-standard and differs from what Leo Breiman
<a href="https://github.com/scikit-learn/scikit-learn/pull/8027#issuecomment-327595859" target="_blank">proposed in the original Random Forest paper</a>.
Instead, we can use permutation to estimate feature importance, which is
preferred over scikit-learn’s definition. This is implemented in the
<a href="https://eli5.readthedocs.io/en/latest/overview.html" target="_blank">ELI5</a> library,
which is fully compatible with scikit-survival.</p>
<pre><code class="language-python">import eli5
from eli5.sklearn import PermutationImportance
perm = PermutationImportance(rsf, n_iter=15, random_state=random_state)
perm.fit(X_test, y_test)
eli5.show_weights(perm, feature_names=feature_names)
</code></pre>
<div class="table-responsive">
<style scoped>
table.eli5-weights {
width: 80%;
margin-left: auto;
margin-right: auto;
display: table;
overflow-x: visible;
-webkit-overflow-scrolling: auto;
}
table.eli5-weights > tbody > tr > td {
background-color: transparent;
}
table.eli5-weights > tbody > tr:hover > td {
background-color: #e5e5e5;
}
.col-weight {
padding: 0 1em 0 0.5em;
text-align: right;
}
.col-feature {
padding: 0 0.5em 0 0.5em;
text-align: left;
}
</style>
<table class="eli5-weights">
<thead>
<tr>
<th class="col-weight">Weight</th>
<th class="col-feature">Feature</th>
</tr>
</thead>
<tbody>
<tr style="background-color: hsl(120, 100.00%, 80.00%); border: none;">
<td class="col-weight">
0.0676 ± 0.0229
</td>
<td class="col-feature">
pnodes
</td>
</tr>
<tr style="background-color: hsl(120, 100.00%, 91.29%); border: none;">
<td class="col-weight">
0.0206 ± 0.0139
</td>
<td class="col-feature">
age
</td>
</tr>
<tr style="background-color: hsl(120, 100.00%, 92.19%); border: none;">
<td class="col-weight">
0.0177 ± 0.0468
</td>
<td class="col-feature">
progrec
</td>
</tr>
<tr style="background-color: hsl(120, 100.00%, 95.29%); border: none;">
<td class="col-weight">
0.0086 ± 0.0098
</td>
<td class="col-feature">
horTh
</td>
</tr>
<tr style="background-color: hsl(120, 100.00%, 97.61%); border: none;">
<td class="col-weight">
0.0032 ± 0.0198
</td>
<td class="col-feature">
tsize
</td>
</tr>
<tr style="background-color: hsl(120, 100.00%, 97.63%); border: none;">
<td class="col-weight">
0.0032 ± 0.0060
</td>
<td class="col-feature">
tgrade
</td>
</tr>
<tr style="background-color: hsl(0, 100.00%, 99.21%); border: none;">
<td class="col-weight">
-0.0007 ± 0.0018
</td>
<td class="col-feature">
menostat
</td>
</tr>
<tr style="background-color: hsl(0, 100.00%, 96.19%); border: none;">
<td class="col-weight">
-0.0063 ± 0.0207
</td>
<td class="col-feature">
estrec
</td>
</tr>
</tbody>
</table>
</div>
<p>The result shows that the number of positive lymph nodes (<code>pnodes</code>) is by far the most important
feature. If its relationship to survival time is removed (by random shuffling),
the concordance index on the test data drops on average by 0.0676 points.
Again, this agrees with the results from the original
<a href="https://projecteuclid.org/euclid.aoas/1223908043" target="_blank">Random Survival Forests paper</a>.</p>
- scikit-survival 0.10 releasedhttps://k-d-w.org/blog/2019/09/scikit-survival-0.10-released/Mon, 02 Sep 2019 18:06:55 +0200https://k-d-w.org/blog/2019/09/scikit-survival-0.10-released/<p>This release of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a> adds two features that
are standard in most software for survival analysis, but were missing so far:</p>
<ol>
<li><a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.linear_model.CoxPHSurvivalAnalysis.html#sksurv.linear_model.CoxPHSurvivalAnalysis" target="_blank">CoxPHSurvivalAnalysis</a>
now has a <code>ties</code> parameter that allows you to choose between Breslow’s
and Efron’s likelihood for handling tied event times. Previously, only
Breslow’s likelihood was implemented and it remains the default.
If you have many tied event times in your data, you can now select
Efron’s likelihood with <code>ties="efron"</code> to get better estimates of the
model’s coefficients.</li>
<li>A <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.compare.compare_survival.html#sksurv.compare.compare_survival" target="_blank">compare_survival</a>
function has been added. It can be used to assess whether survival functions across 2 or more groups differ.</li>
</ol>
<p>To illustrate the use of
<a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.compare.compare_survival.html#sksurv.compare.compare_survival" target="_blank">compare_survival</a>,
let’s consider the Veterans’ Administration Lung Cancer Trial.
Here, we are considering the <code>Celltype</code> feature and we want to know whether
the tumor type impacts survival. We can visualize the survival function for
each subgroup using the Kaplan-Meier estimator.</p>
<pre><code class="language-python">import matplotlib.pyplot as plt
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator
data_x, data_y = load_veterans_lung_cancer()
group_indicator = data_x.loc[:, "Celltype"]
groups = group_indicator.unique()
for group in groups:
group_y = data_y[group_indicator == group]
time, surv_prob = kaplan_meier_estimator(
group_y["Status"],
group_y["Survival_in_days"])
plt.step(time, surv_prob, where="post",
label="Celltype = {}".format(group))
plt.xlabel("time $t$")
plt.ylabel("est. probability of survival")
plt.ylim(0, 1)
plt.grid(True)
plt.legend()
</code></pre>
<figure>
<img src="https://k-d-w.org/blog/2019/09/scikit-survival-0.10-released/img/kaplan-meier-plot.svg"
alt="Kaplan-Meier estimates of survival function."/> <figcaption>
<p>Kaplan-Meier estimates of survival function.</p>
</figcaption>
</figure>
<p>The figure indicates that patients with adenocarcinoma (green line) do not survive
beyond 200 days, whereas patients with squamous cell lung cancer (blue line) can
survive several years.
We can determine whether this difference is indeed statistically significant by
performing a non-parametric <a href="https://en.wikipedia.org/wiki/Logrank_test" target="_blank">log-rank test</a>.
It groups patients according to cell type and compares the estimated group-specific hazard rate
with the pooled hazard rate. Under the null hypothesis, the hazard rate of groups is
equal for all time points. The alternative hypothesis is that the hazard rate of at
least one group differs from the others at some time.</p>
<pre><code class="language-python">from sksurv.compare import compare_survival
chisq, pvalue, stats, covar = compare_survival(
data_y, group_indicator, return_stats=True)
</code></pre>
<p>The resulting test statistic $\chi^2 = 25.40$, which corresponds
to a highly significant P-value of $1.3\cdot{10}^{-5}$.
In addition, we can look at group-specific statistics by specifying
<code>return_stats=True</code>.</p>
<div class="table-responsive">
<table class="dataframe" style="width: 50%;">
<thead>
<tr>
<th></th>
<th>counts</th>
<th>observed</th>
<th>expected</th>
<th>statistic</th>
</tr>
<tr>
<th>group</th>
<th></th>
<th></th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<th>adeno</th>
<td>27</td>
<td>26</td>
<td>15.69</td>
<td>10.31</td>
</tr>
<tr>
<th>large</th>
<td>27</td>
<td>26</td>
<td>34.55</td>
<td>-8.55</td>
</tr>
<tr>
<th>smallcell</th>
<td>48</td>
<td>45</td>
<td>30.10</td>
<td>14.90</td>
</tr>
<tr>
<th>squamous</th>
<td>35</td>
<td>31</td>
<td>47.65</td>
<td>-16.65</td>
</tr>
</tbody>
</table>
</div>
<p>The column <em>counts</em> lists the size of each group and
is followed by the number of <em>observed</em> and <em>expected</em>
events. The last column <em>statistic</em> is the
difference between the observed and expected number
of events from which the overall $\chi^2$ statistic
is computed.</p>
<h2 id="download">Download</h2>
<p>The latest version of scikit-survival can be obtained via <em>conda</em> or <em>pip</em>. Pre-built conda packages are available for Linux, OSX and Windows:</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, you can install it from source via pip:</p>
<pre><code class="language-bash"> pip install -U scikit-survival
</code></pre>
- Survival Analysis for Deep Learninghttps://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/Mon, 29 Jul 2019 07:38:23 +0200https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/<p>Most machine learning algorithms have been developed to perform classification or regression. However, in clinical research we often want to estimate the time to and event, such as death or recurrence of cancer, which leads to a special type of learning task that is distinct from classification and regression. This task is termed <em>survival analysis</em>, but is also referred to as time-to-event analysis or reliability analysis.
Many machine learning algorithms have been adopted to perform survival analysis:
<a href="https://scholar.google.com/scholar?oi=bibs&cluster=18092275419152143443" target="_blank">Support Vector Machines</a>,
<a href="https://scholar.google.com/scholar?cluster=16319510831191377024" target="_blank">Random Forest</a>,
or <a href="https://scholar.google.com/scholar?cluster=14069073471114367075" target="_blank">Boosting</a>.
It has only been recently that survival analysis entered the era of deep learning, which is the focus of this post.</p>
<p>You will learn how to train a convolutional neural network to predict time to a (generated) event from MNIST images, using a loss function specific to survival analysis. The
<a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/#primer-on-survival-analysis">
first part
</a>
, will cover some basic terms and quantities used in survival analysis (feel free to skip this part if you are already familiar). In the
<a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/#generating-synthetic-survival-data-from-mnist">
second part
</a>
, we will generate synthetic survival data from MNIST images and visualize it. In the
<a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/#cox-s-proportional-hazards-model">
third part
</a>
, we will briefly revisit the most popular survival model of them all and learn how it can be used as a loss function for training a neural network.
<a href="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/#creating-a-convolutional-neural-network-for-survival-analysis-on-mnist">
Finally
</a>
, we put all the pieces together and train a convolutional neural network on MNIST and predict survival functions on the test data.</p>
<p>The notebook to reproduce the results is available on <a href="https://github.com/sebp/survival-cnn-estimator" target="_blank">GitHub</a>, or you can run it directly using
<a href="https://colab.research.google.com/github/sebp/survival-cnn-estimator/blob/master/tutorial_tf1.ipynb" target="_blank">Google Colaboratory</a>.</p>
<h2 id="primer-on-survival-analysis">Primer on Survival Analysis</h2>
<p>The objective in survival analysis is to establish a connection between covariates and the time of an event. The name <em>survival analysis</em> originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are <em>censored</em>.</p>
<p>As an example, consider a clinical study that has been carried out over a 1 year period as in the figure below.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/img/censoring-example.svg"/>
</figure>
<p>Patient A was lost to follow-up after three months with no recorded event, patient B experienced an event four and a half months after enrollment, patient C withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the <em>exact time</em> of an event could only be recorded for patients B and D; their records are <span style="color:#E41A1C"><em>uncensored</em></span>. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, C, and E is that they were event-free up to their last follow-up. Therefore, their records are <span style="color:#377Eb8"><em>censored</em></span>.</p>
<p>Formally, each patient record consists of the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing an event are mutually exclusive, it is common to define an event indicator $\delta \in \{0;1\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored time of event is defined as</p>
<p>$$
y = \min(t, c) =
\begin{cases}
t & \text{if } \delta = 1 , \\%
c & \text{if } \delta = 0 .
\end{cases}
$$</p>
<p>Consequently, survival analysis demands for models that take partially observed, i.e., censored, event times into account.</p>
<h2 id="basic-quantities">Basic Quantities</h2>
<p>Typically, the survival time is modelled as a continuous non-negative random variable $T$, from which basic quantities for time-to-event analysis can be derived, most importantly, the <em>survival function</em> and the <em>hazard function</em>.</p>
<ul>
<li>The <strong>survival function</strong> $S(t)$ returns the probability of survival beyond time $t$ and is defined as $S(t) = P(T > t)$. It is non-increasing with $S(0) = 1$, and $S(\infty) = 0$.</li>
<li>The <strong>hazard function</strong> $h(t)$ denotes an approximate probability (it is not bounded from above) that an event occurs in the small time interval $[t; t + \Delta[$, under the condition that an individual would remain event-free up to time $t$:
$$
h(t) = \lim_{\Delta t \rightarrow 0} \frac{P(t \leq T < t + \Delta t \mid T \geq t)}{\Delta t} \geq 0
$$
Alternative names for the hazard function are conditional failure rate, conditional mortality rate, or instantaneous failure rate. In contrast to the survival function, which describes the absence of an event, the hazard function provides information about the occurrence of an event.</li>
</ul>
<h2 id="generating-synthetic-survival-data-from-mnist">Generating Synthetic Survival Data from MNIST</h2>
<p>To start off, we are using images from the MNIST dataset and will synthetically generate
survival times based on the digit each image represents.
We associate a survival time (or risk score) with each class of the ten digits in MNIST. First, we randomly assign each class label to one of four overall risk groups, such that some digits will correspond to better and others to worse survival. Next, we generate risk scores that indicate how big the risk of experiencing an event is, relative to each other.</p>
<div class="table-responsive">
<table class="dataframe" style="width: 40%;">
<thead>
<tr>
<th></th>
<th>risk_score</th>
<th>risk_group</th>
</tr>
<tr>
<th>class_label</th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>3.071</td>
<td>3</td>
</tr>
<tr>
<th>1</th>
<td>2.555</td>
<td>2</td>
</tr>
<tr>
<th>2</th>
<td>0.058</td>
<td>0</td>
</tr>
<tr>
<th>3</th>
<td>1.790</td>
<td>1</td>
</tr>
<tr>
<th>4</th>
<td>2.515</td>
<td>2</td>
</tr>
<tr>
<th>5</th>
<td>3.031</td>
<td>3</td>
</tr>
<tr>
<th>6</th>
<td>1.750</td>
<td>1</td>
</tr>
<tr>
<th>7</th>
<td>2.475</td>
<td>2</td>
</tr>
<tr>
<th>8</th>
<td>0.018</td>
<td>0</td>
</tr>
<tr>
<th>9</th>
<td>2.435</td>
<td>2</td>
</tr>
</tbody>
</table>
</div>
<p>We can see that class labels 2 and 8 belong to risk group 0, which has the lowest risk (close to zero). Risk group 1 corresponds to a risk score of about 1.7, risk group 2 of about 2.5, and risk group 3 is the group with the highest risk score of about 3.</p>
<p>To generate survival times from risk scores, we are going to follow the protocol of
<a href="https://scholar.google.com/scholar?cluster=11575471310627475868" target="_blank">Bender et al</a>. We choose the exponential distribution for the survival time. Its probability density function is $f(t\,|\,\lambda) = \lambda \exp(-\lambda t)$, where $\lambda > 0$ is a scale parameter that is the inverse of the expectation: $E(T) = \frac{1}{\lambda}$. The exponential distribution results in a relatively simple time-to-event model with no memory, because the hazard rate is constant: $h(t) = \lambda$. For more complex cases, refer to the paper by <a href="https://scholar.google.com/scholar?cluster=11575471310627475868" target="_blank">Bender et al</a>.</p>
<p>Here, we choose $\lambda$ such that the mean survival time is 365 days. Finally, we randomly censor survival times drawing times of censoring from a uniform distribution such that we approximately obtain the desired amount of 45% censoring. The generated survival data comprises an observed time and a boolean event indicator for each MNIST image.</p>
<p>We can use the generated censored data and estimate the survival function $S(t)$ to see what the risk scores actually mean in terms of survival. We stratify the training data by class label, and estimate the corresponding survival function using the non-parametric <a href="https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator" target="_blank">Kaplan-Meier estimator</a>.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/img/kaplan-meier-plot.svg"/>
</figure>
<p>Classes 0 and 5 (dotted lines) correspond to risk group 3, which has the highest risk score. The corresponding survival functions drop most quickly, which is exactly what we wanted. On the other end of the spectrum are classes 2 and 8 (solid lines) belonging to risk group 0 with the lowest risk.</p>
<h2 id="evaluating-predictions">Evaluating Predictions</h2>
<p>One important aspect for survival analysis is that both the training data and the test data are subject to censoring, because we are unable to observe the exact time of an event no matter how the data was split. Therefore, performance measures need to account for censoring. The most widely used performance measure is Harrell’s concordance index. Given a set of (predicted) risk scores and observed times, it checks whether the ordering by risk scores is concordant with the ordering by actual survival time. While Harrell’s concordance index is widely used, it has its flaws, in particular when data is highly censored. Please refer to my <a href="https://k-d-w.org/blog/111/evaluating-survival-models" target="_blank">previous post on evaluating survival models</a> for more details.</p>
<p>We can take the risk score from which we generated survival times to check how good a model would perform if we knew the actual risk score.</p>
<pre><code class="language-python">cindex = concordance_index_censored(event_test, time_test, risk_scores[y_train.shape[0]:])
print(f"Concordance index on test data with actual risk scores: {cindex[0]:.3f}")
</code></pre>
<pre><code>Concordance index on test data with actual risk scores: 0.705
</code></pre>
<p>Surprisingly, we do not obtain a perfect result of 1.0. The reason for this is that generated survival times are randomly distributed based on risk scores and not deterministic functions of the risk score. Therefore, any model we will train on this data should not be able to exceed this performance value.</p>
<h2 id="cox-s-proportional-hazards-model">Cox’s Proportional Hazards Model</h2>
<p>By far the most widely used model to learn from censored survival data, is
<a href="https://scholar.google.com/scholar?cluster=17981786408695305487" target="_blank">Cox’s proportional hazards model</a> model.
It models the hazard function $h(t_i)$
of the $i$-th subject, conditional on the feature vector $\mathbf{x}_i \in \mathbb{R}^p$,
as the product of an unspecified baseline hazard function $h_0$ (more on that later) and an
exponential function of the linear model $\mathbf{x}_i^\top \mathbf{\beta}$:</p>
<p>$$
h(t | x_{i1}, \ldots, x_{ip}) = h_0(t) \exp \left( \sum_{j=1}^p x_{ij} \beta_j \right)
\Leftrightarrow
\log \frac{h(t | \mathbf{x}_i)}{h_0 (t)} = \mathbf{x}_i^\top \mathbf{\beta} ,
$$</p>
<p>where $\mathbf{\beta} \in \mathbb{R}^p$ are the coefficients associated with each of the
$p$ features, and no intercept term is included in the model.
The key is that the hazard function is split into two parts: the baseline hazard function $h_0$ only depends on the time $t$, whereas the exponential is independent of time and only depends on the covariates $\mathbf{x}_i$.</p>
<p>Cox’s proportional hazards model is fitted by maximizing the partial likelihood function, which is based on the probability that the $i$-th individual experiences
an event at time $t_i$, given that there is one event at time point $t_i$.
As we will see, by specifying the hazard function as above, the baseline hazard function $h_0$
can be eliminated and does not need be defined for finding the coefficients $\mathbf{\beta}$.
Let $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$
be the risk set, i.e., the set of subjects who remained event-free shortly before time point $y_i$,
and $I(\cdot)$ the indicator function, then we have</p>
<p>$$
\begin{aligned}
&P(\text{subject experiences event at $y_i$} \mid \text{one event at $y_i$}) \\%
=& \frac{P(\text{subject experiences event at $y_i$} \mid \text{event-free up to $y_i$})}
{P (\text{one event at $y_i$} \mid \text{event-free up to $y_i$})} \\%
=& \frac{h(y_i | \mathbf{x}_i)}{ \sum_{j=1}^n I(y_j \geq y_i) h(y_j | \mathbf{x}_j) } \\%
=& \frac{h_0(y_i) \exp(\mathbf{x}_i^\top \mathbf{\beta})}
{ \sum_{j=1}^n I(y_j \geq y_i) h_0(y_j) \exp(\mathbf{x}_j^\top \mathbf{\beta}) } \\%
=& \frac{\exp( \mathbf{x}_i^\top \beta)}{\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \beta)} .
\end{aligned}
$$</p>
<p>By multiplying the conditional probability from above for all patients who experienced an event, and taking the logarithm, we obtain the <em>partial likelihood function</em>:</p>
<p>$$
\widehat{\mathbf{\beta}} = \arg\max_{\mathbf{\beta}}~
\log\,PL(\mathbf{\beta}) = \sum_{i=1}^n \delta_i \left[ \mathbf{x}_i^\top \mathbf{\beta}
- \log \left( \sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta}) \right) \right] .
$$</p>
<h2 id="non-linear-survival-analysis-with-neural-networks">Non-linear Survival Analysis with Neural Networks</h2>
<p>Cox’s proportional hazards model as described above is a linear model, i.e., the predicted risk score is a linear combination of features. However, the model can easily be extended to the non-linear case by just replacing the linear predictor with the output of a neural network with parameters $\mathbf{\Theta}$.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/img/faraggi-simon-model.svg"/>
</figure>
<p>This has been realized early on and was originally proposed in the work of <a href="https://scholar.google.com/scholar?cluster=8523249692591517459" target="_blank">Faraggi and Simon</a> back in 1995. Farragi and Simon explore multilayer perceptrons, but the same loss can be used in combination with more advanced architectures such as convolutional neural networks or recurrent neural networks.
Therefore, it is natural to also use the same loss function in the era of deep learning.
However, this transition is not so easy as it may seem and comes with some caveats, both for training and for evaluation.</p>
<h3 id="computing-the-loss-function">Computing the Loss Function</h3>
<p>When implementing the Cox PH loss function, the problematic part is the inner sum over the risk set:
$\sum_{j \in \mathcal{R}_i} \exp( \mathbf{x}_j^\top \mathbf{\beta})$. Note that the risk set is defined as $\mathcal{R}_i = \{ j\,|\,y_j \geq y_i \}$, which implies an ordering according to observed times $y_i$, which may lead to quadratic complexity if implemented naively. Ideally, we want to sort the data once in descending order by survival time and then incrementally update the inner sum, which leads to a linear complexity to compute the loss (ignoring the time for sorting).</p>
<p>Another problem is that the risk set for the subject with the smallest uncensored survival time is over the whole dataset. This is usually impractical, because we may not be able to keep the whole dataset in GPU memory. If we use mini-batches instead, as it’s the norm, (i) we cannot compute the exact loss, because we may not have access to all samples in the risk set, and (ii) we need to sort each mini-batch by observed time, instead of sorting the whole data once.</p>
<p>For practical purposes, computing the Cox PH loss over a mini-batch is usually fine, as long as the batch contains several uncensored samples, because otherwise the outer sum in the partial likelihood function would be over an empty set.
Here, we implement the sum over the risk set by multiplying the exponential of the predictions (as a row vector) by a squared boolean matrix that contains each sample’s risk set as its rows. The sum over the risk set for each sample is then equivalent to a row-wise summation.</p>
<pre><code class="language-python">class InputFunction:
…
def _get_data_batch(self, index):
"""Compute risk set for samples in batch."""
time = self.time[index]
event = self.event[index]
images = self.images[index]
labels = {
"label_event": event.astype(np.int32),
"label_time": time.astype(np.float32),
"label_riskset": _make_riskset(time)
}
return images, labels
…
def _make_riskset(time):
assert time.ndim == 1, "expected 1D array"
# sort in descending order
o = np.argsort(-time, kind="mergesort")
n_samples = len(time)
risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
for i_org, i_sort in enumerate(o):
ti = time[i_sort]
k = i_org
while k < n_samples and ti == time[o[k]]:
k += 1
risk_set[i_sort, o[:k]] = True
return risk_set
def coxph_loss(event, riskset, predictions):
# move batch dimension to the end so predictions get broadcast
# row-wise when multiplying by riskset
pred_t = tf.transpose(predictions)
# compute log of sum over risk set for each row
rr = logsumexp_masked(pred_t, riskset, axis=1, keepdims=True)
losses = tf.multiply(event, rr - predictions)
loss = tf.reduce_mean(losses)
return loss
def logsumexp_masked(risk_scores, mask,
axis=0, keepdims=None):
"""Compute logsumexp across `axis` for entries where `mask` is true."""
mask_f = tf.cast(mask, risk_scores.dtype)
risk_scores_masked = tf.multiply(risk_scores, mask_f)
# for numerical stability, substract the maximum value
# before taking the exponential
amax = tf.reduce_max(risk_scores_masked, axis=axis, keepdims=True)
risk_scores_shift = risk_scores_masked - amax
exp_masked = tf.multiply(tf.exp(risk_scores_shift), mask_f)
exp_sum = tf.reduce_sum(exp_masked, axis=axis, keepdims=True)
output = amax + tf.log(exp_sum)
if not keepdims:
output = tf.squeeze(output, axis=axis)
return output
</code></pre>
<p>To monitor the training process, we would like to compute the concordance index with respect to a separate validation set. Similar to the Cox PH loss, the concordance index needs access to predicted risk scores and ground truth of <em>all</em> samples in the validation data. While we had to opt for computing the Cox PH loss over a mini-batch, I would not recommend this for the validation data. For small batch sizes and/or high amount of censoring, the estimated concordance index would be quite volatile, which makes it very hard to interpret. In addition, the validation data is usually considerably smaller than the training data, therefore we can collect predictions for the whole validation data and compute the concordance index accurately.</p>
<h2 id="creating-a-convolutional-neural-network-for-survival-analysis-on-mnist">Creating a Convolutional Neural Network for Survival Analysis on MNIST</h2>
<p>Finally, after many considerations, we can create a convolutional neural network (CNN) to learn a high-level representation from MNIST digits such that we can estimate each image’s survival function. The CNN follows the LeNet architecture where the last linear has one output unit that corresponds to the predicted risk score. The predicted risk score, together with the binary event indicator and risk set, are the input to the Cox PH loss.</p>
<pre><code class="language-python">def model_fn(features, labels, mode, params):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(6, kernel_size=(5, 5), activation='relu', name='conv_1'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(16, (5, 5), activation='relu', name='conv_2'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu', name='dense_1'),
tf.keras.layers.Dense(84, activation='relu', name='dense_2'),
tf.keras.layers.Dense(1, activation='linear', name='dense_3')
])
risk_score = model(features, training=is_training)
if mode == tf.estimator.ModeKeys.TRAIN:
loss = coxph_loss(
event=tf.expand_dims(labels["label_event"], axis=1),
riskset=labels["label_riskset"],
predictions=risk_score)
optim = tf.train.AdamOptimizer(learning_rate=params["learning_rate"])
gs = tf.train.get_or_create_global_step()
train_op = tf.contrib.layers.optimize_loss(loss, gs,
learning_rate=None,
optimizer=optim)
else:
loss = None
train_op = None
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
predictions={"risk_score": risk_score})
train_spec = tf.estimator.TrainSpec(
InputFunction(x_train, time_train, event_train,
num_epochs=15, drop_last=True, shuffle=True))
eval_spec = tf.estimator.EvalSpec(
InputFunction(x_test, time_test, event_test))
params = {"learning_rate": 0.0001, "model_dir": "ckpts-mnist-cnn"}
estimator = tf.estimator.Estimator(model_fn, model_dir=params["model_dir"], params=params)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
</code></pre>
<figure>
<img src="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/img/loss-and-cindex.png"
alt="TensorBoard plots of training loss and concordance index on test data."/> <figcaption>
<p>TensorBoard plots of training loss and concordance index on test data.</p>
</figcaption>
</figure>
<p>We can make a couple of observations:</p>
<ol>
<li>The final concordance index on the validation data is close to the optimal value we computed above using the actual underlying risk scores.</li>
<li>The loss during training is quite volatile, which stems from the small batch size (64) and the varying number of uncensored samples that contribute to the loss in each batch. Increasing the batch size should yield smoother loss curves.</li>
</ol>
<h3 id="predicting-survival-functions">Predicting Survival Functions</h3>
<p>For inference, things are much easier, we just pass a batch of images and record the predicted risk score. To estimate individual survival functions, we need to estimate the baseline hazard function $h_0$, which can be done analogous to the linear Cox PH model by using <a href="https://www.jstor.org/stable/1402659" target="_blank">Breslow’s estimator</a>.</p>
<pre><code class="language-python">from sklearn.model_selection import train_test_split
from sksurv.linear_model.coxph import BreslowEstimator
def make_pred_fn(images, batch_size=64):
if images.ndim == 3:
images = images[..., np.newaxis]
def _input_fn():
ds = tf.data.Dataset.from_tensor_slices(images)
ds = ds.batch(batch_size)
next_x = ds.make_one_shot_iterator().get_next()
return next_x, None
return _input_fn
train_pred_fn = make_pred_fn(x_train)
train_predictions = np.array([float(pred["risk_score"])
for pred in estimator.predict(train_pred_fn)])
breslow = BreslowEstimator().fit(train_predictions, event_train, time_train)
</code></pre>
<p>Once fitted, we can use Breslow’s estimator to obtain estimated survival functions for images in the test data. We randomly draw three sample images for each digit and plot their predicted survival function.</p>
<pre><code class="language-python">sample = train_test_split(x_test, y_test, event_test, time_test,
test_size=30, stratify=y_test, random_state=89)
x_sample, y_sample, event_sample, time_sample = sample[1::2]
sample_pred_fn = make_pred_fn(x_sample)
sample_predictions = np.array([float(pred["risk_score"])
for pred in estimator.predict(sample_pred_fn)])
sample_surv_fn = breslow.get_survival_function(sample_predictions)
</code></pre>
<figure>
<img src="https://k-d-w.org/blog/2019/07/survival-analysis-for-deep-learning/img/predicted-survival-function.svg"/>
</figure>
<p>Solid lines correspond to images that belong to risk group 0 (with lowest risk), which the model was able to learn. Samples from the group with the highest risk are shown as dotted lines. Their predicted survival functions have the steepest descent, confirming that the model correctly identified different risk groups from images.</p>
<h2 id="conclusion">Conclusion</h2>
<p>We successfully built, trained, and evaluated a convolutional neural network for survival analysis on MNIST. While MNIST is obviously not a clinical dataset, the exact same approach can be used for clinical data. For instance, <a href="https://www.pnas.org/content/115/13/E2970" target="_blank">Mobadersany et al.</a> used the same approach to predict overall survival of patients diagnosed with brain tumors from microscopic images, and <a href="https://scholar.google.com/scholar?cluster=3381426605939025516" target="_blank">Zhu et al.</a> applied CNNs to predict survival of lung cancer patients from pathological images.</p>
- scikit-survival 0.9 releasedhttps://k-d-w.org/blog/2019/07/scikit-survival-0.9-released/Sat, 27 Jul 2019 21:11:43 +0200https://k-d-w.org/blog/2019/07/scikit-survival-0.9-released/<p>This release of <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a> adds support for scikit-learn 0.21 and pandas 0.24, among a couple of other smaller fixes. Please see the <a href="https://scikit-survival.readthedocs.io/en/latest/release_notes.html" target="_blank">release notes</a> for a full list of changes. If you are using scikit-survival in your research, you can now cite it using an <a href="https://zenodo.org/record/3352343" target="_blank">Digital Object Identifier (DOI)</a>.</p>
<p>A usual, the latest version can be obtained via <em>conda</em> or <em>pip</em>. Pre-built conda packages are available for Linux, OSX and Windows:</p>
<pre><code class="language-bash"> conda install -c sebp scikit-survival
</code></pre>
<p>Alternatively, scikit-survival can be installed from source via pip:</p>
<pre><code class="language-bash"> pip install -U scikit-survival
</code></pre>
- Evaluating Survival Modelshttps://k-d-w.org/blog/2019/05/evaluating-survival-models/Sat, 04 May 2019 11:12:05 +0000https://k-d-w.org/blog/2019/05/evaluating-survival-models/<p>The most frequently used evaluation metric of survival models is the concordance index (c index, c statistic). It is a measure of rank correlation between predicted risk scores $\hat{f}$ and observed time points $y$ that is closely related to <a href="https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient" target="_blank">Kendall’s τ</a>. It is defined as the ratio of correctly ordered (concordant) pairs to comparable pairs. Two samples $i$ and $j$ are comparable if the sample with lower observed time $y$ experienced an event, i.e., if $y_j > y_i$ and $\delta_i = 1$, where $\delta_i$ is a binary event indicator. A comparable pair $(i, j)$ is concordant if the estimated risk $\hat{f}$ by a survival model is higher for subjects with lower survival time, i.e., $\hat{f}_i >\hat{f}_j \land y_j > y_i$, otherwise the pair is discordant. Harrell’s estimator of the c index is implemented in <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_censored.html#sksurv.metrics.concordance_index_censored" target="_blank">concordance_index_censored</a>.</p>
<p>While Harrell’s concordance index is easy to interpret and compute, it has some shortcomings:</p>
<ol>
<li>it has been shown that it is too optimistic with increasing amount of censoring <a href="https://dx.doi.org/10.1002/sim.4154" target="_blank">[1]</a>,</li>
<li>it is not a useful measure of performance if a specific time range is of primary interest (e.g. predicting death within 2 years).</li>
</ol>
<p>Since version 0.8, <a href="https://github.com/sebp/scikit-survival" target="_blank">scikit-survival</a> supports an alternative estimator of the concordance index from right-censored survival data, implemented in <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_ipcw.html#sksurv.metrics.concordance_index_ipcw" target="_blank">concordance_index_ipcw</a>, that addresses the first issue.</p>
<p>The second point can be addressed by extending the well known receiver operating characteristic curve (ROC curve) to possibly censored survival times. Given a time point $t$, we can estimate how well a predictive model can distinguishing subjects who will experience an event by time $t$ (sensitivity) from those who will not (specificity). The function <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.cumulative_dynamic_auc.html#sksurv.metrics.cumulative_dynamic_auc" target="_blank">cumulative_dynamic_auc</a> implements an estimator of the cumulative/dynamic area under the ROC for a given list of time points.</p>
<p>The
<a href="https://k-d-w.org/blog/2019/05/evaluating-survival-models/#bias-of-harrell-s-concordance-index">
first part
</a>
of this post will illustrate the first issue with simulated survival data, while the
<a href="https://k-d-w.org/blog/2019/05/evaluating-survival-models/#time-dependent-area-under-the-roc">
second part
</a>
will focus on the time-dependent area under the ROC applied to data from a real study.</p>
<p><strong>To see the full source code for producing the figures in this post, please see <a href="https://github.com/sebp/scikit-survival/blob/master/examples/evaluating-survival-models.ipynb" target="_blank">this notebook</a></strong>.</p>
<h2 id="bias-of-harrell-s-concordance-index">Bias of Harrell’s Concordance Index</h2>
<p>Harrell’s concordance index is known to be biased upwards if the amount of censoring in the test data is high <a href="https://dx.doi.org/10.1002/sim.4154" target="_blank">[1]</a>. <a href="https://dx.doi.org/10.1002/sim.4154" target="_blank">Uno et al.</a> proposed an alternative estimator of the concordance index that behaves better in such situations. In this section, we are going to apply <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_censored.html#sksurv.metrics.concordance_index_censored" target="_blank">concordance_index_censored</a> and <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_ipcw.html#sksurv.metrics.concordance_index_ipcw" target="_blank">concordance_index_ipcw</a> to synthetic survival data and compare their results.</p>
<h3 id="simulation-study">Simulation Study</h3>
<p>We are generating a synthetic biomarker by sampling from a standard normal distribution. For a given hazard ratio, we compute the associated (actual) survival time by drawing from an exponential distribution. The censoring times were generated from a uniform independent distribution $\textrm{Uniform}(0,\gamma)$, where we choose $\gamma$ to produce different amounts of censoring.</p>
<p>Since Uno’s estimator is based on inverse probability of censoring weighting, we need to estimate the probability of being censored at a given time point. This probability needs to be non-zero for all observed time points. Therefore, we restrict the test data to all samples with observed time lower than the maximum event time $\tau$. Usually, one would use the <code>tau</code> argument of <code>concordance_index_ipcw</code> for this, but we apply the selection before to pass identical inputs to <code>concordance_index_censored</code> and <code>concordance_index_ipcw</code>. The estimates of the concordance index are therefore restricted to the interval $[0, \tau]$.</p>
<p>Let us assume a moderate hazard ratio of 2 and generate a small synthetic dataset of 100 samples from which we estimate the concordance index. We repeat this experiment 200 times and plot mean and standard deviation of the difference between the <em>actual</em> (in the absence of censoring) and <em>estimated</em> concordance index.</p>
<p>Since the hazard ratio remains constant and only the amount of censoring changes, we would want an estimator for which the difference between the actual and estimated c to remain approximately constant across simulations.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/05/evaluating-survival-models/img/plot_1.svg" width="400"/>
</figure>
<p>We can observe that estimates are on average below the actual value, except for the highest amount of censoring, where Harrell’s c begins overestimating the performance (on average).</p>
<p>With such a small dataset, the variance of differences is quite big, so let us increase the amount of data to 1000 and repeat the simulation.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/05/evaluating-survival-models/img/plot_2.svg" width="400"/>
</figure>
<p>Now we can observe that Harrell’s c begins to overestimate performance starting with approximately 49% censoring while Uno’s c is still underestimating the performance, but is on average very close to the actual performance for large amounts of censoring.</p>
<p>For the final experiment, we double the size of the dataset to 2000 and repeat the analysis.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/05/evaluating-survival-models/img/plot_3.svg" width="400"/>
</figure>
<p>The trend we observed in the previous simulation is now even more pronounced. Harrell’s c is becoming more and more overconfident in the performance of the synthetic marker with increasing amount of censoring, while Uno’s c remains stable.</p>
<p>In summary, while the difference between <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_ipcw.html#sksurv.metrics.concordance_index_ipcw" target="_blank">concordance_index_ipcw</a> and <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_censored.html#sksurv.metrics.concordance_index_censored" target="_blank">concordance_index_censored</a> is negligible for small amounts of censoring, when analyzing survival data with moderate to high amounts of censoring, you might want to consider estimating the performance using <code>concordance_index_ipcw</code> instead of <code>concordance_index_censored</code>.</p>
<h2 id="time-dependent-area-under-the-roc">Time-dependent Area under the ROC</h2>
<p>The area under the <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic" target="_blank">receiver operating characteristics curve</a> (ROC curve) is a popular performance measure for binary classification task. In the medical domain, it is often used to determine how well estimated risk scores can separate diseased patients (cases) from healthy patients (controls). Given a predicted risk score $\hat{f}$, the ROC curve compares the false positive rate (1 - specificity) against the true positive rate (sensitivity) for each possible value of $\hat{f}$.</p>
<p>When extending the ROC curve to continuous outcomes, in particular survival time, a patient’s disease status is typically not fixed and changes over time: at enrollment a subject is usually healthy, but may be diseased at some later time point. Consequently, sensitivity and specificity become <a href="http://dx.doi.org/10.1111/j.0006-341x.2000.00337.x" target="_blank">time-dependent measures</a>. Here, we consider <em>cumulative cases</em> and <em>dynamic controls</em> at a given time point $t$, which gives rise to the time-dependent cumulative/dynamic ROC at time $t$. Cumulative cases are all individuals that experienced an event prior to or at time $t$ ($t_i \leq t$), whereas dynamic controls are those with $t_i>t$. By computing the area under the cumulative/dynamic ROC at time $t$, we can determine how well a model can distinguish subjects who fail by a given time ($t_i \leq t$) from subjects who fail after this time ($t_i>t$). Hence, it is most relevant if one wants to predict the occurrence of an event in a period up to time $t$ rather than at a specific time point $t$.</p>
<p>The <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.cumulative_dynamic_auc.html#sksurv.metrics.cumulative_dynamic_auc" target="_blank">cumulative_dynamic_auc</a> function implements an estimator of the cumulative/dynamic area under the ROC at a given list of time points. To illustrate its use, we are going to use data from a <a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3538473/" target="_blank">study</a> that investigated to which extent the serum immunoglobulin free light chain (FLC) assay can be used predict overall survival. The dataset has 7874 subjects and 9 features; the endpoint is death, which occurred for 2169 subjects (27.5%).</p>
<p>First, we are loading the data and split it into train and test set to evaluate how well markers generalize.</p>
<pre><code class="language-python">x, y = load_flchain()
(x_train, x_test,
y_train, y_test) = train_test_split(x, y, test_size=0.2, random_state=0)
</code></pre>
<p>Serum creatinine measurements are missing for some patients, therefore we are just going to impute these values with the mean using scikit-learn’s <code>SimpleImputer</code>.</p>
<pre><code class="language-python">num_columns = ['age', 'creatinine', 'kappa', 'lambda']
imputer = SimpleImputer().fit(x_train.loc[:, num_columns])
x_train = imputer.transform(x_train.loc[:, num_columns])
x_test = imputer.transform(x_test.loc[:, num_columns])
</code></pre>
<p>Similar to Uno’s estimator of the concordance index described above, we need to be a little bit careful when selecting the test data and time points we want to evaluate the ROC at, due to the estimator’s dependence on inverse probability of censoring weighting. First, we are going to check whether the observed time of the test data lies within the observed time range of the training data.</p>
<pre><code class="language-python">y_events = y_train[y_train['death']]
train_min, train_max = y_events["futime"].min(), y_events["futime"].max()
y_events = y_test[y_test['death']]
test_min, test_max = y_events["futime"].min(), y_events["futime"].max()
assert train_min <= test_min < test_max < train_max, \
"time range or test data is not within time range of training data."
</code></pre>
<p>When choosing the time points to evaluate the ROC at, it is important to remember to choose the last time point such that the probability of being censored after the last time point is non-zero. In the simulation study above, we set the upper bound to the maximum event time, here we use a more conservative approach by setting the upper bound to the 80% percentile of observed time points, because the censoring rate is quite large at 72.5%. Note that this approach would be appropriate for choosing <code>tau</code> of <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_ipcw.html#sksurv.metrics.concordance_index_ipcw" target="_blank">concordance_index_ipcw</a> too.</p>
<pre><code class="language-python">times = np.percentile(y["futime"], np.linspace(5, 81, 15))
print(times)
</code></pre>
<pre><code class="language-python">[ 470.3 1259. 1998. 2464.82428571 2979.
3401. 3787.99857143 4051. 4249. 4410.17285714
4543. 4631. 4695. 4781. 4844. ]
</code></pre>
<p>We begin by considering individual real-valued features as risk scores without actually fitting a survival model. Hence, we obtain an estimate of how well age, creatinine, kappa FLC, and lambda FLC are able to distinguish cases from controls at each time point.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/05/evaluating-survival-models/img/plot_4.svg" width="400"/>
</figure>
<p>The plot shows the estimated area under the time-dependent ROC at each time point and the average across all time points as dashed line.</p>
<p>We can see that age is overall the most discriminative feature, followed by $\kappa$ and $\lambda$ FLC. That fact that age is the strongest predictor of overall survival in the general population is hardly surprising (we have to die at some point after all). More differences become evident when considering time: the discriminative power of FLC decreases at later time points, while that of age increases. The observation for age again follows common sense. In contrast, FLC seems to be a good predictor of death in the near future, but not so much if it occurs decades later.</p>
<p>Next, we will fit an actual survival model to predict the risk of death from the <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.datasets.load_veterans_lung_cancer.html#sksurv.datasets.load_veterans_lung_cancer" target="_blank">Veterans’ Administration Lung Cancer Trial</a>. After fitting a Cox proportional hazards model, we want to assess how well the model can distinguish survivors from deceased in weekly intervals, up to 6 months after enrollment.</p>
<figure>
<img src="https://k-d-w.org/blog/2019/05/evaluating-survival-models/img/plot_5.svg" width="400"/>
</figure>
<p>The plot shows that the model is doing quite well on average with an AUC of ~0.82 (dashed line). However, there is a clear difference in performance between the first and second half of the time range. Performance increases up to about 100 days from enrollment, but quickly drops thereafter. Thus, we can conclude that the model is less effective in predicting death past 100 days.</p>
<h2 id="conclusion">Conclusion</h2>
<p>I hope this post helped you to understand some of the pitfalls when estimating the performance of markers and models from right-censored survival data. We illustrated that <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_censored.html#sksurv.metrics.concordance_index_censored" target="_blank">Harrell’s estimator</a> of the concordance index is biased when the amount of censoring is high, and that <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.concordance_index_ipcw.html#sksurv.metrics.concordance_index_ipcw" target="_blank">Uno’s estimator</a> is more appropriate in this situation. Finally, we demonstrated that the <a href="https://scikit-survival.readthedocs.io/en/latest/generated/sksurv.metrics.cumulative_dynamic_auc.html#sksurv.metrics.cumulative_dynamic_auc" target="_blank">time-dependent area under the ROC</a> is a very useful tool when we want to predict the occurrence of an event in a period up to time $t$ rather than at a specific time point $t$.</p>