This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

[2,3]\fnmZongyuan \surGe

1]\orgdivDepartment of Electrical and Computer Systems Engineering, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

[2]\orgdivAIM for Health Lab, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

3]\orgdivDepartment of Data Science and AI, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

4]\orgdivMonash Centre for Health Research and Implementation, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

5]\orgdivDepartment of Neuroscience, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

6]\orgdivDepartment of Neurology, \orgnameMonash University, \orgaddress\cityMelbourne, \countryAustralia

Adaptive Transformer Modelling of Density Function for Nonparametric Survival Analysis

\fnmXin \surZhang [email protected]    \fnmDeval \surMehta [email protected]    \fnmYanan \surHu [email protected]    \fnmChao \surZhu [email protected]    \fnmDavid \surDarby [email protected]    \fnmZhen \surYu [email protected]    \fnmDaniel \surMerlo [email protected]    \fnmMelissa \surGresle [email protected]    \fnmAnneke \survan der Walt [email protected]    \fnmHelmut \surButzkueven [email protected]    [email protected] [ * [ [ [ [
Abstract

Survival analysis holds a crucial role across diverse disciplines, such as economics, engineering and healthcare. It empowers researchers to analyze both time-invariant and time-varying data, encompassing phenomena like customer churn, material degradation and various medical outcomes. Given the complexity and heterogeneity of such data, recent endeavors have demonstrated successful integration of deep learning methodologies to address limitations in conventional statistical approaches. However, current methods typically involve cluttered probability distribution function (PDF), have lower sensitivity in censoring prediction, only model static datasets, or only rely on recurrent neural networks for dynamic modelling. In this paper, we propose a novel survival regression method capable of producing high-quality unimodal PDFs without any prior distribution assumption, by optimizing novel Margin-Mean-Variance loss and leveraging the flexibility of Transformer to handle both temporal and non-temporal data, coined UniSurv. Extensive experiments on several datasets demonstrate that UniSurv places a significantly higher emphasis on censoring compared to other methods.

keywords:
Survival analysis, Transformer, Margin-Mean-Variance loss, Deep learning

1 Introduction

The primary task of survival analysis is to determine the timing of one or multiple events, which can signify the moment of a mechanical system malfunction, the period of transition from corporate deficit to surplus, the instance of patient fatality or so on, depending on the specific circumstance [19]. Among all scenarios, survival analysis for medical data poses the most severe challenges [1]. Some medical datasets are longitudinal, as exemplified by electronic health records (EHRs), where multiple observations of each patient’s covariates over time are recorded. Survival models must be capable of handling such measurements and learning from their continuous temporal trends. Moreover, observations in longitudinal data are often sparse, necessitating the effective handling of missing values for any reliable survival model, even when the missing rates are exceedingly high [32]. Additionally, censoring represents a fundamental aspect of survival data, referring to cases in which complete information regarding the survival time or event occurrence of a subject is not fully observed or available within the study period [20]. The occurrence of censoring signifies the unknown exact timing of the event, consequently lacking ground truth for comparative learning. This, in turn, poses significant challenges for deep survival learning. Existing deep learning approaches aim at mitigating this issue by typically guaranteeing non-occurrence of events before censoring. Notwithstanding, detailed elucidation pertaining to the temporal aspect of events subsequent to censoring frequently remains inadequately explored.

Developing survival analysis models requires regressing the probability of survival over a defined period. A high-quality estimation of probability distribution is essential for the time-to-event prediction. As the initial category, parametric survival models are capable of generating high-quality probability density function (PDF) or survival curve by predetermining stochastic distribution, however, their precision is contingent upon the validity of all underlying assumptions. In contrast, non-parametric models do not presume any prior distribution of events, but they struggle to accurately predict PDF over extended temporal spans within medical datasets, consequently yielding PDF or survival curve of comparatively lower quality.

To address the challenges in the development of survival models and to mitigate the limitations inherent in existing models, we propose UniSurv, a non-parametric model based on the Transformer architecture. In particular, UniSurv can: 1) generate higher quality PDF resembling normal distribution without any prior probability assumption, and significantly improve accuracy for predicting censoring by integrating novel Margin-Mean-Variance loss; 2) use distinct embedding branches for static and dynamic feature extractions separately; 3) effectively handle cases with high missing rates of longitudinal data and various data modalities. The superiority of the UniSurv is substantiated through empirical evidence obtained from real and synthetic datasets.

2 Literature

Semi- and fully-parametric models heavily rely on the premise of making explicit assumptions about the underlying distribution of event times. They provide a structured framework for understanding the relationship between covariates and the occurrence of events over time. However, the strength of these assumptions results in overly simplistic probability distributions predicted by the models. The lack of flexibility stemming from this oversimplification also renders these models impractical in various scenarios. Cox proportional hazard (CPH) [2] is a prime example in this field. It estimates the hazard function λ(t|X)\lambda(t|X) by multiplying a predetermined base hazard function λ0(t)\lambda_{0}(t) with the learnt representation of features g(X)g(X). Subsequent studies [5, 35, 21] have used more sophisticated models to improve the CPH model. However, the oversimplified stochastic process continues to constrain their predictive capabilities, and it is unable to conduct dynamic analysis. Meanwhile, [24] introduced Deep Survival Machines (DSM), which postulates that the survival function is a composition of multiple Weibull and log-normal distributions. The parameters of those distributions are estimated by a multi-layer perceptron (MLP). Besides, [23] illustrated Recurrent DSM (RDSM) by incorporating recurrent neural network (RNN) into DSM, thereby endowing it to process dynamic analysis. Nonetheless, DSM models exhibit suboptimal accuracy in predicting event times. Its loss function frequently becomes divergent during training, contributing to the overfitting problem.

Some recent works have concentrated on static analysis. For example, DeepSurv [13] employs an MLP network to replace the parametric assumptions of the hazard function present in the conventional CPH. This transformation results in a semi-parametric variant of the CPH model. The incorporation of neural networks enhances its flexibility by enabling the model to learn nonlinear relationships more adeptly from covariates. Besides, Deep Cox Mixtures (DCM) [25] encounters the same underlying assumption of proportional hazards, wherein it assumes the presence of latent groups. Employing Variational Autoencoders for clustering, DCM assumes the validity of proportional hazards within each latent group. Moreover, [11] propose an extension of random forest (RF) algorithm, named Random Survival Forest (RSF), which initially breaks through the inherent assumptions of CPH. It computes the risk scores through the generation of Nelson-Aalen estimators within the partitions established by RF. RSF assumes independence among trees in forest, which might not always hold. This assumption can affect its performance, usually when correlations or dependencies exist among survival trees.

Several studies have explored the dynamic analysis field. [17] propose DeepHit for competing risk events as a non-parametric model. The encoder of DeepHit is constructed as a joint MLP, while its decoder employs a series of distinct MLPs to address individual events. This design results in the generation of separate PDFs for each event. [18] further extend it into Dynamic DeepHit (DDH) by replacing the encoder with RNNs followed by an attention mechanism, to process longitudinal data. A primary limitation lies in the arbitrary fluctuations between adjacent predictions within the output layer, resulting in noise present in the final PDFs. This phenomenon becomes particularly pronounced when forecasting over a long-time horizon. Survival SEQ2SEQ (SS2S) [30] addresses this issue and takes advantage of RNN cells in their framework decoder to generate smoother PDFs. However, its approach to handling censoring is overly simplistic, focusing solely on the premise that events should not occur before the censoring time, without addressing any potential implications after it. Another model that shares a similar concern in handling censoring is the Transformer-based Deep Survival Model (TDSM) [9]. This deficiency is evident in the designs of their loss functions. The previous transformer architectures in survival analysis include TDSM and SurvTRACE [36], but none of them have been extended to handle dynamic analysis.

In summary, while most existing methods could perform static analysis, only three of them could handle longitudinal data. Despite all the advancements in these recent works, there is a lack of a universal model which could jointly integrate handling miss data method, process various input formats and produce organized PDFs.

3 Method

In this section, we introduce our formal framework UniSurv, which is a adaptive Transformer-based architecture for survival analysis. We assume that the available survival dataset is subject to right censoring.

3.1 Survival Notation

We denote time-invariant and time-varying covariates by 𝒙n\boldsymbol{x}_{n} and 𝒙v\boldsymbol{x}_{v}, probability by pp, time by TT, tt or τ\tau, PDF by p(t)p(t) and survival function by S(t)S(t). Let’s represent the survival dataset as (𝒙ni,𝒙vi,Ti,δi)i=1N{(\boldsymbol{x}_{n}^{i},\boldsymbol{x}_{v}^{i},T^{i},\delta^{i})}_{i=1}^{N}, where for individual ii, δi\delta^{i} is the event indicator typically taken from the set {0,1}\{0,1\} without competing risks, and TiT^{i} represents the event or censoring time depending on δi\delta^{i}. We omit the explicit dependence on ii throughout this and the next subsections for simplifying notation.

We assume that time t{T0,T1,,Tmax}t\in\{T_{0},T_{1},...,T_{max}\} to fit a discrete survival model, where tt is a discrete random variable, and TjT_{j} is each time step with equal interval. The cumulative distribution function (CDF) of tt can be easily calculated by its PDF as

CDF(Tj(𝒙n,𝒙v))=t=T0TjptCDF(T_{j}\mid(\boldsymbol{x}_{n},\boldsymbol{x}_{v}))=\sum_{t=T_{0}}^{T_{j}}p_{t} (1)

Having defined the probability that an event has occurred by duration TjT_{j}, the survival function can then be estimated as the probability that the survival time tt is at least TjT_{j}. It can also be represented as the complement of CDF as

S(Tj(𝒙n,𝒙v))=1CDF(Tj(𝒙n,𝒙v))=t=TjTmaxptS(T_{j}\mid(\boldsymbol{x}_{n},\boldsymbol{x}_{v}))=1-CDF(T_{j}\mid(\boldsymbol{x}_{n},\boldsymbol{x}_{v}))=\sum_{t=T_{j}}^{T_{max}}p_{t} (2)

3.2 Model Description

Refer to caption

(a) UniSurv framework

Refer to caption

(b) Overall schemata
Figure 1: The illustration of (a) the architecture of UniSurv model and (b) a schematic representation of the UniSurv during training and testing stages

Fig. 1(a) presents a comprehensive illustration of the UniSurv model. It integrates a novel survival loss design to enable a seamless end-to-end learning procedure. It encompasses dynamic and static extraction components, coupled with a Transformer encoder module culminating in a softmax layer at the output terminal. Besides, Fig. 1(b) depicts the conceptual process of UniSurv during both training and testing stages.

3.2.1 Static and Dynamic Extractions

We integrate the variation of last-observation-carried-forward (LOCF) method to handle missing data. It duplicates the value of the last observation to replace the following missing values until Δτ=3\Delta_{\tau}=3, and the ones that are still missing after LOCF are imputed by mean/mode of all previous points for continuous/binary covariates until TmaxT_{max}, making sure no missing data for all time points TjT_{j} [18]. Next, we extract latent representations of time-invariant features 𝒙n\boldsymbol{x}_{n} and time-varying features 𝒙v\boldsymbol{x}_{v} by static (-s) and tabular-data-based dynamic (-d1) extraction modules separately. These modules are constructed using MLPs to deal with numerical tabular data. The representation of 𝒙n\boldsymbol{x}_{n} is replicated Tmax+1T_{max}+1 times by encompassing T0T_{0} as well. For static modelling, these are subsequently transmitted to encoder. For dynamic modelling, these representations are concatenated with the representation of 𝒙v\boldsymbol{x}_{v} at their respective time points before the transmission. Meanwhile, a convolutional neural network (CNN) variation (-d2) of the image-based dynamic extraction module can be used to address the image-like input formats. As shown in Fig. 1(a), the extracted feature can be shared across a predefined time window TwT_{w} in light of the data sparsity.

3.2.2 Transformer Encoder

The core of the encoder module is a Transformer, which treats each patient as a ‘sentence’ and the embedded features as ‘words’ of the sentence. For an input sample, the number of words correspond to the duration t{0,1,,Tmax}t\in\{0,1,...,T_{max}\}, where we predefined T0=0T_{0}=0 and TmaxT_{max} is a hyper-parameter selected based on the longest temporal data of a dataset. The previous concatenated representations pass through a MLP followed by layer normalization to get the embedded features. Following the conventional approach of a Transformer, we utilize the sine-cosine positional embedding [34] as temporal embedding in this work, and add it onto the set of embedded features, whose length is set as the embedding dimension dmd_{m} of the Transformer. The Transformer encoder then processes embedded features and produces Tmax+1T_{max}+1 outputs, each with shape 1×dm1\times d_{m}. It is worth noting that the self-attention layers in the encoder is modified to prevent positions from attending to subsequent positions. Specifically, it prohibits each position from attending to subsequent positions, and the attention scores for all illegal connections are masked out by assigning them with -\infty [34]. Next, all-time point outputs are fed into an exclusive 2-layer MLP. The first layer is followed by rectified linear unit (ReLU) and layer normalization, with shape of dm×dm2d_{m}\times\frac{d_{m}}{2}. The second layer, with shape of dm×1d_{m}\times 1, is followed by a softmax layer to produce the individual estimated PDF. Further, the estimated survival function S^(t(𝒙n,𝒙v))\hat{S}(t\mid(\boldsymbol{x}_{n},\boldsymbol{x}_{v})) can be calculated based on Eq. 2, which ensures its monotonicity is preserved. Moreover, in discrete survival analysis, the mean lifetime μ\mu can be approximated by the sum of the survival probabilities up to TmaxT_{max}. We could get the estimated mean lifetime by further involving Eq. 2 as

μ^=t=T0TmaxS^(t)=t=T0Tmaxτ=tTmaxp^τt=T0Tmaxtp^t\hat{\mu}=\sum_{t=T_{0}}^{T_{max}}\hat{S}(t)=\sum_{t=T_{0}}^{T_{max}}\sum_{\tau=t}^{T_{max}}\hat{p}_{\tau}\approx\sum_{t=T_{0}}^{T_{max}}t\cdot\hat{p}_{t} (3)

where the employment of the approximately equal symbol in the equation is attributed to the presence of time point T0=0T_{0}=0. The variance of distribution is computed as

v=t=T0Tmaxp^t(tμ^)2v=\sum_{t=T_{0}}^{T_{max}}\hat{p}_{t}\cdot(t-\hat{\mu})^{2} (4)

3.3 Loss Function

To robustly estimate the uncensored survival time via distribution learning and generating smooth PDF, we adopt a variation of Mean-Variance loss [26] in UniSurv, which requiring each training sample has a corresponding event time label. However, censoring existing in survival dataset does not have event but censoring time. Using censoring time as the label will be misleading for the model, resulting in prediction bias. To overcome this, we employ the “margin time” concept [8], to assign a “best guess” value to each censored subject based on the non-parametric population Kaplan-Meier (KM) [12] estimator. Given a subject ii censored at time TiT^{i}, its margin event time is calculated by

emi=Ti+TiTmaxSkm(Dtr)(t)𝑑tSkm(Dtr)(Ti)e_{m}^{i}=T^{i}+\frac{\int_{T^{i}}^{T_{max}}S_{km(D_{tr})}(t)dt}{S_{km(D_{tr})}(T^{i})} (5)

where Skm(Dtr)S_{km(D_{tr})} is the KM estimation derived from the training dataset. It is worth to mention that during the integration process, we compute it up to TmaxT_{max} in this work, ensuring that eme_{m} remains within the bounds of TmaxT_{max}, which stands in contrast to the approach proposed by [8]. They extend the KM curve infinitely through risky extrapolation beyond observed values.

We denote TiT^{i} as the corresponding ground-truth event/censoring time for individual ii. With the estimated mean lifetime μ^i\hat{\mu}^{i}, the Margin-Mean loss can be computed as

mm=12i=1N(δi(μ^iTi)2+(1δi)ωi(μ^iemi)2)\mathcal{L}_{mm}=\frac{1}{2}\sum_{i=1}^{N}\Big{(}\delta^{i}\cdot(\hat{\mu}^{i}-T^{i})^{2}+(1-\delta^{i})\cdot\omega^{i}\cdot(\hat{\mu}^{i}-e_{m}^{i})^{2}\Big{)} (6)

where ωi=1Skm(Dtr)(Ti)\omega^{i}=1-S_{km(D_{tr})}(T^{i}) can give high confidence weight with late censor time but low with early censor time. Margin-Mean loss can penalize dissimilarity between estimated mean lifetime and actual event/margin-event time. Besides, the Variance loss is calculated as

v=i=1Nvi\mathcal{L}_{v}=\sum_{i=1}^{N}v^{i} (7)

which is implemented to regulate the spread of the estimated survival distribution, limiting it to a narrow range within the mean. Considering with Eq. 4, v\mathcal{L}_{v} can cause the probabilities at time points farther from μ^i\hat{\mu}^{i} to approach 0. The softmax loss, as known as cross-entropy loss, can be computed as

s=i=1Nlogpi,Ti\mathcal{L}_{s}=\sum_{i=1}^{N}-logp_{i,T^{i}} (8)

which is further utilized to aid in early training convergence, as Margin-Mean-Variance loss alone may experience substantial fluctuations [26].

Finally, tailored to address uncensoring, we utilize discordant loss by

d=i=1Nδimax{0,(TkTi)(μ^kμ^i)}\mathcal{L}_{d}=\sum_{i=1}^{N}\delta^{i}\cdot max\{0,(T^{k}-T^{i})-(\hat{\mu}^{k}-\hat{\mu}^{i})\} (9)

which can penalize the randomized discordant pairs for improving model’s pairwise ranking ability. The process is similar to randomized algorithm: random sampling with replacement individual kk for each confirmed individual ii, making sure that Tk>TiT^{k}>T^{i} and the difference between estimated times should not be smaller than the difference between ground truths. d\mathcal{L}_{d} can further penalize the discordant pairs because when TkT^{k} and TiT^{i} are close, the Margin-Mean-Variance loss cannot effectively discriminate discordant pairs and may fall into a local optimum.

The total loss to train UniSurv is the combination of the above four losses as

total=s+λmmm+λvv+λdd\mathcal{L}_{total}=\mathcal{L}_{s}+\lambda_{m}\mathcal{L}_{mm}+\lambda_{v}\mathcal{L}_{v}+\lambda_{d}\mathcal{L}_{d} (10)

where λm\lambda_{m}, λv\lambda_{v}, λd\lambda_{d} are weights for the corresponding loss functions.

4 Experiments

In this section, we demonstrate the effectiveness of UniSurv by comparing it with other benchmarks on real and synthetic datasets from static and dynamic settings.

4.1 Datasets

To highlight the right-skewed characteristic of survival data, we utilized three real-world datasets and two long-tailed synthetic datasets.

4.1.1 Static Datasets

The Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) [14] is a large static survival dataset of seriously ill hospitalized adults. The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) [3] is a static breast cancer dataset aiming to distinguish its subtypes based on the molecular characteristics. Their pre-processing strategies follow DeepSurv [13].

