††Faculty of Information Science and Technology, Hokkaido University, Japan
E-mail: {zhu, togo, ogawa, mhaseyama}@lmd.ist.hokudai.ac.jp
Which Client is Reliable?: A Reliable and Personalized Prompt-based Federated Learning for Medical Image Question Answering
Abstract
Conventional medical artificial intelligence (AI) models face barriers in clinical application and ethical issues owing to their inability to handle the privacy-sensitive characteristics of medical data. We present a novel personalized federated learning (pFL) method for medical visual question answering (VQA) models, addressing privacy reliability challenges in the medical domain. Our method introduces learnable prompts into a Transformer architecture to efficiently train it on diverse medical datasets without massive computational costs. Then we introduce a reliable client VQA model that incorporates Dempster-Shafer evidence theory to quantify uncertainty in predictions, enhancing the model’s reliability. Furthermore, we propose a novel inter-client communication mechanism that uses maximum likelihood estimation to balance accuracy and uncertainty, fostering efficient integration of insights across clients. The code will be available soon.
Keywords:
Medical VQA personalized federated learning large-scale model.1 Introduction
Data heterogeneity and privacy protection [1] are key challenges in the application of deep learning in the healthcare field. Federated Learning (FL) [12], which allows for learning without centralizing data, is considered a promising solution to this issue. FL is a learning approach that enables the training of distributed models by sharing parameters instead of data. In the typical problem setting of FL in the medical field [24], a model is locally trained in each hospital using patient data and sends only the learned model parameters to a central server. The central server aggregates these parameters to create an improved model. This approach allows for improving the model using data from multiple hospitals while protecting patient privacy. In recent years, research [14, 17] has been conducted to introduce FL to medical tasks, addressing privacy and data heterogeneity issues. Many studies in medical FL focus on simple tasks such as segmentation [20] and classification [5, 28, 21], using predefined input and output formats. Also, most of these studies aim to construct highly versatile models by integrating information from multiple client models. However, the medical field often requires models that are personalized for specific targets rather than those with broad generalization abilities. It can be clinically more beneficial to fine-tune a model using insights from other models trained on vastly different data rather than relying solely on a single generalized model. This is especially important in clinical practice because the same medical images or symptoms often can be interpreted differently for each patient, highlighting the need for personalized models. Hence, FL remains room in the current medical field to tackle more challenging tasks and problem settings.
Medical visual question answering (VQA) [9] is a multimodal task capable of handling diverse questions and answers and is considered more challenging than tasks such as segmentation or classification. Implementing a VQA system enables the provision of specialized advice to physicians across different departments and facilitates the delivery of medical advice tailored to specific patients. In natural images, the proposal of Transformer-based baseline models has significantly improved the accuracy of the VQA task [19, 2]. However, sufficient consideration has not been given to the application in the medical field, as various clinical constraints need to be taken into account for its implementation.
Personalized Federated Learning (pFL) [4] is an approach that focuses on the performance of individual clients rather than the global model of a central server, using protected data to learn personalized models. This approach demonstrates significant compatibility with the medical sector, yet it simultaneously highlights the potential for enhancement in client information aggregation. For example, recent studies [29] based on prompt learning [23, 22] do not consider the relevance between clients, leading to a degradation of model performance. Assessing the reliability of each client is a crucial issue when considering clinical applications in pFL [11].
In this paper, we explore the combination of pFL and VQA, tackling a more clinically challenging problem setting. The VQA task has the potential for significant roles in clinical applications with advancements in large language models (LLMs). We propose a novel Transformer model that quantifies the uncertainty of each client and efficiently aggregates beneficial information in this new problem setting. Specifically, we simulate different departments in a hospital and set up clients for medical images from various organs, each consisting of a VQA model. To reduce the communication burden, we introduce learnable prompts to the Transformer’s multi-head attention (MHA) layer, allowing efficient learning of personalized client data distributions. Furthermore, we propose a Dynamic Likelihood-weighted Uncertainty Calibration (DLUC) process to effectively aggregate information in inter-client communication. This process evaluates the uncertainty of client VQA models using the Dempster-Shafer evidence theory (DST) [16] and dynamically adjusts weights based on this evaluation. This ensures the reliability of the model and supports effective decision-making. Extensive qualitative and quantitative experiments on two closed and open-ended medical VQA datasets demonstrate that our method can efficiently aggregate information for personalized clients.
2 Methodology

