This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

ClinicalMamba: A Generative Clinical Language Model on Longitudinal Clinical Notes

Zhichao Yang1,  Avijit Mitra1,  Sunjae Kwon1,  Hong Yu1,2
1 College of Information and Computer Sciences, University of Massachusetts Amherst
2 Department of Computer Science, University of Massachusetts Lowell
{zhichaoyang,avijitmitra,sunjaekwon}@umass.edu [email protected]
Abstract

The advancement of natural language processing (NLP) systems in healthcare hinges on language models’ ability to interpret the intricate information contained within clinical notes. This process often requires integrating information from various time points in a patient’s medical history. However, most earlier clinical language models were pretrained with a context length limited to roughly one clinical document. In this study, We introduce ClinicalMamba, a specialized version of the Mamba language model, pretrained on a vast corpus of longitudinal clinical notes to address the unique linguistic characteristics and information processing needs of the medical domain. ClinicalMamba models, with 130 million and 2.8 billion parameters, demonstrate superior performance in modeling clinical language across extended text lengths compared to Mamba and clinical Llama. With few-shot learning, ClinicalMamba achieves notable benchmarks in speed and performance, outperforming existing clinical language models and large language models like GPT-4 in longitudinal clinical tasks.

1 Introduction

Clinical narratives, such as patient histories, consultation notes, and discharge summaries, contain detailed and complex information that extends over long text sequences (Wu et al., 2019). To fully understand a patient’s condition, treatments, and outcomes, NLP systems need to integrate information from various parts of these narratives, which often requires understanding the context provided in those long form text (Blumenthal, 2010).

Understanding the sequence of health events is crucial for diagnoses, treatment plans, and patient monitoring (Wang et al., 2024; Yang et al., 2023; Eva, 2005). This often involves putting together information from different time points within a patient’s health history (Gao et al., 2024). Long context enables NLP systems to perform temporal reasoning by tracking events over time longitudinally, which is essential for tasks like predicting disease progression or extracting medical relation (Chen et al., 2023; Jia et al., 2019; Wiegreffe et al., 2019).

It becomes imperative to design models for the need for processing longer texts (Parmar et al., 2023; Tay et al., 2020). Prior studies have introduced Mamba (Gu and Dao, 2023), a selective state space model, that selects and compresses all necessary information into latent space from context, and achieves linear-time efficiency with context length. While these advancements have been primarily directed towards processing general domain text, the unique linguistic features of clinical narratives differ significantly from general domain (Lehman et al., 2023), motivating us to develop specialized Mamba models in the clinical domain.

In this work, we build and publicly release ClinicalMamba - a Mamba model pretrained on longitudinal clinical notes. Furthermore, we demonstrate that ClinicalMamba outperforms multiple language models on longitudinal clinical NLP tasks. In particular, our contributions are as follows:

  • We publicly release ClinicalMamba with 130m and 2.8b parameters trained on MIMIC-III (Johnson et al., 2016).111https://github.com/whaleloops/ClinicalMamba

  • Through distributed training, ClinicalMamba-2.8b model was pretrained in under 60 hours on 4 A100 GPUs and it is the first clinical autoregressive language model with a 16k maximum token length.

  • Through few-shot prompt-based finetuning, we demonstrate both ClinicalMamba outperforms original Mamba, GPT4, and other existing clinical long context language models on well-established long context clinical information extractions tasks: cohort selection for clinical trial and international classification of diseases (ICD) coding.

Refer to caption
Figure 1: Perplexity of different generative language models on MIMIC-III when evaluated at various preceding context lengths (1k, 4k, and 16k tokens). The X-axis is in the log scale. The subfigure is a zoom-out plot with perplexity ranges 0-100. Experiment settings and detailed results are in section 4.

2 Related Work

2.1 Pretraining clinical narratives

The rapid expansion of the utilization of electronic health records (EHRs) into the healthcare landscape underscores an urgent need for a clinical language model (Kang et al., 2019). Previous work, such as (Alsentzer et al., 2019) on Clinical BERT embeddings and (Huang et al., 2019) with ClinicalBERT, adapted general-purpose language models to the clinical domain to enhance performance on clinical tasks. These models have been pivotal in demonstrating the effectiveness of adapting general-purpose NLP tools to the intricacies of clinical text. Similarly, the creation of GatorTron (Yang et al., 2022a) scales up clinical language models to billions of parameters, while NYUTron (Jiang et al., 2023) harness billions of unstructured data found in EHRs. Both underscores the potential of domain-adapted language models to advance clinical NLP by improving performance across various tasks such as concept extraction and outcome prediction (Yang et al., 2022a; Jiang et al., 2023).