We also generate a static synthetic dataset (SYNTH-s) of the style of that in [17] but without competing risks. The dataset contains N=15,100N=15,100 examples drawn from the stochastic process

𝒙ni𝒩(0,𝐈)Tiexp(𝜸nT𝒙ni)\begin{split}&\boldsymbol{x}_{n}^{i}\sim\mathcal{N}(0,\mathbf{I})\\ &T^{i}\sim{\rm exp}(\boldsymbol{\gamma}_{n}^{T}\boldsymbol{x}_{n}^{i})\end{split} (11)

where 𝒙ni\boldsymbol{x}_{n}^{i} is a vector of 4-dimensional variables and 𝜸n=𝟏𝟎\boldsymbol{\gamma}_{n}=\boldsymbol{10}. We randomly select 50%50\% patients to be right-censored with random censoring time uniformly drawn from [0,Ti][0,T^{i}]. More details are listed in Tab. 1.

Table 1: Descriptive statistics of three real world medical datasets and two synthetic datasets
Dataset Longitudinal Uncensored Censored Features Event Time Censoring Time
static dynamic min max mean min max mean
SUPPORT No 6036 (68%) 2837 (32%) 14 - 0 65 6.85 12 68 35.33
METABRIC No 1103 (58%) 801 (42%) 9 - 0 355 99.95 0 337 159.55
SYNTH-s No 7600 (50%) 7500 (50%) 4 - 0 192 22.45 0 165 10.80
MSReactor Yes 148 (20%) 598 (80%) 8 90 1 68 17.54 0 80 45.43
SYNTH-d Yes 7462 (49%) 7638 (51%) 4 20 0 199 57.96 0 195 28.59