2.1 Problem Formulation
An overview of our method is shown in Fig. 1. We define separate clients , each of which consists of a VQA model on distributed datasets through the pFL setting. Given a client model with parameter and the evaluation metric , the final optimization objective can be expressed as , where and denote the generated answer and the uncertainty, respectively. By optimizing this objective, the local training tries to generate and , which will be used as a criterion of the model’s performance evaluation during the next step of the inter-client communication process. Since the reduction in data utilized for training client models inevitably leads to a degradation in performance, the sharing of learned data patterns among clients facilitates a collaborative training process that is mutually advantageous.
Specifically, for client , we define the communication method , and the objective is The objective is to optimize a method that maximizes the information obtained from the other clients for the local training of the client . By optimizing the two mentioned objectives, we can train personalized client models for different medical data distributions without data sharing. This method also maximizes the aggregation of valid information from other clients for each client . The local training and inter-client communication processes are described in Sections 2.2 and 2.3, respectively.
2.2 Local Training for Medical VQA Client Model
For client , the dataset comprises instances that consist of an input question , an image , and the corresponding ground truth answer . We utilize Transformer-based encoders and to extract features of the input image and question separately. Subsequently, an answering decoder generates the predicted answer.
In the proposed method, we add learnable prompts ( is the length of the prompt, is the hidden dimension of the Transformer) to the multi-head attention (MHA) layers of the encoders. Specifically, we use Prefix-Tuning [8] splitting into , which are introduced into the MHA layer as follows:
(1) |
where , , and are the original inputs of the MHA layer.
Conventional VQA models output the prediction for the answer probability directly from the hidden state of the decoder output via a softmax activation function. In order to capture the uncertainty for evaluating the model performance, based on the DST theory, we consider the probability of answers to obey a Dirichlet distribution with parameter ( is the length of the answer list). Specifically, we define the output hidden states of the decoder as the evidence of each candidate’s answer. Since the evidence is non-negative, in order to get the evidence, we replace the final softmax layer of the conventional VQA models with the activation layer , and can be calculated as . The Dirichlet parameter can be calculated by evidence as . Based on , the belief mass of the model for each candidate’s answer can be defined as , where . Unlike traditional probability distributions, the sum of does not equal 1, and the difference from 1 is the uncertainty of the model, which can be calculated as .
To assign more evidence to the correct answer, we employ the following Dirichlet term:
(2) |
where is the beta function. The Dirichlet term is introduced into the cross-entropy loss function according to obtain as follows:
(3) |
where is the one-hot vector of the ground truth answer, and is the digamma function. Note that Eq. (3) aims to allocate the sum of all evidence generated by predictions as much as possible to the correct answers, providing positive feedback. However, since the above loss does not ensure that incorrect labels yield less evidence, the KL term is introduced to minimize the evidence for incorrect labels as much as possible to 0:
(4) | ||||
(5) |
Here, is a parameter vector of ones, and is the Dirichlet parameter to prevent penalties for correct evidence (when , and the loss becomes 0). Therefore, the final loss function can be calculated as follows:
(6) |
where is a balance parameter. The client model can output the answer text and the uncertainty by training the model using the proposed loss function. In the inter-client communication process, uncertainty evaluates the model’s performance.
2.3 DLUC-based Inter-client Communication Process
We define a dynamic weight for client , and the prompt is added with the weighted sum of the prompts of the other clients . Given the data and the clients’ parameters , the conditional expectation, as a lower bound on the true likelihood, is computed as
(7) |
where , is an aggregation rate, and the is the weight optimized by the EM method . Since the prompts represent the client data schema, we use a neural network to optimize the objective of Eq. (7). We take the uncertainty as the network input and the weights as the output. During training, we set up local training step and inter-client communication separately, representing the number of iterations each is executed, and cross these two processes. This DLUC aims to find the best configuration of weights to maximize the total likelihood of the model for all client data. This optimization allows the client to aggregate the maximum amount of information from the server while avoiding the interference of invalid information.
3 Experiments
We are motivated by the excellent performance of transformers on multimodal tasks in the medical domain. Considering the wide range of VQA task applications, we selected medical images and QA pairs from VQA-RAD [6] and Slake [10] datasets. These datasets were further partitioned into sub-datasets based on the anatomical focus of the images. With reference to the setup of different departments in the actual clinic, we split Slake into “Lung” (Client1), “Abdomen” (Client2), “Brain” (Client3), and “Other” (Client4), VQA-RAD into “Abdomen” (Client1), “Chest” (Client2), “Head” (Client3), respectively. Details of the dataset can be found in Table 1 of Supplementary. To mitigate the impact of accuracy fluctuations observed among some clients, attributed to the limited size of their test datasets, we determined the final accuracy based on the average results and the variance from the last round of local training. Furthermore, we also tested the method’s performance when the number of clients spikes using the large-scale dataset PMC-VQA [27].
Method | VQA accuracy (%) | Param. (M) | ||||||||
Slake | VQA-RAD | Enc. | Upd. | |||||||
PMC-VQA [27] | 82.5 | 86.8 | 7000 | 7000 | ||||||
BioMedGPT [25] | 82.5 | 81.3 | 1500 | 1500 | ||||||
LoRA [18] | 82.1 | - | 1500 | 15 | ||||||
M2I2 [7] | 83.2 | 70.8 | 600 | 600 | ||||||
BioMedCLIP [26] | 86.7 | 79.8 | 150 | 340 | ||||||
MEVF-BAN [13] | 75.4 | 75.1 | 50 | 50 | ||||||
VGG+SAN [10] | 75.4 | 74.0 | 50 | 50 | ||||||
Client1 | Client2 | Client3 | Client4 | Client1 | Client2 | Client3 | ||||
PMVQA [29] | 84.4 | 72.3 | 76.9 | 81.4 | 68.9 | 64.7 | 74.2 | 63 | 0.3 | |
PM | 63 | 0.01 |
Baseline | Param. | Slake | VQA-RAD | ||||||
---|---|---|---|---|---|---|---|---|---|
Lung | Abdomen | Brain | Others | Chest | Abdomen | Head | |||
VIT-B/32 [15] | 63M | ||||||||
PubMedCLIP [3] | 63M | ||||||||
VIT-L/14 [15] | 300M | ||||||||
ViT-L/14@336px [15] | 300M |
CLOSED | OPEN | |||||||||
---|---|---|---|---|---|---|---|---|---|---|
Client1 | Client2 | Client3 | Client4 | Client1 | Client2 | Client3 | Client4 | |||
PM w/o client | ||||||||||
PM w/o server | ||||||||||
PM w/o DLUC | ||||||||||
PM |
CLOSED | OPEN | ||||||||
---|---|---|---|---|---|---|---|---|---|
Client1 | Client2 | Client3 | Client1 | Client2 | Client3 | ||||
PM w/o client | |||||||||
PM w/o server | |||||||||
PM w/o DLUC | |||||||||
PM |
For specific modeling details, we employed the pre-trained Contrastive Language–Image Pre-training (CLIP) [15] model as the image and text encoders and . All clients share a CLIP with a fixed weight, and only the client’s prompts are updated during training. In our experiments, we got the best results when we added prompts in the first four blocks of the Transformer and set the length of the prompts to 24 and 30 for the Slake and VQA-RAD, respectively. The reason is that the feature in different organs focus more on patterns in smaller regions, and the Transformer’s beginning few blocks outputs tend to capture more localized information. The learning rate of the prompts is initially 0.001, multiplied by 0.5 after each , i.e., one complete local update. The answer layer is implemented as a 2-layer MLP with a hidden size 512 and a dropout rate of 0.2. We used a fully connected layer for the DLUC, whose learning rate is initially 0.01 and multiplied by 0.1 after each complete inter-client communication.
Comparison Methods: Given the complexity and the substantial size of the baseline model employed, along with the intricacies of the multimodal task setup, direct comparison with previous FL methods is not feasible. Considering that we aim to enhance the personalized performance of medical VQA, we compare the proposed with several state-of-the-art medical VQA models regarding accuracy and model parameters that need to be updated. Specifically, we employed PMC-VQA [27], BioMedGPT [25], M2I2 [7], BioMedCLIP [26], MEVF-BAN [13], VGG+SAN [10], LoRA [18] and PMVQA [29].
As the results shown in Table 1, compared to the PMVQA [29] with the same task setup, we averaged a 3.5% accuracy improvement on most clients. Compared to other locally trained models, our method achieves higher precision when the number of baseline parameters is close. However, there is still a gap when comparing methods using more parameter baselines. It is worth noting parameters that need to be updated by our method are 0.01% of the Vision Transformer-based model and 0.00001% of the current SOTA model.
To demonstrate that our methodology can be effectively extended to all transformer structure networks, we tested the performance of our method on multiple baselines. As shown in Table 2, our method can effectively improve the personalization performance of multimodal transformers, and we achieved an average of 3% improvement on most clients. Further analysis of the differences in weights and performance across clients is in the supplementary.
Ablation studies. We designed: trained on the whole data without setting up the client, setting up the client but without parameter sharing in sever, and without DLUC for information aggregation (equal weights). As shown in Table 3 and Table 4, the results demonstrate that DLUC can effectively aggregate clients’ information to improve the performance of each client.
Analysis of inter-client weights. Fig. 2 illustrates the weights generated by DLUC during inter-client communication for each client, and the results show significant differences in the dependency between clients. For example, the image prompt of “Lung” is highly dependent on “Abdomen” information because the images of these client in the Slake are tomographic CT scans from immediate area that have similar features. Furthermore, the images from the three clients of VQA-RAD are obtained from CT, X-ray, and MRI, so there is no significant dependence on the results due to the large modality gap. The results show that our generation of weights is clinically meaningful.
Cross evaluation between prompts. To further demonstrate that prompts can effectively personalize the model rather than learn generalizability information, we performed accuracy tests for each client using prompts from the other clients. As shown in Fig. 4, there is a significant drop in accuracy when the prompts are exchanged, proving that the prompts learn information about different data distribution.