To handle complex and nuanced tasks, recent studies investigated training generative models with prompt (Kweon et al., 2023; Peng et al., 2023; Wang et al., 2023; Lu et al., 2022; Wang and Sun, 2022). These models not only excel in classification but also in generating clinically relevant text that can be indistinguishable from human-written notes. Most previous methods focus on pretraining transformer models with a context window less than 2k tokens. However, we pretrained a selective state space model with a context window of 16k tokens, which includes more than 98% of the visits in MIMIC-III.

2.2 Clinical information extraction on long document

Handling long texts in clinical NLP has always been challenging. Traditional methods of information extraction tackle this by marking specific locations within the sentence, but such labeling is not always available, and hiring annotators can be costly (Fu et al., 2020; Kwon et al., 2022; Deshpande et al., 2024). Recent advancements in document information extraction involve pairing labels with documents. However, BERT struggles with processing these lengthy documents directly.

To address this, prior research introduced Hierarchical-ClinicalRoberta, which involves breaking down long documents into shorter segments of 512 tokens, applying ClinicalRoberta to each segment to obtain embeddings, and then using additional layers to leverage these embeddings for label classification (Huang et al., 2022; Zhang and Jankowski, 2022). However, this method combines information from each segment only at the final layer, which can hinder performance when training data is limited.

To mitigate this issue, ClinicalLongformer is designed to efficiently process longer context length by employing a self-attention mechanism across all layers, which is key to its proficiency in managing dense information exchanges within a specified contextual range (Li et al., 2022; Ji et al., 2023). This mechanism, while powerful, is limited by its focus on a predetermined window of text, restricting its scope to what falls within this window.

To overcome these limitations, the Mamba model emerges as a revolutionary approach. Mamba employs a selective state space model strategy to meticulously choose critical data for incorporation into its state (Gu and Dao, 2023), thereby, enhancing its capability to manage information beyond the conventional self-attention window. In general domain language modeling, Mamba surpasses Transformers of equivalent size in task performance and speed.

3 Methods

3.1 Pretraining

We gather 82,178 hospital visits along with their deidentified free-text clinical notes (2,083,180) from 46,520 patients in MIMIC-III (Johnson et al., 2016). Rather than breaking down the notes into chunks of 512 tokens to act as individual data instances, we aggregate all notes related to a visit longitudinally. The distribution of token counts per data instance is detailed in Table A.1. For information on our text pre-processing methods, please refer to section A.1.

Following previous works (Li et al., 2019; Alsentzer et al., 2019), we continue to pretrain Mamba using MIMIC-III clinical notes with the causal language modeling objective. This pretraining process utilizes 4 Nvidia A100-80GB GPUs. It’s important to note that some of our downstream evaluation tasks utilize a small subset (6,049) of hospital visits from MIMIC-III, so we exclude them from the pretaining data. A comprehensive training recipe is available in section A.2.

3.2 Prompt based fine-tuning

We leverage the inherent capabilities of pre-trained language models by introducing a novel fine-tuning strategy that aligns with the specific demands of few-shot learning in clinical NLP. Recognizing the limitations of traditional fine-tuning methods when applied to clinical NLP tasks with limited labeled data, we adapted a prompt-based fine-tuning mechanism following previous works (Gao et al., 2021; Yang et al., 2022b; Taylor et al., 2023). Specifically, we first identify a set of representative prompts that encapsulate key aspects of the clinical tasks, such as the patient’s alcohol consumption. These prompts are then appended after each input clinical note and incorporated into the fine-tuning phase, where the language model learns to associate them with label tokens (Yes/No) based on a limited dataset. The generated label tokens are then mapped to label space.

Refer to caption
Figure 2: Illustration of Prompt-based fine-tuning.

As shown in Figure 2, we transfer the downstream information extraction task into a pretraining-like task - label token generation.

3.3 Fine-tuning tasks