4.1.2 Dynamic Datasets

On the basis of SYNTH-s, we further generate dynamic synthetic dataset (SYNTH-d) by adding additional dynamic variables following Weibull distribution, and introducing temporal noise disturbances to make them variable over time as

𝒙vi(t)aβ(𝒙β)(a1)exp((𝒙β)a)+𝒩(0,𝐈)Tiexp(𝜸v1Tmax(𝒙v1i(t))+𝜸v2Tmin(𝒙v2i(t))+𝜸nT𝒙ni)\begin{split}&\boldsymbol{x}_{v}^{i}(t)\sim\frac{a}{\beta}\Big{(}\frac{\boldsymbol{x}}{\beta}\Big{)}^{(a-1)}exp\Big{(}-(\frac{\boldsymbol{x}}{\beta})^{a}\Big{)}+\mathcal{N}(0,\mathbf{I})\\ &T^{i}\sim{\rm exp}\Big{(}\boldsymbol{\gamma}_{v_{1}}^{T}\cdot max\big{(}\boldsymbol{x}_{v_{1}}^{i}(t)\big{)}+\boldsymbol{\gamma}_{v_{2}}^{T}\cdot min\big{(}\boldsymbol{x}_{v_{2}}^{i}(t)\big{)}+\boldsymbol{\gamma}_{n}^{T}\boldsymbol{x}_{n}^{i}\Big{)}\end{split} (12)

where 𝒙vi\boldsymbol{x}_{v}^{i} is a 20×Tmax20\times T_{max} dynamic variable matrix for all time points, aa is the shape parameter, β\beta is the scale parameter, 𝜸v1=𝜸v2=𝟓\boldsymbol{\gamma}_{v_{1}}=\boldsymbol{\gamma}_{v_{2}}=\boldsymbol{5}, and max()max(\cdot) and min()min(\cdot) are operations on the temporal dimension. Besides, v1v_{1} and v2v_{2} are two randomly selected subsets that satisfy v1v2=v_{1}\cap v_{2}=\varnothing, v1v2=vv_{1}\cup v_{2}=v. TiT^{i} is then resampled, and the method of introducing censoring cases remains the same as before.

Moreover, MSReactor [22] dataset is a quantifiable, objective collection on cognition via longitudinal computerized test for Multiple Sclerosis (MS), integrating with other 8 static covariates. In each test, patients are instructed to respond as quickly as possible to onscreen stimuli, and their reaction time is recorded in millisecond (ms). The test includes 3 different tasks for testing their psychomotor function, attention and working memory. Each patient undergoes the test a number of times after the diagnosis and prior to the occurrence of the event/censoring (with at least one-month interval between every two adjacent tests). The survival event is characterized by EDSS progression through the six-month disability worsening confirmation rule [10]. Numerous research investigations have indicated that utilizing reaction data could potentially offer a more responsive approach for detecting subclinical cognitive impairment in comparison to current cognitive assessment methods [6, 27, 37].

