This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: Harbin Engineering University 22institutetext: The Second Affiliated Hospital of Mudanjiang Medical University
22email: [email protected]

CTS: A Consistency-Based Medical Image Segmentation Model

Kejia Zhang 11    Lan Zhang 11    Haiwei Pan 11    Baolong Yu 22
Abstract

In medical image segmentation tasks, diffusion models have shown significant potential. However, mainstream diffusion models suffer from drawbacks such as multiple sampling times and slow prediction results. Recently, consistency models, as a standalone generative network, have resolved this issue. Compared to diffusion models, consistency models can reduce the sampling times to once, not only achieving similar generative effects but also significantly speeding up training and prediction. However, they are not suitable for image segmentation tasks, and their application in the medical imaging field has not yet been explored. Therefore, this paper applies the consistency model to medical image segmentation tasks, designing multi-scale feature signal supervision modes and loss function guidance to achieve model convergence. Experiments have verified that the CTS model can obtain better medical image segmentation results with a single sampling during the test phase.

Keywords:
Diffusion models Consistency models Medical image segmentation.
Refer to caption
Figure 1: CTS model overall flowchart. (a) The process of multi-scale feature supervision signal input is displayed. (b). The overlay process of feature supervision signals through channel attention mechanism is shown.

1 Introduction

The field of medical image segmentation has always been a hot research direction within the image segmentation domain. Unlike traditional segmentation methods[15, 22], utilizing generative models for image segmentation[20] can also achieve good results. Since diffusion models[17, 10] are a type of generative model that samples from Gaussian noise, the images they generate possess strong noise resistance and smoothness. Consequently, an increasing number of studies are leveraging diffusion models to tackle the non-generative issues of different images. Researchers use masks as the target for generative model sampling, while also incorporating constraints in the generative models to guide the direction of model generation. However, due to the need for extensive resampling during training and prediction, the issue of low computational efficiency in diffusion models urgently needs to be addressed. Consistency model[18] transform multiple samplings into a single sampling by constructing a unique solution by ODE, significantly reducing the time consumed during the sampling process. Moreover, while reducing the number of samplings, consistency models also ensure the effectiveness of the samples. Compared to DDPMs[10], consistency model represent a superior generative paradigm, yet studies applying this model to the field of medical image segmentation are currently lacking. Therefore, this paper proposes constructing a medical image segmentation model based on the consistency model, and designing a loss function according to the segmentation loss and consistency training loss, enabling end-to-end training of the model. CTS code can be obtained in https://github.com/LanHEU/CTS .

The specific contributions of this text are as follows:

  • A medical image segmentation model based on a consistency model has been constructed, featuring a newly designed joint loss function.

  • During the decoding phase, multi-scale feature supervision signals are utilized to guide the model’s convergence direction.

2 Related Works

In this section, we briefly describe the existing lines of research relevant to our work. Diffusion models have been applied to many fields, such as sequence modeling [12, 5], speech processing[14], computer vision[16, 9] to computed tomography (CT) scanning and magnetic resonance imaging (MRI). In computer vision, to reduce the number of sampling times, many methods have made great efforts. There are also some sampling algorithms tailored for conditional generation, such as without classifier guidance[11] or with classifier guidance[6]. Image segmentation is an important task in computer vision, which studies simplifying the complexity of an image by decomposing it into multiple meaningful image segments[2, 8]. Due to the time, cost, and expertise required[1, 4], the number of images and labels for medical image segmentation is limited. For this reason, diffusion models, by synthesizing labeled data and eliminating the need for pixel-level labeled data, have become a promising method in image segmentation research. BrainSPADE[7] proposed a generative model for synthesizing labeled brain MRI images, which can be used to train segmentation models. However, diffusion models in medical image segmentation face issues such as a high number of sampling times and long prediction times.

3 Method

This paper aims to fully leverage the advantages of sampling once with a consistency model, while retaining the benefits of the segmentation model. In consistency model[18], the method of directly training a consistency model is referred to as consistency training loss, which is the origin of the ’CT’ in the name of ’CTS’. The specific process is shown in Fig1.