Cohort selection for clinical trial addresses the challenge of interpreting unstructured clinical narratives to streamline the patients selection process. It aims to classify patients based on whether they meet 13 specific eligibility criteria, such as the usage of aspirin to prevent myocardial infarction, excessive alcohol consumption, and HbA1c values between 6.5 and 9.5%, among others. The input contains multiple clinical notes with a total length of 4924 tokens on average. This dataset was released as part of n2c2 challenge (track 1) in 2018.

ICD coding interprets complex clinical narratives, translating them into standardized codes that facilitate accurate billing, statistical analysis, and healthcare management. It aims to extract patient’s disease and procedure codes from clinical text. We followed general instructions from Mullenbach et al. (2018) in building this task, but instead of using a single discharge summary as input, we used all previous discharge summaries and assigned ICD code descriptions from previous visits. We further filtered 50 infrequent codes as Code-rare and 50 frequent codes as Code-common following Yang et al. (2022b). The average length is 4,223 and 7,062 tokens respectively. Detailed dataset statistics are shown in Table A.1.

For each task, we report the micro precision, recall, F1 scores, and the receiver operating characteristic/area under the curve (ROCAUC) on the test dataset.

4 Results & Discussions

Model Prec Recall F1 AUC
CLlama2 70.0 79.1 77.7 84.3
Hi-CRoberta 72.4 82.6 79.2 88.1
CLongformer 69.7 78.6 76.1 83.5
GPT-4 88.1 79.9 84.8 -
Mamba-130m 75.4 80.2 77.7 85.7
CMamba-130m 79.0 86.2 82.2 91.8
CMamba-2.8b 88.6 89.5 88.8 95.7
Table 1: Results on cohort selection task, where C is model pretrained in clinical domain.
Model Code-rare Code-common
Prec Recall F1 AUC Prec Recall F1 AUC
MultiResCNN 20.34 2.07 5.19 47.2 70.5 60.78 66.24 92.04
Hi-CRoberta 46.19 10.96 16.74 77.11 73.76 65.01 69.23 93.14
CLongformer 50.27 17.81 28.69 80.52 78.42 64.97 71.14 94.24
GPT-4 30.91 36.12 33.29 - 72.48 62.28 68.19 -
Mamba-130m 57.75 28.08 37.79 84.8 73.71 62.87 68.94 92.75
CMamba-130m 70.97 30.14 42.31 91.08 76.82 68.03 74.34 94.23
CMamba-2.8b 75.28 45.89 56.51 92.75 75.53 72.12 73.64 94.54
Table 2: Results on ICD coding task, where C indicates model pretrained in clinical domain.

In this section, we will first compare the model’s language modeling ability on MIMIC-III clinical notes. We will then describe the evaluation on different clinical information extraction tasks. Finally, we will describe trade-offs between language modeling abilities (perplexity) and inference speed (throughput) for several generative models on MIMIC-III notes. Baseline models used are detailed in section A.3.

ClinicalMamba stands as the sole model capable of handling clinical notes of up to 16k tokens. As demonstrated in Figure 1, the perplexity for ClinicalMamba-2.8b decreased from 3.11 to 2.61 as the context length expanded from 1k to 16k tokens during inference. This is in contrast to the performance of prior clinical autoregressive language models, where perplexity levels rose with increased context lengths. For instance, with ClinicalLlama-7b, perplexity escalated from 2.82 to 94.02 as the context length grew from 4k to 16k. This limitation arises because these models were trained on contexts not exceeding 4k, impairing their accuracy for next token prediction when given previous contexts beyond 4k.

In the domain of extracting information from longitudinal clinical records, ClinicalMamba demonstrates superior performance compared to Mamba. ClinicalMamba achieved ROCAUC scores of 91.8, 42.3, and 94.2 on Cohort selection, Code-rare, and Code-common, while Mamba obtained ROCAUC scores of 85.7, 37.8, and 92.8 respectively. ClinicalMamba also outperformed previous long-range clinical language models with similar number of parameters. ClinicalMamba significantly outperformed Hierachical-ClinicalRoberta and ClinicalLongformer by relatively 52.7% and 19.1% on ROCAUC respectively. This is particularly notable in the Code-rare task with limited training data (5 shots), where ClinicalMamba attained an AUC of 91.1, compared to 77.1 of Hierarchical-ClinicalRoberta and 80.5 of ClinicalLongformer. Surprisingly, ClinicalMamba-2.8b also outperformed zero-shot GPT-4, achieving F1 scores of 88.8, 56.6, and 73.6 on Cohort selection, Code-rare, and Code-common tasks, whereas GPT-4 obtained a F1 score of 84.8, 33.3, and 68.2 respectively.