Refer to caption
Figure 2: The illustration of reaction tensor representation of a single individual in MSReactor

The longitudinal reaction test will be considered as time-varying covariates. However, due to certain redundancies present in MSReactor, evident through pronounced inter-column associations, the characteristics of adjacent columns exhibit robust correlations, deviating from the conventional tabular data extraction where each column represents highly streamlined information. Without the application of specialized data preprocessing techniques and innovative model architectures, the existing survival models may encounter difficulties in extracting meaningful latent patterns from this data. Therefore, we transform the longitudinal tabular data part into a composite “reaction tensor” after monthly imputation, and utilize the certain module to deal with it. Specifically, each patient has a unique 3-dimensional reaction tensor111More details are illustrated in Appendix A., as shown in Fig. 2. Its Z-axis is corresponding to the 3 different tasks. X-axis is response dimension with fixed length of 30, corresponding to the 30 times the patient needs to finish three tasks separately in each test. Y-axis corresponds to the times patient has undergone tests per month with fixed length from the start time T0T_{0} to the end time TmaxT_{max}. The reaction tensor is divided into several smaller tensors along Y-axis by TwT_{w} as in Fig. 1(a).

4.2 Evaluation Metrics

We utilize ranking measures such as concordance index (C-index) from lifelines [4] library and mean cumulative area under ROC curve (mAUC) from scikit-survival [29] library, and accuracy measures such as mean absolute error (MAE) as the evaluation metrics for all experiments.

4.2.1 Concordance

C-index [33] is able to estimate ranking ability by comparing relative risks across all pairs in the test set as

C-index=i,kδi𝕀(Ti<Tk)𝕀(μ^i<μ^k)i,kδi𝕀(Ti<Tk)\text{C-index}=\frac{{\textstyle\sum_{i,k}}\delta^{i}\cdot\mathbb{I}(T^{i}<T^{k})\cdot\mathbb{I}(\hat{\mu}^{i}<\hat{\mu}^{k})}{{\textstyle\sum_{i,k}}\delta^{i}\cdot\mathbb{I}(T^{i}<T^{k})} (13)

where 𝕀()\mathbb{I}(\star) is an indicator function, and δi=0\delta_{i}=0 if TiT^{i} is uncensored and 1 otherwise.

4.2.2 MAE-Uncensored

MAE-Uncensored (MAE-U) can compensate for the inability of C-index to measure the mean absolute value of the estimated risk score. It is computed as

MAE-U=iδi|Tiμ^i|iδi\text{MAE-U}=\frac{{\textstyle\sum_{i}}\delta^{i}\cdot\left|T^{i}-\hat{\mu}^{i}\right|}{{\textstyle\sum_{i}}\delta^{i}} (14)

4.2.3 MAE-Hinge

MAE-Hinge (MAE-H) is a one-sided MAE for only censoring cases, opposite with MAE-U for uncensoring only. It considers only if the predicted time μ^\hat{\mu} is earlier than the censored time TT as follow

MAE-H=i(1δi)max{Tiμ^i,0}i(1δi)\text{MAE-H}=\frac{{\textstyle\sum_{i}}(1-\delta^{i})\cdot max\{T^{i}-\hat{\mu}^{i},0\}}{{\textstyle\sum_{i}}(1-\delta^{i})} (15)

4.2.4 Mean Cumulative Area Under ROC Curve

The area under ROC curve for survival analysis involves treating survival issue as binary classification across various quantiles of event times and defining the sensitivity and specificity as time-dependent measures [16]. The cumulative AUC measures model’s capability of discriminating individuals who fail by a specified tt (TjtT_{j}\leq t) from subjects who fail after this time (Tj>tT_{j}>t). We compute the mAUC by integrating the cumulative AUC over all time range (Tj,Tj+1)(T_{j},T_{j}+1).

4.3 Experimental Setting

We compare with five static benchmarks, including CPH, DeepSurv, DeepHit, DSM and TDSM, and two dynamic benchmarks222We have not compared with SS2S in this study as its code has not been made publicly at the moment., including DDH and RDSM. As static dataset does not have longitudinal covariates, our dynamic extraction module in UniSurv is in non-activation mode named UniSurv-s. For MSReactor, the dynamic extraction module has two variants based on different data representations, tabular data representation named UniSurv-d1 and image-like representation named UniSurv-d2. More implementation and hyperparameter details are in the Appendix B333Code availability: https://github.com/XinZ0419/UniSurv.

For a fair comparison, we use C-index as early stopping criterion for all approaches as it can cover more subjects than MAE. We report the results by using cross-validation, randomly splitting datasets 5 times into training, validation and test sets with ratio 7:1:2. All experiments are implemented in PyTorch 2.0.1 on the same environments with a fixed random seed.

4.4 Benchmarking Results

Table 2: Benchmarking on three static datasets. “†” denotes P-Value <0.05<0.05, where “w/o mask” means ”without masking” and is not in comparison. Higher (\uparrow) values of C-index and mAUC; and lower (\downarrow) values of MAE-U and MAE-H are better
Model SUPPORT METABRIC SYNTH-s
C-index \uparrow MAE-U \downarrow MAE-H \downarrow mAUC \uparrow C-index \uparrow MAE-U \downarrow MAE-H \downarrow mAUC \uparrow C-index \uparrow MAE-U \downarrow MAE-H \downarrow mAUC \uparrow
CPH 0.585.0120.585_{.012} 19.251.2319.25_{1.23} 17.853.1217.85_{3.12} 0.715.0130.715_{.013} 0.633.0230.633_{.023} 81.174.1381.17_{4.13} 33.272.1033.27_{2.10} 0.808.0210.808_{.021} 0.702.0070.702_{.007} 22.853.3622.85_{3.36} 5.480.255.48_{0.25} 0.838.0070.838_{.007}
DeepSurv 0.610.017\boldsymbol{0.610}_{.017} 18.761.1218.76_{1.12} 14.93¯2.58\underline{14.93}_{2.58} 0.789.016\boldsymbol{0.789}_{.016} 0.642.020\boldsymbol{0.642}_{.020} 77.685.1277.68_{5.12} 34.811.8534.81_{1.85} 0.822.0160.822_{.016} 0.701.0090.701_{.009} 21.743.0621.74_{3.06} 5.850.465.85_{0.46} 0.835.0080.835_{.008}
DeepHit 0.601.0140.601_{.014} 17.452.5717.45_{2.57} 21.253.4321.25_{3.43} 0.738.0180.738_{.018} 0.636.0210.636_{.021} 78.255.1478.25_{5.14} 31.65¯1.75\underline{31.65}_{1.75} 0.811.0180.811_{.018} 0.723¯.006\underline{0.723}_{.006} 22.373.1922.37_{3.19} 5.450.145.45_{0.14} 0.859¯.005\underline{0.859}_{.005}
DSM 0.602.0050.602_{.005} 17.582.1917.58_{2.19} 19.693.7419.69_{3.74} 0.742.0070.742_{.007} 0.633.0280.633_{.028} 75.184.3175.18_{4.31} 32.212.1332.21_{2.13} 0.805.0110.805_{.011} 0.685.0100.685_{.010} 24.173.7424.17_{3.74} 5.36¯0.26\underline{5.36}_{0.26} 0.823.0100.823_{.010}
TDSM 0.603.0070.603_{.007} 8.670.73\boldsymbol{8.67}^{{\dagger}}_{0.73} 29.323.4129.32_{3.41} 0.762.0110.762_{.011} 0.637.0180.637_{.018} 55.975.83\boldsymbol{55.97}^{{\dagger}}_{5.83} 62.332.9162.33_{2.91} 0.824¯.017\underline{0.824}_{.017} 0.718.0070.718_{.007} 11.253.57\boldsymbol{11.25}_{3.57} 6.470.346.47_{0.34} 0.850.0060.850_{.006}
UniSurv-s 0.604¯.007\underline{0.604}_{.007} 17.35¯1.55\underline{17.35}_{1.55} 12.922.61\boldsymbol{12.92}^{{\dagger}}_{2.61} 0.767¯.012\underline{0.767}_{.012} 0.638¯.021\underline{0.638}_{.021} 71.30¯6.23\underline{71.30}_{6.23} 23.062.68\boldsymbol{23.06}^{{\dagger}}_{2.68} 0.826.014\boldsymbol{0.826}_{.014} 0.731.008\boldsymbol{0.731}_{.008} 19.15¯3.23\underline{19.15}_{3.23} 1.550.04\boldsymbol{1.55}^{{\dagger}}_{0.04} 0.866.006\boldsymbol{0.866}_{.006}
UniSurv-s w/o mask 0.603.0080.603_{.008} 17.421.5617.42_{1.56} 12.802.6312.80_{2.63} 0.769.0100.769_{.010} 0.638.0200.638_{.020} 73.225.4573.22_{5.45} 22.322.2922.32_{2.29} 0.826.0140.826_{.014} 0.732.0070.732_{.007} 19.233.2819.23_{3.28} 1.510.041.51_{0.04} 0.868.0050.868_{.005}