Similar to the consistency model, the basic framework of this paper includes two parts: model MM and target model TMTM. The model’s sampling begins with the mask xmx^{m} of each image, inputting the corresponding data xdx^{d}as a supervisory signal. Initialize the parameters of the two models, and copy the parameters from MM to TMTM.

Step 1. The input to the model is the mask xmx^{m}, and noise znz_{n}, sampled from a Gaussian distribution at step n, is added to the tensor: xnm=xm+znx^{m}_{n}=x^{m}+z_{n}.

Step 2. Simultaneously, based on the time period tt, obtain: Model average learning moving parameters: cinc_{in},coutc_{out},cskipc_{skip}

Step 3. The final input to model MM is: xinm=cinxnmx^{m}_{in}=c_{in}*x^{m}_{n}

Step 4. Generate multi-scale signals using xid,y^=hT(xd)\bigcup x_{i}^{d},\hat{y}=h^{T}\left(x^{d}\right), where i(1,2,)i\in\left(1,2,...\right)

Step 5. Here, the feature signal xid\bigcup x_{i}^{d} is incorporated into the UNet model: youtm=gT(xinm,xid,t)y_{out}^{m}=\ g^{T}\left(x_{in}^{m},\bigcup x_{i}^{d},t\right)

Step 6. The output of the final model M is: ynm=coutyoutm+cskipxnmy_{n}^{m}=c_{out}\ast y_{out}^{m}+\ c_{skip}\ast x_{n}^{m}

Step 7. Utilize the normal sampling method to obtain the noisezn+1z_{n+1} of (n+1)\left(n+1\right)th sampling from the Gaussian distribution. xn+1m=f𝒰(xm,xnm,n,n+1)x_{n+1}^{m}=f_{\mathcal{U}}\left(x^{m},x_{n}^{m},n,n+1\right)

Step 8: Use xn+1mx_{n+1}^{m} as the input for model TMTM, and repeat the above Step 2-5. And the output is yn+1m=coutgTM(xinm,xd,t)+cskipxn+1my_{n+1}^{m}=c_{out}\ g^{TM}\left(x_{in}^{m},x^{d},t\right)\ +\ c_{skip}\ast x_{n+1}^{m}

The consistency training segmentation loss is as follows: CT=yn+1mynm2\mathcal{L}_{CT}=yn+1m-ynm2 To expedite the convergence speed and training outcomes, training is conducted on structures generated by multi-scale signals: S=yxm2\mathcal{L}_{S}=y-xm2 Overall Loss Function: CTS=CT+αS\mathcal{L}_{CTS}=\mathcal{L}_{CT}+\alpha\mathcal{L}_{S}, where α\alpha is hyperparameter.

Step 9: Update the TMTM model parameters; the TMTM model update adheres to the learning rate: θTMstopgrad(μ(k)θTM+(1μ(k))θM)\theta^{TM}\leftarrow\ stopgrad\left(\mu\left(k\right)\theta^{TM}+\left(1-\mu\left(k\right)\right)\theta^{M}\right).

The pseudocode of the CTS algorithm is shown in Alg1.

Multi-scale Feature Supervision Signal. The process of integrating multi-scale feature supervision signals xid\bigcup x_{i}^{d} is shown in Fig1(a). The decoder stage of the image data encoding network progressively generates feature maps of each size, and combines them with the corresponding supervision signals. In the decoder stage of the image data encoding network, a corresponding supervision signal xidx_{i}^{d} is gradually generated for each size feature map. The supervision signal xidx_{i}^{d}, as shown in Fig1(b). This process is integrated into the MM model through a channel attention mechanism, achieving the addition of multi-scale supervision signals. During the decoder stage of the image data encoding network, these feature maps contain information of various scales, which can assist the model in better understanding the details and contextual information of the image. To better integrate the supervision signals and feature maps, a channel attention mechanism is employed. This automatically learns the importance weights of each channel, thereby making better use of the information from the supervision signals.