ClinicalMamba also offers a great tradeoff between language modeling abilities and inference speed. On one side, small language models’ perplexity is limited. On the other side, the remarkable perplexity reduction delivered by large language models comes at a steep increase in computational cost. As shown in Table A.4, the perplexity of ClinicalMamba-2.8b (2.61) is comparable to that of ClinicalLlama-7b (2.82) trained with the same computation budget. Moreover, the inference speeds of ClinicalMamba-2.8b and ClinicalMamba-130m are 3 to 30 times faster than that of ClinicalLlama7b.

5 Conclusion

In this study, we developed and released Mamba models pretrained on a large collection of clinical notes. Our findings demonstrate the superior performance of our ClinicalMamba in extracting information from long text documents compared to other models. We strongly believe that clinical NLP researchers can benefit from such long-context generative language models that alleviates the need of a substantial computational power, without any performance trade-off. Building on the groundwork laid by this study, future endeavors can further refine and expand the capabilities of ClinicalMamba, thereby enhancing the effectiveness of clinical data processing across diverse medical fields.

Limitations

This work has several notable limitations. First, we do not experiment with more recent parameter-efficient fine-tuning strategies such as soft prompting (Lester et al., 2021) and Low-Rank Adaptation (LoRa) (Hu et al., 2021). This potentially undermined ClinicalMamba on downstream tasks. Second, our adaptation of the Mamba framework was restricted solely to textual data documented during visits. EHRs are rich with multifaceted information, including but not limited to radiology images taken at different times and Electrocardiogram waveforms that span various periods. Future research could develop a multimodal Mamba framework to leverage all other modalities. Third, the MIMIC-III dataset, which serves as the foundation of our study, only includes notes from the intensive care unit of a single hospital within the United States. This limits the generalizability of our findings, as care practices vary significantly across different institutions and countries. We did not pretrain on MIMIC-IV because it only has a limited number of notes (and also limited type: discharge summary and radiology report) per visit. Lastly, the linguistic scope of the MIMIC dataset is limited to English, which presents a barrier to understanding and applying our findings in non-English speaking contexts. Addressing these limitations could substantially broaden the applicability and relevance of our work in future endeavors.

Ethics Statement

In this research, we gained authorized access to the MIMIC and N2C2 dataset and used de-identified clinical notes following their license agreement and HIPAA regulations. When language models are trained on extensive clinical text, they can inherit biases within the data. For instance, they might prefer inquiries concerning smoking habits or link specific medical conditions to certain demographic groups. These biases could be mitigated by enhancing model alignment with each patient’s background.

References

Appendix A Appendix

A.1 Text preprocessing

Cohort selection Code-rare Code-common
shots mean 89 5 918
tokens mean 4924 4223 7062
median 4632 3236 5177
99% 10781 14345 13356
max 13989 18480 14773
Table A.1: Number of instances per label (shots) and number of tokens per input.

We followed Huang et al. (2019) to format notes during text preprocessing. But We did not convert text to lowercase because Mamba tokenizer is able to process both upper and lower cases. For notes on each patient’s hospital visit, we sorted notes by their charted date and concatenated notes into one string. We used string "- - {NoteType} note - -" to separate the notes. Table A.2 shows comprehensive values of {NoteType}.

For pretraining data, we truncate notes with more than 16k tokens, however, this is only less than 2%, a length distribution is provided in Figure A.1. We exclude a small subset (6,049) of hospital visits due to the evaluation of MIMIC ICD coding and MIMIC hospital readmission prediction, The visit ids (hadm_id) are documented in the github.