Performance comparisons for all datasets are summarized in Tab. 2 and Tab. 3. We bold the best and underline the second best. Besides, the statistical significance is determined by paired t-test between the best results and all others individually.

4.4.1 Static Modelling Results

In terms of C-index, our UniSurv-s secures the first position on SYNTH-s and the second position on both SUPPORT and METABRIC. It also reaches the best mAUC on METABRIC and SYNTH-s and the second best on SUPPORT. DeepSurv shows comparable ranking performance on two real-world datasets. This illustrates that parametric model still hold a slight advantage over non-parametric model, rely on its robust probability distribution assumptions. Meanwhile, the performances of the other four models vary, creating a competitive landscape. This makes it difficult to definitively judge their performance under single ranking metrics.

Meanwhile, our UniSurv performs well in MAE-U and exhibits notably superior performance in the realm of MAE-H, with statistical significance compared to other models. Only DeepSurv in SUPPORT is comparable with ours in both two MAEs. Conversely, the performance of TDSM, while excelling in MAE-U, lags notably behind in MAE-H. This is because the loss design of TDSM leads to overfitting on uncensored data throughout the learning process, failing to capture the fact that most censored samples have longer survival times. Further, the inadequacy of TDSM’s predictions for censoring is also evident by Fig. 3, in which we represent the difference between true censoring time and estimated mean lifetime with red lines for some censoring cases. We show the METABRIC results from TDSM, UniSurv-s and the second-best MAE-U model DeepHit here. The more and longer red lines, the model have less sensitivity of censoring prediction. It can be observed that UniSurv has the capability to provide accurate predictions for the majority of censoring cases. This outcome can be attributed to the incorporation of the MAE-margin concept within the Margin-Mean loss mm\mathcal{L}_{mm} in Eq. 6, as it leverages prior knowledge from the training dataset to effectively “enforce” predicted survival time to exceed the censoring time. On the other hand, DeepHit exhibits significant inefficiency in forecasting longer censoring times. Similar to TDSM, this is also due to the absence of certain constraints within its loss designs beyond the censoring time, which may give rise to a systemic bias in predicting censoring cases.

4.4.2 Dynamic Modelling Results

As depicted in Tab. 3, UniSurv-d1 demonstrates superior performance over two other models for longitudinal datasets, as evidenced by higher values in C-index, mAUC and lower values in two MAEs. However, the performance of these three methods is generally suboptimal, as their C-index values remain below 0.60.6. This occurrence likely arises from the fact that the temporal data in MSReactor diverges from conventional survival tabular data, instead representing a reaction testing approach applied to MS patients. Traditional models struggle to effectively extract meaningful insights from this intricate and redundant information. Notably, when we preprocess the computerized test data into ”reaction tensor” and employ CNN to extract latent features, the performance of UniSurv-d2 surpasses the others with statistically significant improvements. However, this ”tensor” method has not demonstrated effectiveness for SYNTH-d, primarily due to the isotropic distribution of each variable xvx_{v} during data generation, resulting in their mutual independence and lack of correlation.

Refer to caption
Figure 3: The difference between the estimated lifetime μ^i\hat{\mu}^{i} (blue dot) and the true censoring time TiT^{i} (green square) of TDSM, DeepHit and UniSurv-s in METABRIC. Each red line indicates the difference if μ^i<Ti\hat{\mu}^{i}<T^{i} for individual ii, which is conversely not displayed in the opposite scenario
Table 3: Benchmarking on two dynamic datasets
Model MSReactor SYNTH-d
C-index \uparrow MAE-U \downarrow MAE-H \downarrow mAUC \uparrow C-index \uparrow MAE-U \downarrow MAE-H \downarrow mAUC \uparrow
DDH 0.521.0640.521_{.064} 38.2710.3238.27_{10.32} 16.735.3616.73_{5.36} 0.698.0510.698_{.051} 0.725.0090.725_{.009} 31.364.5631.36_{4.56} 5.450.365.45_{0.36} 0.822.0070.822_{.007}
RDSM 0.527.0770.527_{.077} 35.258.1235.25_{8.12} 17.927.6417.92_{7.64} 0.714.0550.714_{.055} 0.703.0040.703_{.004} 32.183.6732.18_{3.67} 5.320.295.32_{0.29} 0.804.0050.804_{.005}
UniSurv-d1 0.547¯.032\underline{0.547}_{.032} 34.12¯10.57\underline{34.12}_{10.57} 10.15¯5.12\underline{10.15}_{5.12} 0.729¯.048\underline{0.729}_{.048} 0.737¯.008\underline{0.737}_{.008} 23.74¯5.52\underline{23.74}_{5.52} 3.400.13\boldsymbol{3.40}_{0.13} 0.875¯.007\underline{0.875}_{.007}
UniSurv-d2 0.634.047\boldsymbol{0.634}^{{\dagger}}_{.047} 29.735.29\boldsymbol{29.73}_{5.29} 6.483.55\boldsymbol{6.48}^{{\dagger}}_{3.55} 0.793.046\boldsymbol{0.793}^{{\dagger}}_{.046} 0.739.007\boldsymbol{0.739}_{.007} 23.565.73\boldsymbol{23.56}_{5.73} 3.52¯0.08\underline{3.52}_{0.08} 0.876.008\boldsymbol{0.876}_{.008}
UniSurv-d1 w/o mask 0.562.0300.562_{.030} 33.5311.9633.53_{11.96} 11.918.2511.91_{8.25} 0.732.0580.732_{.058} 0.739.0070.739_{.007} 22.475.4722.47_{5.47} 3.740.123.74_{0.12} 0.876.0060.876_{.006}
UniSurv-d2 w/o mask 0.642.0590.642_{.059} 28.876.6728.87_{6.67} 7.134.477.13_{4.47} 0.801.0420.801_{.042} 0.741.0060.741_{.006} 23.875.3723.87_{5.37} 3.130.103.13_{0.10} 0.878.0070.878_{.007}

4.4.3 The Implication of Data Distribution

Refer to caption
Figure 4: The time-dependent AUC. The dashed line shows mAUC corresponding to each colored curve