Algorithm 1 Consistency Training Segmentation(CST)
  
  Input: Dataset 𝒟\mathcal{D}, initial model parameter θM\theta^{M}, learning rate η\eta , step scheduleN()N(\cdot). EMA decay rate schedule μ()\mu(\cdot), d(,)d(\cdot,\cdot),and λ()\lambda(\cdot)
  θTMθM\theta^{TM}\leftarrow\theta^{M} and k0k\leftarrow 0
  repeat
        Sample xm,xd𝒟x^{m},x^{d}\sim\mathcal{D}, and n𝒰(1,N(k)1)n\sim\mathcal{U}\left(1,N(k)-1\right)
        Sample z𝒩(0,𝐈)z\sim\mathcal{N}\left(0,\mathbf{I}\right)
        CT(θM,θTM)\mathcal{L}_{CT}\left(\theta^{M},\theta^{TM}\right)\leftarrow
         λ(tn)d{gθM[x+tn+1𝐳,hθM(xd),tn+1],gθTM[x+tn𝐳,hθTM(xd),tn]}\lambda(t_{n})d\{g_{\theta}^{M}[x+t_{n+1}\mathbf{z},h_{\theta}^{M}\left(x^{d}\right),t_{n+1}],g_{\theta}^{TM}[x+t_{n}\mathbf{z},h_{\theta}^{TM}\left(x^{d}\right),t_{n}]\}
        θηθCT(θM,θTM)\theta\leftarrow-\eta\bigtriangledown_{\theta}\mathcal{L}_{CT}(\theta^{M},\theta^{TM})
        θstopgrad(μ(k)θ+(1μ(k))θ)\theta^{-}\leftarrow stopgrad(\mu(k)\theta^{-}+(1-\mu(k))\theta)
        kk+1k\leftarrow k+1
  until convergence

4 Experiment

This section demonstrates through experimentation the advantages of CTS in medical image segmentation. We started with a thorough comparison of existing alternatives, followed by additional analysis to dissect the reasons behind CTS’s success.

Datasets. We conducted experiments on medical tasks in two different image modalities: MRI image segmentation of brain tumors, and ultrasound image segmentation of thyroid nodules and liver tumor segmentation on the BraTs-2021 dataset[3] as well as in SEHPI datasets[21]. This paper utilized anisotropic diffusion filtering[13], while also removing Poisson noise from medical images, preserving more edge information and effective feature structures. Consequently, this further improved the model’s performance.

Table 1: Result. CTS-nM: without multiscale feature supervision signals. CTM-M: with multiscale feature supervision signals. CTM-FM: with the FFTP structure mentioned same with MedSegDiff.
Brain-Turmor SEHPI Thyroid Nodule
Dice IoU Dice IoU Dice IoU
CENet 76.2 68.9 82.4 70.5 78.9 71.2
MRNet 83.4 75.6 85.9 73.4 80.4 73.4
SegNet 80.2 72.9 86.8 73.1 81.7 74.5
nnUNet 88.2 80.4 88.2 72.9 84.2 76.2
TransUNet 86.6 79.0 89.3 75.4 83.5 75.1
MedSegDiff-L 89.9 82.3 89.8 78.5 86.1 79.6
CTS-nM 90.0 82.5 - - - -
MedSegDiff++ 90.5 82.8 90.3 79.3 86.6 80.2
CTS-M 91.7 83.9 91.0 80.5 87.3 81.3
CTS-FM 92.1 84.0 - - - -

Experiment Details. We utilized a 4×4\times UNet. In the testing phase, we employed a single diffusion step for inference, which is significantly smaller than most previous studies. All experiments were implemented using the PyTorch platform and executed on one GTX 4090. All images were uniformly resized to 256×256 pixels. Training was conducted in an end-to-end manner using the AdamW optimizer. batchsize=8batchsize=8. The learning rate was initially set to 1×1041\times 10^{-4}.

Refer to caption
Figure 2: Result Visualization