Refer to caption
Figure A.1: Long tail distribution of number of tokens per each visit. Y-axis is the density (sum to 1.0).
Category Count % Len
Nursing 506,528 73 241
Radiology 338,834 83.3 449
ECG 123,042 61.3 43
Physician 92,426 18.2 1369
Discharge summary 47,572 96.7 2195
Echo 34,064 45.8 464
Respiratory 32,798 8.1 205
Nutrition 7,971 6.4 602
General 7,710 6.4 290
Rehab Services 5,321 4.6 622
Social Work 2,294 2.8 446
Case Management 939 1.3 260
Pharmacy 97 0.1 512
Consult 78 0.1 1206
Table A.2: Statistic of note events documented in MIMIC-III dataset. Each column represents a) the number of notes, b) proportion of visits, c) average number of words for each note type.

A.2 Pretraining recipe

ClinicalMamba-2.8b is a selective state space model designed using replication of the Mamba architecture (Gu and Dao, 2023). ClinicalMamba refers to the class of models, while 2.8b represents the number of parameters of this particular pretrained model. We also pretrained ClinicalMamba-130m using pretraining data from the previous section. The specific values of hyperparameters are shown in Table A.3. These models were trained for 763 million English tokens over 7000 steps (3 epochs) (Muennighoff et al., 2023). It was trained as an autoregressive language model, using cross-entropy loss (Brown et al., 2020). For learning rate scheduling, we followed Mamba and chose linear learning rate warmup with cosine decay to 1e51e-5. We found this important setting to avoid loss overflow. It took under 60 hours to pretrain ClinicalMamba-2.8b in on 4 Nvidia Tesla A100-80GB GPUs.

Hyperparameter Value
num param 130m/2.8b
num layer 24/64
dim model 768/2560
context len 16k
num vocab 50277
position emb None
optimizer Adam
beta1 0.9
beta2 0.95
epsilon 1e-5
batch size 32
weight decay 0.1
gradient clipping 1.0
peak learning rate 1e-3/6e-4
Table A.3: Hyperparameters used to train ClinicalMamba.

A.3 Baselines

GPT-4 is a large language model designed to understand and generate human-like text based on the input it receives. We applied zero-shot prompting to each downstream task, using ACAN and original prompt introduced in Wornow et al. (2024). GPT-4 (version 2023-12-01-preview) was accessed securely through the Azure OpenAI API. We set the sampling temperature for decoding to 0.1.

Asclepius-R (Kweon et al., 2023) is a clinical generative language model trained on MIMIC-III discharge summaries and corresponding instruction-answer pairs. It has 7 billion parameters with a maximum input of 4096 tokens.

ClinicalLlama2 (CLlama2) is similar to Asclepius-R, but it was trained on all types of MIMIC-III note with the same computation budget as ClinicalMamba-2.8b (60 hours on 4 A100). It has 7 billion parameters with 4096 max context length.

ClinicalLongformer (CLongformer) (Li et al., 2022) is a clinical knowledge enriched version of Longformer that was further pretrained using MIMIC-III clinical notes. It has 149 million parameters with a maximum input of 4096 tokens. We only used local attention and does not apply global attention for computation efficiency.

Hierachical ClinicalRoberta (Hi-CRoberta) (Huang et al., 2022), utilizes multiple embedding from clinical Roberta (Lewis et al., 2020). It first segment clinical notes into chunks of 512 tokens to obtain their embeddings, embeddings are then pooled by concatenation and finally a linear classification head during downstream task. It has 110 million parameters with a max of 16384 tokens.

MultiResCNN (Li and Yu, 2020) encode free text with Multi-Filter ResidualCNN, and applied label code attention mechanism to enable each ICD code to attend different parts of the document.

Context Length Perplexity Throughput
Pythia-130m 1k 34.86 12042
4k 69.79 14859
16k 566.42 15713
Mamba-130m 1k 29.90 24695
4k 26.45 34539
16k 34.11 37009
ClinicalMamba-130m 1k 3.64 24709
4k 3.29 33349
16k 3.08 35486
ClinicalLlama2-7b 1k 3.28 1005
4k 2.82 1013
16k 94.02 951
Asclepius-R-7b 1k 4.01 1066
4k 8.79 1064
16k 97.37 995
ClinicalMamba-2.8b 1k 3.11 2932
4k 2.84 3007
16k 2.61 3027
Table A.4: Trade-off between language modeling abilities (perplexity) and inference throughput (Tokens/s) for a number of models on MIMIC-III clinical notes. Among 3 variations of 1k, 4k, and 16k context length, the variation with best perplexity is bold for each model.