As shown in Fig. 4, all five histograms depict the distribution of survival times skewed towards the early segment of the time horizon, while censoring times tend to cluster in the latter half, especially in SUPPORT, SYNTH-s, MSReactor and SYNTH-d. This leads to survival models facing difficulty in maintaining predictive accuracy over time, as evidenced by the time-dependent AUC (TD-AUC). For example, all the performances of UniSurv-s, DeepHit and DSM, or their dynamic variants (UniSurv-d2, DDH, RDSM) exhibit a consistent decline in TD-AUC as time progresses. However, UniSurv still outperforms others, especially on two dynamic datasets. For METABRIC, due to its relatively low censoring rate and evenly distributed censoring cases, all three models maintain their TD-AUC quite well, with some even showing an upward trend, particularly UniSurv. It affirms that Transformer encoder based on Margin-Mean-Variance loss learning can effectively alleviate the challenges posed by survival datasets characterized by long-tail distributions.

Refer to caption
Figure 5: The estimated PDFs by DDH and UniSurv for five randomly selected uncensoring cases in MSReactor. Each color represents an individual

4.5 Importance of Masked Attention Mechanism

In the context of leveraging Transformer for inference, the masking function within the attention mechanism is inherently discretionary, contingent upon whether each output necessitates contributions from all or specific designated inputs. For static survival data, the design of UniSurv does not entail distinctions in latent features at each time point beyond temporal embedding. Hence, there is no risk of information leakage, rendering the masking mechanism inconsequential. For instance, it is not employed in the TDSM. As demonstrated by Tab. 2, the overall performance of UniSurv-s has not been affected by removing masking mechanism from UniSurv-s, and the slight performance fluctuations can be negligible. However, when dealing with dynamic survival data, the missing data problem is inevitable, and imputations following event or censoring times may give rise to potential retro-active prediction concern. Therefore, the masking mechanism becomes imperative in such scenario. In Tab. 3, both UniSurv-d1 and UniSurv-d2 exhibited an equivalent degree of performance decline across two datasets by removing masking, which are evidenced by their ranking ability.

4.6 Comparison of PDF Visualizations

In addition to predictive accuracy, the quality of estimated individual PDF stands as another crucial consideration when comparing non-parametric survival models. The distribution of PDF generated by our UniSurv is specifically governed by Margin-Mean-Variance loss and remains unaffected by variations in distinct extraction modules. In Fig. 5, we present a comparison of the PDF outputs for 5 randomly selected uncensoring cases in MSReactor. We choose to contrast the DDH and UniSurv due to their absence of assumptions regarding the shape of the PDF, whereas RDSM relies on strong assumptions related to the Weibull and log-normal distributions. As described in above sections and shown in Fig. 5, our mm\mathcal{L}_{mm} can penalize dissimilarity between the peak of PDF and the ground truth. Besides, diverging from the disordered PDFs from DDH, v\mathcal{L}_{v} can regulate the spread of PDF and limit it into a distinct pattern and organization. In contrast, despite using the same MLP and softmax as the output layer in DDH, the high fluctuations of PDFs can be attributed to the shortcomings in its loss function design.

The unimodal nature of survival PDF offers several advantages. For example, it can better reflect the time-to-event and naturally calibrate the median survival time corresponding to survival curve, such as [31] employed several pre-defined unimodal distributions for survival modelling. However, UniSurv departs from this assumption, achieving the same objective through a distinctive loss design. The current over-concentrated PDF is not optimal, and appropriately adjusting v\mathcal{L}_{v} to relax its constraints on the shape will become necessary.

Refer to caption
(a) Loss combination
Refer to caption
(b) TwT_{w} selection
Refer to caption
(c) λm\lambda_{m} selection
Refer to caption
(d) λv\lambda_{v} selection
Refer to caption
(e) Averaged PDF shape
Refer to caption
(f) PDF shape
Refer to caption
(g) PDF sensitivity
Figure 6: Comparison results of the ablation study and the effectiveness analysis. All experiments are from UniSurv-d2 setting on MSReactor dataset

4.7 Ablation Study

We further conduct an ablation study of losses on MSReactor using UniSurv-d2, to demonstrate the contribution of each loss. In Fig. 6(a), it is evident that an incomplete loss combination sometimes can lead to lower MAE-U or MAE-H, however, this often results in a situation of local optimization, which is reflected in the shape of the PDF. In Fig. 6(f), we compare the PDFs under selected scenarios: mm\mathcal{L}_{mm} only, v\mathcal{L}_{v} only and total\mathcal{L}_{total}. It is discernible that relying solely on mm\mathcal{L}_{mm}, due to the absence of v\mathcal{L}_{v} constraints, tends to produce a probability distribution biased towards uniformity around the ground truth. On the other hand, training solely with v\mathcal{L}_{v} generates irregular PDFs and fails to acquire meaningful information. This observation elucidates why these scenarios do not yield the optimal C-index. Besides, the inclusion of d\mathcal{L}_{d} can further enhance the performance by mitigating the occurrence of discordant pairs. In addition, incorporating s\mathcal{L}_{s} results in faster convergence and significant performance improvement, particularly evident in C-index. However, s\mathcal{L}_{s} also leads to a rapid concentration of all probability distributions near the event time, which can result in overly concentrated PDFs and potential calibration issues. As shown in Fig. 6(e), we compared the averaged PDF shape with and without s\mathcal{L}_{s} in the combinations presented in Fig. 6(a). The averaged shape is calculated by aligning all PDF peaks using the Dynamic Time Warp [7] technique and then averaging them in normalized horizon. It is apparent that the absence of s\mathcal{L}_{s} yields a multimodal, smoother, and more realistic PDF. Hence, the selection of s\mathcal{L}_{s} involves a trade-off between the ranking and calibration ability.

4.8 Effectiveness Analysis

4.8.1 Sensitivity of Time Window TwT_{w}

Fig. 6(b) shows the effect of TwT_{w}. We can observe that when Tw=8T_{w}=8, the model can achieve the highest C-index and lowest MAE-H, which is associated with the progression rate of MS. However, during the same period, MAE-U demonstrates its poorest performance. It is also apparent that the fluctuations in MAE-U and MAE-H exhibit a contrasting pattern. This disparity can be attributed to the disparate distributions of censoring and uncensoring within the MSReactor as in Fig. 4. Meanwhile, this underscores that there exists potential for enhancing the robustness of UniSurv.

4.8.2 Sensitivity of Loss Weights λm\lambda_{m} And λv\lambda_{v}

As the number of losses increases, finding the optimal weight combination indeed becomes challenging, but grid search can take care of this. The four losses do not need to be standardized to a similar magnitude. The unique characteristics of different datasets can lead to distinct optimal weights for losses. We assessed the sensitivities of λm\lambda_{m} and λv\lambda_{v} particularly on MSReactor in Fig. 6(c) and Fig. 6(d), and some selected PDFs are shown in Fig. 6(g). The model exhibits robustness when small variations occur in λm\lambda_{m} or λv\lambda_{v}, as the performance near their optimal values does not exhibit significant degradation. In some cases, two MAEs even perform better. This phenomenon is attributed to the opposing fluctuation trends exhibited by the MAEs, indicating a trade-off made by UniSurv during training. Notably, the C-index appears to be more sensitive to changes in λv\lambda_{v} compared to variations in λm\lambda_{m}. As the variations in both weights increase, deviations in the PDF gradually emerge, with its peak drifting further away from the actual event time and assuming irregular shapes.

4.8.3 Sensitivity of Larger and Noised Synthetic datasets

To emphasize UniSurv’s reliability for higher dimensionality datasets and its robustness to data noise, we have expanded the number of features for the existing synthetic datasets SYNTH-s and SYNTH-d without altering event or censoring settings, resulting in new SYNTH-sk and SYNTH-dk datasets, where k denotes the dimension of 𝒙ni\boldsymbol{x}_{n}^{i} in Eq. 11 and Eq. 12 is increased from 44 to 45k4\cdot 5^{k}, and the dimension of 𝒙vi\boldsymbol{x}_{v}^{i} in Eq. 12 is increased from 2020 to 205k20\cdot 5^{k}. Additionally, we introduced noise ϵiϵ0𝒩(0,𝐈)\boldsymbol{\epsilon}^{i}\sim\epsilon_{0}\cdot\mathcal{N}(0,\mathbf{I}) to 𝒙ni\boldsymbol{x}_{n}^{i} and 𝒙vi\boldsymbol{x}_{v}^{i} separately in all datasets. As the results shown in Tab. 4, UniSurv performs well on high-dimensional datasets and exhibits robustness to small levels of noise interference.