Main result. We compared the SOTA segmentation methods proposed for each task with general medical image segmentation methods. The main results are presented in Table 1. Part of the relevant results originates from the work[19]. In our experiments, we trained on each dataset for 700,000 iterations, with the specific training duration being one month. CTS-nM indicates that the model did not use multiscale feature supervision signals. CTM-M denotes the use of multiscale feature supervision signals. For a fair comparison, the Meg method incorporates a Fourier filter, while CTM-FM includes the FFTP structure mentioned same with MedSegDiff. Detailed results can be seen in Tab1. Visualization results are shown in Fig2. The results reveal that CTS-nM can surpass most methods, demonstrating that consistency model can not only reduce the number of samplings but also enhance effectiveness. The performance of the CTM-M model, with the addition of multi-scale feature signals, is further improved. CTM-FM also proves that Fourier filtering can further enhance performance. All methods under CTS were tested with a single sampling, averaging 1.9s. This significantly reduces the time compared to other models. Therefore, the CTS model can guarantee model effectiveness while accelerating sampling.

In Fig4, the convergence process of the model is shown. The trend of the loss value during training. The changing trends of IoU and Dice metrics for model parameters saved at different times on the test set. It can be observed that, as the training time increases, the loss region during training becomes more leveled, but the test results do not show saturation, with a significant growth margin. This is likely related to the strong learning and representational capabilities of the consistency model. The findings of this paper are in agreement with the conclusions of the consistency studies.

Refer to caption
Figure 3: Models saved at different stages, their training loss, and corresponding results on the testset.
Refer to caption
Figure 4: Accelerating the convergence speed of multi-scale feature signal models
Table 2: Ablation experiment. Mul_s represents the multi-scale supervisory signal. FFTP denotes the Fourier filter structure.
Brain-Turmor SEHPI
Mul_s FFTP Dice IoU Dice IoU
85.3 79.2 78.2 85.1
86.3 80.1 79.1 84.9
86.2 80.3 79.3 85.6
86.9 82.1 81.1 86.3

Ablation experiment Fig4 compares the convergence speeds between CTS-nM and CTS-M. Since models typically enter a smoothing phase after exceeding 10,000 rounds, only the experimental results of 20,000 rounds are compared here. It is evident that the inclusion of multi-scale feature signals significantly accelerates the convergence speed of the models.Tab2 compares the ablation results between two datasets. Here, to observe the result, each experiment was trained with 500,000 samples. It is evident that incorporating multi-scale signals yields better results. Additionally, FFTP can also enhance performance. Adding either element individually does not significantly differ in improving the model, possibly because both involve modifications to the way supervisory signals are added, with the difference lying in the method. This further indicates that adding better supervisory signals can greatly impact the model’s outcomes. Eventually, the model that incorporated both methods was used to train the overall model.

5 Conclusion and Discussion

This paper first establishes a medical image segmentation method based on a consistency model CTS. It not only yields better results but also significantly reduces prediction time. Moreover, by constructing multi-scale feature supervision signals, the training convergence speed is accelerated. Meanwhile, medical image data processed with anisotropic edge enhancement filters can achieve improved outcomes. However, there are still some shortcomings. In our experiments, due to a lack of GTX 4090, we did not train as many as one million rounds as initially planned in the Consistency[18]. Furthermore, we observed that increasing the number of training rounds did not lead to saturation in results, leading us to infer that more training rounds could potentially enhance the model’s performance.