Impact of hyperparameters. The model’s performance is affected by three primary parameters: the local training step , the DLUC step in inter-client communication , and the aggregation rate . The comparison results on step are shown in Fig. 1 in Supplementary. The results shows that higher tends to perform better, and performances on higher are better when the is low. This is because when there is less local training, the model cannot fully adapt to the local data, and optimizing the aggregation weights by DLUC can integrate information from different clients more effectively. As for , both too high and too low result in deterioration of the results due to excessive or insufficient influence of external information as shown in Fig. 4.
Extreme Tests: Since each of our clients essentially retains only learnable prompts with only 0.01M parameters, our method can set up a tremendous number of clients, which was absolutely difficult to achieve with the previous pFL method. We experimented with setting up 500 clients on the PMCVQA dataset, consumed only 13G of VRAM, and experimental results show an average 1.8% accuracy improvement on 56% of the clients.
4 Conclusion
We present a prompt-based, reliable pFL method for medical VQA. Our method creates clients for heterogeneous medical data and achieves aggregation of information through the novel DLUC. In the DLUC process, we calculate the uncertainty of each client through DST theory and generate aggregation weights based on maximum likelihood estimation through uncertainty. The experimental results demonstrate that our method can effectively improve the performance of the transformer-based client model with minimal overhead. Then, the computation of our method is independent of the size of the baseline network and can be applied in personalized learning with any transformer of any size. In future work, we will introduce the method to personalize larger-scale models.
References
- [1] Abouelmehdi, K., Beni-Hessane, A., Khaloufi, H.: Big healthcare data: preserving security and privacy. Journal of big data 5(1), 1–18 (2018)
- [2] Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018)
- [3] Eslami, S., de Melo, G., Meinel, C.: Does clip benefit visual question answering in the medical domain as much as it does in the general domain? arXiv preprint arXiv:2112.13906 (2021)
- [4] Fallah, A., Mokhtari, A., Ozdaglar, A.: Personalized federated learning: A meta-learning approach. arXiv preprint arXiv:2002.07948 (2020)
- [5] Jiang, M., Zhong, Y., Le, A., Li, X., Dou, Q.: Client-level differential privacy via adaptive intermediary in federated medical imaging. In: MICCAI. pp. 500–510 (2023)
- [6] Lau, J.J., Gayen, S., Abacha, A.B., Demner-Fushman, D.: A dataset of clinically generated visual questions and answers about radiology images. Scientific Data 5, 1–10 (2018)
- [7] Li, P., Liu, G., Tan, L., Liao, J., Zhong, S.: Self-supervised vision-language pretraining for medial visual question answering. In: ISBI. pp. 1–5 (2023)
- [8] Li, X.L., Liang, P.: Prefix-tuning: Optimizing continuous prompts for generation. arXiv preprint arXiv:2101.00190 (2021)
- [9] Lin, Z., Zhang, D., Tao, Q., Shi, D., Haffari, G., Wu, Q., He, M., Ge, Z.: Medical visual question answering: A survey. Artificial Intelligence in Medicine 143, 102611 (2023)
- [10] Liu, B., Zhan, L.M., Xu, L., Ma, L., Yang, Y., Wu, X.M.: Slake: A semantically-labeled knowledge-enhanced dataset for medical visual question answering. In: ISBI. pp. 1650–1654 (2021)
- [11] Liu, P., Yuan, W., Fu, J., Jiang, Z., Hayashi, H., Neubig, G.: Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. ACM Computing Surveys 55(9), 1–35 (2023)
- [12] McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.: Communication-efficient learning of deep networks from decentralized data. In: AISTATS. pp. 1273–1282. PMLR (2017)
- [13] Nguyen, B.D., Do, T.T., Nguyen, B.X., Do, T., Tjiputra, E., Tran, Q.D.: Overcoming data limitation in medical visual question answering. In: MICCAI. pp. 522–530 (2019)
- [14] Nguyen, D.C., Pham, Q.V., Pathirana, P.N., Ding, M., Seneviratne, A., Lin, Z., Dobre, O., Hwang, W.J.: Federated learning for smart healthcare: A survey. ACM Comput. Surv. 55(3) (feb 2022)
- [15] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., et al.: Learning transferable visual models from natural language supervision. In: ICML. pp. 8748–8763 (2021)
- [16] Sensoy, M., Kaplan, L., Kandemir, M.: Evidential deep learning to quantify classification uncertainty. In: NIPS. p. 3183–3193 (2018)
- [17] Sheller, M.J., Edwards, B., Reina, G.A., Martin, J., Pati, S., Kotrotsou, A., Milchenko, M., Xu, W., Marcus, D., Colen, R.R., et al.: Federated learning in medicine: facilitating multi-institutional collaborations without sharing patient data. Scientific reports 10(1), 12598 (2020)
- [18] van Sonsbeek, T., Derakhshani, M.M., Najdenkoska, I., Snoek, C.G.M., Worring, M.: Open-ended medical visual question answering through prefix tuning of language models. In: MICCAI. pp. 726–736 (2023)
- [19] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Advances in neural information processing systems 30 (2017)
- [20] Wang, J., Jin, Y., Stoyanov, D., Wang, L.: Feddp: Dual personalization in federated medical image segmentation. IEEE Transactions on Medical Imaging (2023)
- [21] Wang, M., Wang, L., Xu, X., Zou, K., Qian, Y., Goh, R.S.M., Liu, Y., Fu, H.: Federated uncertainty-aware aggregation for fundus diabetic retinopathy staging. In: MICCAI (2023)
- [22] Wang, Z., Zhang, Z., Ebrahimi, S., Sun, R., Zhang, H., Lee, C.Y., Ren, X., Su, G., Perot, V., Dy, J., et al.: Dualprompt: Complementary prompting for rehearsal-free continual learning. In: ECCV. pp. 631–648 (2022)
- [23] Wang, Z., Zhang, Z., Lee, C.Y., Zhang, H., Sun, R., Ren, X., Su, G., Perot, V., Dy, J., Pfister, T.: Learning to prompt for continual learning. In: CVPR. pp. 139–149 (2022)
- [24] Xu, J., Glicksberg, B.S., Su, C., Walker, P., Bian, J., Wang, F.: Federated learning for healthcare informatics. Journal of Healthcare Informatics Research 5, 1–19 (2021)
- [25] Zhang, K., Yu, J., Yan, Z., Liu, Y., Adhikarla, E., Fu, S., Chen, X., Chen, C., Zhou, Y., Li, X., et al.: Biomedgpt: A unified and generalist biomedical generative pre-trained transformer for vision, language, and multimodal tasks. arXiv preprint arXiv:2305.17100 (2023)
- [26] Zhang, S., Xu, Y., Usuyama, N., Bagga, J., Tinn, R., Preston, S., Rao, R., Wei, M., Valluri, N., Wong, C., et al.: Large-scale domain-specific pretraining for biomedical vision-language processing. arXiv preprint arXiv:2303.00915 2(3), 6 (2023)
- [27] Zhang, X., Wu, C., Zhao, Z., Lin, W., Zhang, Y., Wang, Y., Xie, W.: Pmc-vqa: Visual instruction tuning for medical visual question answering. arXiv preprint arXiv:2305.10415 (2023)
- [28] Zhou, Q., Zheng, G.: Fedcontrast-gpa: Heterogeneous federated optimization via local contrastive learning and global process-aware aggregation. In: MICCAI. pp. 660–670. Springer (2023)
- [29] Zhu, H., Togo, R., Ogawa, T., Haseyama, M.: Prompt-based personalized federated learning for medical visual question answering. arXiv preprint arXiv:2402.09677 (2024)