Table 4: Noise Sensitivity analysis on different sizes of synthetic datasets
ϵ𝟎\boldsymbol{\epsilon_{0}} UniSurv-s UniSurv-d2
SYNTH-s SYNTH-s1 (k=1) SYNTH-s2 (k=2) SYNTH-d SYNTH-d1 (k=1) SYNTH-d2 (k=2)
C-index MAE-H C-index MAE-H C-index MAE-H C-index MAE-H C-index MAE-H C-index MAE-H
𝟎\boldsymbol{0} 0.731.0080.731_{.008} 1.550.041.55_{0.04} 0.733.0070.733_{.007} 1.530.051.53_{0.05} 0.732.0080.732_{.008} 1.540.041.54_{0.04} 0.739.0070.739_{.007} 3.520.083.52_{0.08} 0.740.0080.740_{.008} 3.510.083.51_{0.08} 0.738.0090.738_{.009} 3.530.093.53_{0.09}
0.1\boldsymbol{0.1} 0.731.0080.731_{.008} 1.560.041.56_{0.04} 0.733.0080.733_{.008} 1.530.051.53_{0.05} 0.732.0090.732_{.009} 1.550.041.55_{0.04} 0.739.0070.739_{.007} 3.530.083.53_{0.08} 0.740.0080.740_{.008} 3.510.093.51_{0.09} 0.737.0100.737_{.010} 3.530.093.53_{0.09}
0.3\boldsymbol{0.3} 0.729.0100.729_{.010} 1.560.051.56_{0.05} 0.731.0090.731_{.009} 1.540.061.54_{0.06} 0.729.0090.729_{.009} 1.570.051.57_{0.05} 0.738.0060.738_{.006} 3.540.093.54_{0.09} 0.738.0070.738_{.007} 3.530.083.53_{0.08} 0.735.0090.735_{.009} 3.560.073.56_{0.07}

5 Conclusion And Discussion

In this paper, we propose a non-parametric discrete survival model named UniSurv. Departing from the existing models of utilizing RNN for processing longitudinal data, we employ a Transformer for adeptly handling dynamic analysis. In particular, our survival framework firstly integrates imputation for handling missing data issue, then incorporates different embedding branches for time-varying and time-invariant features extraction. The Transformer encoder takes merged features as input and outputs the individual PDF. We also demonstrated how to process image-like data using variations of modules and how to select a time window based on the progression speed of the disease to share information. This is particularly beneficial in the field of medicine, as obtaining regular time-series medical images in the real world is challenging.

Furthermore, our novel Margin-Mean-Variance loss effectively produces smooth PDF in a unimodal manner, demonstrating clear superiority over other discrete models. Importantly, the proposed loss can be seamlessly embedded into various discrete survival models. Moreover, it significantly enhances prediction accuracy, particularly for patients with extended censoring times. Applying poorly performing models in such scenarios could evidently disrupt physician’s judgments and place unnecessary burdens on both society and healthcare institutions. This constitutes a valuable contribution. Although our current PDF may appear overly concentrated around event times, akin to many models relying on strong probability assumptions, resulting in unconventional survival curves, we intend to further modify the s\mathcal{L}_{s} and v\mathcal{L}_{v} to relax certain constraints in the future. This adjustment aims to yield a more elegant PDF, characterized by a smoother and less abrupt distribution while maintaining overall performance. Meanwhile, adapting UniSurv to accommodate multiple censoring scenarios, such as left truncation and interval-censored data, presents an interesting direction for future research. Additionally, expanding the scope to include a post-processing statistic for interpreting risk predictions in both static and dynamic analyses of disease progression is necessary. For example, individual explanations of predicted probabilities can be achieved through the generation of SHapley Additive exPlanations (SHAP) [28, 15]. This approach is expected to result in more effective health care.

\bmhead

Acknowledgements

X.Z. receives support from the Australian Government Research Training Program (RTP) Scholarship.

Declarations

Availability of data and materials We are restricted from making MSReactor data available to the public for the moment. All the other data are publicly available.
Competing interests Not applicable.
Ethics approval Not applicable.
Consent for participation Not applicable.
Consent for publication Not applicable.

Appendix A MSReactor

A.1 Missing Details

For each patient, temporal tests were done once around roughly every half-year. The time interval between two adjacent tests range from 0 to 4444 months with mean of 4.684.68. The number of yearly follow-ups was from 11 to 66 with mean of 2.202.20 tests per patients.

A.2 Min-Max Values Selection of Reaction Tensor Representation

In our proposed innovative reaction tensor representation, as illustrated in Fig. 2, the chosen minimum and maximum values are not directly derived from the original tabular data for MSReactor dataset, but rather determined by threshold selection. Specifically, we sort individual patient’s data for a particular task cc, and the 2%2\% and 98%98\% percentiles of the sorted values are taken as the minimum and maximum values, denoted as αmin_thr,c\alpha_{min\_thr,c} and αmax_thr,c\alpha_{max\_thr,c}, respectively. The outliers are automatically set as minimum or maximum values. We use the same strategy in SYNTH-d dataset.

The rationale for adopting this approach stems from the inherent instability of recorded reaction times in such tests, attributed at times to individual patient idiosyncrasies or extraneous environmental interference, rendering these reaction times as outliers in our analysis. Instances of a patient expending time to accommodate their sitting posture, being diverted by ambient noise distractions, or experiencing rapid inadvertent touchscreen interactions, exemplify scenarios capable of inducing aberrations in reaction time.

Appendix B Hyperparameter Information

Tab. 5 shows the hyperparameter spaces and their optimal choices we used in UniSurv-s for SUPPORT, METABRIC and SYNTH-s dataset, and in UniSurv-d2 for SYNTH-d, MSReactor. Clearly, we shrink the architecture of Transformer encoder part, since the survival datasets are much smaller than the standard natural language processing (NLP) or computer vision (CV) datasets, and the number of features is also small. Besides, the selection of TmaxT_{max} is based on the maximum survival/censoring time in the corresponding dataset, and the value of TwT_{w} is directly chosen from the factors of Tmax+1T_{max}+1 for the sake of convenience during training dynamic datasets.

Table 5: The hyperparameter spaces for five datasets. We bold the optimal choice
Hyperparameter SUPPORT METABRIC SYNTH-s MSReactor SYNTH-d
TmaxT_{max} 80 400 200 95 199
TwT_{w} - - - {4, 6, 8, 12, 24, 32} {4, 5, 8, 25, 40}
λm\lambda_{m} {0.01, 0.1, 1, 10} {0.01, 0.1, 1, 10} {0.01, 0.1, 1, 10} {0.01, 0.1, 1, 10} {0.01, 0.1, 1, 10}
λv\lambda_{v} {0.001, 0.01, 0.1, 1} {0.001, 0.01, 0.1, 1} {0.001, 0.01, 0.1, 1} {0.001, 0.01, 0.1, 1} {0.001, 0.01, 0.1, 1}
λd\lambda_{d} {0, 1} {0, 1} {0, 1} {0, 1} {0, 1}
Epochs 400 200 200 200 200
Batch size 16 16 {4, 8, 16, 32} {4, 8, 16, 32} {4, 8, 16, 32}
Dropout rate 0.1 0.1 {0.0, 0.1, 0.3} {0.0, 0.1, 0.3} {0.0, 0.1, 0.3}
Number of heads 4 4 {1, 2, 4, 8} {1, 2, 4, 8} {1, 2, 4, 8}
Embedding dimension 512 512 {256, 512} {256, 512} {256, 512}
Number of attention layers 4 4 {1, 2, 3, 4} {1, 2, 3, 4} {1, 2, 3, 4}
Adam optimizer with fixed learning rate 1e-4 1e-4 {1e-4, 1e-3} {1e-4, 1e-3} {1e-4, 1e-3}