References

  • [1] Azad, R., Heidari, M., Shariatnia, M., Aghdam, E.K., Karimijafarbigloo, S., Adeli, E., Merhof, D.: Transdeeplab: Convolution-free transformer-based deeplab v3+ for medical image segmentation. In: International Workshop on PRedictive Intelligence In MEdicine. pp. 91–102. Springer (2022)
  • [2] Azad, R., Heidari, M., Wu, Y., Merhof, D.: Contextual attention network: Transformer meets u-net. In: International Workshop on Machine Learning in Medical Imaging. pp. 377–386. Springer (2022)
  • [3] Baid, U., Ghodasara, S., Mohan, S., Bilello, M., Calabrese, E., Colak, E., Farahani, K., Kalpathy-Cramer, J., Kitamura, F.C., Pati, S., et al.: The rsna-asnr-miccai brats 2021 benchmark on brain tumor segmentation and radiogenomic classification. arXiv preprint arXiv:2107.02314 (2021)
  • [4] Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q., Wang, M.: Swin-unet: Unet-like pure transformer for medical image segmentation. In: European conference on computer vision. pp. 205–218. Springer (2022)
  • [5] Chen, T., Zhang, R., Hinton, G.: Analog bits: Generating discrete data using diffusion models with self-conditioning. arXiv preprint arXiv:2208.04202 (2022)
  • [6] Dhariwal, P., Nichol, A.: Diffusion models beat gans on image synthesis. Advances in neural information processing systems 34, 8780–8794 (2021)
  • [7] Fernandez, V., Pinaya, W.H.L., Borges, P., Tudosiu, P.D., Graham, M.S., Vercauteren, T., Cardoso, M.J.: Can segmentation models be trained with fully synthetically generated data? In: International Workshop on Simulation and Synthesis in Medical Imaging. pp. 79–90. Springer (2022)
  • [8] Heidari, M., Kazerouni, A., Soltany, M., Azad, R., Aghdam, E.K., Cohen-Adad, J., Merhof, D.: Hiformer: Hierarchical multi-scale representations using transformers for medical image segmentation. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. pp. 6202–6212 (2023)
  • [9] Ho, J., Chan, W., Saharia, C., Whang, J., Gao, R., Gritsenko, A., Kingma, D.P., Poole, B., Norouzi, M., Fleet, D.J., et al.: Imagen video: High definition video generation with diffusion models. arXiv preprint arXiv:2210.02303 (2022)
  • [10] Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Advances in neural information processing systems 33, 6840–6851 (2020)
  • [11] Ho, J., Salimans, T.: Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598 (2022)
  • [12] Li, X., Thickstun, J., Gulrajani, I., Liang, P.S., Hashimoto, T.B.: Diffusion-lm improves controllable text generation. Advances in Neural Information Processing Systems 35, 4328–4343 (2022)
  • [13] Passalacqua, P., Do Trung, T., Foufoula-Georgiou, E., Sapiro, G., Dietrich, W.E.: A geometric framework for channel network extraction from lidar: Nonlinear diffusion and geodesic paths. Journal of Geophysical Research: Earth Surface 115(F1) (2010)
  • [14] Popov, V., Vovk, I., Gogoryan, V., Sadekova, T., Kudinov, M.: Grad-tts: A diffusion probabilistic model for text-to-speech. In: International Conference on Machine Learning. pp. 8599–8608. PMLR (2021)
  • [15] Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (2015)
  • [16] Saharia, C., Chan, W., Chang, H., Lee, C., Ho, J., Salimans, T., Fleet, D., Norouzi, M.: Palette: Image-to-image diffusion models. In: ACM SIGGRAPH 2022 Conference Proceedings. pp. 1–10 (2022)
  • [17] Sohl-Dickstein, J., Weiss, E.A., Maheswaranathan, N., Ganguli, S.: Deep unsupervised learning using nonequilibrium thermodynamics. JMLR.org (2015)
  • [18] Song, Y., Dhariwal, P., Chen, M., Sutskever, I.: Consistency models (2023)
  • [19] Wu, J., Fu, R., Fang, H., Zhang, Y., Yang, Y., Xiong, H., Liu, H., Xu, Y.: Medsegdiff: Medical image segmentation with diffusion probabilistic model. In: Medical Imaging with Deep Learning. pp. 1623–1639. PMLR (2024)
  • [20] Xue, Y., Xu, T., Zhang, H., Long, R., Huang, X.: Segan: Adversarial network with multi-scale l1l_{1} loss for medical image segmentation. Neuroinformatics 16(6), 1–10 (2017)
  • [21] Zhang, L., Zhang, K., Pan, H.: Sunet++: A deep network with channel attention for small-scale object segmentation on 3d medical images. Tsinghua Science and Technology 28(4), 628–638 (2023)
  • [22] Zhou, Z., Rahman Siddiquee, M.M., Tajbakhsh, N., Liang, J.: Unet++: A nested u-net architecture for medical image segmentation. In: Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support. pp. 3–11. Springer International Publishing, Cham (2018)