References

  • \bibcommenthead
  • Collett [2023] Collett D (2023) Modelling survival data in medical research. CRC press
  • Cox [1972] Cox DR (1972) Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological) 34(2):187–202
  • Curtis et al [2012] Curtis C, Shah SP, Chin SF, et al (2012) The genomic and transcriptomic architecture of 2,000 breast tumours reveals novel subgroups. Nature 486(7403):346–352
  • Davidson-Pilon [2019] Davidson-Pilon C (2019) lifelines: survival analysis in python. Journal of Open Source Software 4(40):1317. 10.21105/joss.01317, URL https://doi.org/10.21105/joss.01317
  • Faraggi and Simon [1995] Faraggi D, Simon R (1995) A neural network model for survival data. Statistics in medicine 14(1):73–82
  • Foong et al [2023] Foong YC, Bridge F, Merlo D, et al (2023) Smartphone monitoring of cognition in people with multiple sclerosis: A systematic review. Multiple Sclerosis and Related Disorders p 104674
  • Giorgino [2009] Giorgino T (2009) Computing and visualizing dynamic time warping alignments in r: the dtw package. Journal of statistical Software 31:1–24
  • Haider et al [2020] Haider H, Hoehn B, Davis S, et al (2020) Effective ways to build and evaluate individual survival distributions. The Journal of Machine Learning Research 21(1):3289–3351
  • Hu et al [2021] Hu S, Fridgeirsson E, van Wingen G, et al (2021) Transformer-based deep survival analysis. In: Survival Prediction-Algorithms, Challenges and Applications, PMLR, pp 132–148
  • Hunter et al [2021] Hunter SF, Aburashed RA, Alroughani R, et al (2021) Confirmed 6-month disability improvement and worsening correlate with long-term disability outcomes in alemtuzumab-treated patients with multiple sclerosis: Post hoc analysis of the care-ms studies. Neurology and therapy 10(2):803–818
  • Ishwaran et al [2008] Ishwaran H, Kogalur UB, Blackstone EH, et al (2008) Random survival forests. The annals of applied statistics 2(3):841–860
  • Kaplan and Meier [1958] Kaplan EL, Meier P (1958) Nonparametric estimation from incomplete observations. Journal of the American statistical association 53(282):457–481
  • Katzman et al [2018] Katzman JL, Shaham U, Cloninger A, et al (2018) Deepsurv: personalized treatment recommender system using a cox proportional hazards deep neural network. BMC medical research methodology 18(1):1–12
  • Knaus et al [1995] Knaus WA, Harrell FE, Lynn J, et al (1995) The support prognostic model: Objective estimates of survival for seriously ill hospitalized adults. Annals of internal medicine 122(3):191–203
  • Krzyziński et al [2023] Krzyziński M, Spytek M, Baniecki H, et al (2023) Survshap (t): time-dependent explanations of machine learning survival models. Knowledge-Based Systems 262:110234
  • Lambert and Chevret [2016] Lambert J, Chevret S (2016) Summary measure of discrimination in survival models based on cumulative/dynamic time-dependent roc curves. Statistical methods in medical research 25(5):2088–2102
  • Lee et al [2018] Lee C, Zame W, Yoon J, et al (2018) Deephit: A deep learning approach to survival analysis with competing risks. In: Proceedings of the AAAI conference on artificial intelligence
  • Lee et al [2019] Lee C, Yoon J, Van Der Schaar M (2019) Dynamic-deephit: A deep learning approach for dynamic survival analysis with competing risks based on longitudinal data. IEEE Transactions on Biomedical Engineering 67(1):122–133
  • Lee and Whitmore [2006] Lee MLT, Whitmore G (2006) Threshold regression for survival analysis: Modeling event times by a stochastic process reaching a boundary. Statist Sci 21(1):501–513
  • Leung et al [1997] Leung KM, Elashoff RM, Afifi AA (1997) Censoring issues in survival analysis. Annual review of public health 18(1):83–104
  • Luck et al [2017] Luck M, Sylvain T, Cardinal H, et al (2017) Deep learning for patient-specific kidney graft survival analysis. arXiv preprint arXiv:170510245
  • Merlo et al [2021] Merlo D, Stankovich J, Bai C, et al (2021) Association between cognitive trajectories and disability progression in patients with relapsing-remitting multiple sclerosis. Neurology 97(20):e2020–e2031
  • Nagpal et al [2021a] Nagpal C, Jeanselme V, Dubrawski A (2021a) Deep parametric time-to-event regression with time-varying covariates. In: Survival Prediction-Algorithms, Challenges and Applications, PMLR, pp 184–193
  • Nagpal et al [2021b] Nagpal C, Li X, Dubrawski A (2021b) Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. IEEE Journal of Biomedical and Health Informatics 25(8):3163–3175
  • Nagpal et al [2021c] Nagpal C, Yadlowsky S, Rostamzadeh N, et al (2021c) Deep cox mixtures for survival regression. In: Machine Learning for Healthcare Conference, PMLR, pp 674–708
  • Pan et al [2018] Pan H, Han H, Shan S, et al (2018) Mean-variance loss for deep age estimation from a face. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp 5285–5294
  • Pham et al [2021] Pham L, Harris T, Varosanec M, et al (2021) Smartphone-based symbol-digit modalities test reliably captures brain damage in multiple sclerosis. NPJ digital medicine 4(1):36
  • Pieszko et al [2023] Pieszko K, Shanbhag AD, Singh A, et al (2023) Time and event-specific deep learning for personalized risk assessment after cardiac perfusion imaging. npj Digital Medicine 6(1):78
  • Pölsterl [2020] Pölsterl S (2020) scikit-survival: A library for time-to-event analysis built on top of scikit-learn. Journal of Machine Learning Research 21(212):1–6. URL http://jmlr.org/papers/v21/20-729.html
  • Pourjafari et al [2022] Pourjafari E, Ziaei N, Rezaei MR, et al (2022) Survival seq2seq: A survival model based on sequence to sequence architecture. In: Machine Learning for Healthcare Conference, PMLR, pp 79–100
  • Rindt et al [2022] Rindt D, Hu R, Steinsaltz D, et al (2022) Survival regression with proper scoring rules and monotonic neural networks. In: International Conference on Artificial Intelligence and Statistics, PMLR, pp 1190–1205
  • Singer and Willett [1991] Singer JD, Willett JB (1991) Modeling the days of our lives: using survival analysis when designing and analyzing longitudinal studies of duration and the timing of events. psychological Bulletin 110(2):268
  • Uno et al [2011] Uno H, Cai T, Pencina MJ, et al (2011) On the c-statistics for evaluating overall adequacy of risk prediction procedures with censored survival data. Statistics in medicine 30(10):1105–1117
  • Vaswani et al [2017] Vaswani A, Shazeer N, Parmar N, et al (2017) Attention is all you need. Advances in neural information processing systems 30
  • Vinzamuri and Reddy [2013] Vinzamuri B, Reddy CK (2013) Cox regression with correlation based regularization for electronic health records. In: 2013 IEEE 13th International Conference on Data Mining, IEEE, pp 757–766
  • Wang and Sun [2022] Wang Z, Sun J (2022) Survtrace: Transformers for survival analysis with competing events. In: Proceedings of the 13th ACM International Conference on Bioinformatics, Computational Biology and Health Informatics, pp 1–9
  • Whitehouse et al [2019] Whitehouse CE, Fisk JD, Bernstein CN, et al (2019) Comorbid anxiety, depression, and cognition in ms and other immune-mediated disorders. Neurology 92(5):e406–e417