This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: Institute of Science and Technology for Brain-inspired Intelligence
Fudan University

BerDiff: Conditional Bernoulli Diffusion Model for Medical Image Segmentation

Tao Chen    Chenhui Wang    Hongming Shan
Abstract

Medical image segmentation is a challenging task with inherent ambiguity and high uncertainty, attributed to factors such as unclear tumor boundaries and multiple plausible annotations. The accuracy and diversity of segmentation masks are both crucial for providing valuable references to radiologists in clinical practice. While existing diffusion models have shown strong capacities in various visual generation tasks, it is still challenging to deal with discrete masks in segmentation. To achieve accurate and diverse medical image segmentation masks, we propose a novel conditional Bernoulli Diffusion model for medical image segmentation (BerDiff). Instead of using the Gaussian noise, we first propose to use the Bernoulli noise as the diffusion kernel to enhance the capacity of the diffusion model for binary segmentation tasks, resulting in more accurate segmentation masks. Second, by leveraging the stochastic nature of the diffusion model, our BerDiff randomly samples the initial Bernoulli noise and intermediate latent variables multiple times to produce a range of diverse segmentation masks, which can highlight salient regions of interest that can serve as valuable references for radiologists. In addition, our BerDiff can efficiently sample sub-sequences from the overall trajectory of the reverse diffusion, thereby speeding up the segmentation process. Extensive experimental results on two medical image segmentation datasets with different modalities demonstrate that our BerDiff outperforms other recently published state-of-the-art methods. Our results suggest diffusion models could serve as a strong backbone for medical image segmentation.

Keywords:
Conditional diffusion Bernoulli noise Medical image segmentation.

1 Introduction

Medical image segmentation plays a crucial role in enabling better diagnosis, surgical planning, and image-guided surgery [8]. The inherent ambiguity and high uncertainty of medical images pose significant challenges [5] for accurate segmentation, attributed to factors such as unclear tumor boundaries in brain Magnetic resonance imaging (MRI) images and multiple plausible annotations in lung nodule Computed Tomography (CT) images. Existing medical image segmentation works typically provide a single, deterministic, most likely hypothesis mask, which may lead to misdiagnosis or sub-optimal treatment. Therefore, providing accurate and diverse segmentation masks as valuable references [16] for radiologists is crucial in clinical practice.

Recently, diffusion models [10] have shown strong capacities in various visual generation tasks [19, 20]. However, how to better integrate with discrete segmentation tasks needs further consideration. Although many researches [1, 24] have combined diffusion model with segmentation tasks and made some modifications, these methods do not take full account of the discrete characteristic of segmentation task and still use Gaussian noise as their diffusion kernel. To achieve accurate and diverse segmentation, we propose a novel Conditional Bernoulli Diffusion model for medical image segmentation (BerDiff). Instead of using the Gaussian noise, we first propose to use the Bernoulli noise as the diffusion kernel to enhance the capacity of the diffusion model for segmentation, resulting in more accurate segmentation masks. Moreover, by leveraging the stochastic nature of the diffusion model, our BerDiff randomly samples the initial Bernoulli noise and intermediate latent variables multiple times to produce a range of diverse segmentation masks, which can highlight salient regions of interest (ROI) that can serve as a valuable reference for radiologists. In addition, our BerDiff can efficiently sample sub-sequences from the overall trajectory of the reverse diffusion based on the rationale behind the Denoising Diffusion Implicit Models (DDIM) [23], thereby speeding up the segmentation process.

The contributions of this work are summarized as follows. 1) Instead of using the Gaussian noise, we propose a novel conditional diffusion model based on the Bernoulli noise for discrete binary segmentation tasks, achieving accurate and diverse medical image segmentation masks. 2) Our BerDiff can efficiently sample sub-sequences from the overall trajectory of the reverse diffusion, thereby speeding up the segmentation process. 3) Experimental results on two medical images, CT and MRI, specifically the LIDC-IDRI and BRATS 2021 datasets, demonstrate that our BerDiff outperforms other state-of-the-art methods.

2 Methodology

In this section, we first describe the problem definitions, then demonstrate the Bernoulli forward and diverse reverse processes of our BerDiff, as shown in Fig. 1. Finally, we provide an overview of the training and sampling procedures.

2.1 Problem definition

Refer to caption
Figure 1: Illustration of Bernoulli forward and diverse reverse processes of our BerDiff.

Let us assume that 𝒙H×W×C\bm{x}\in{\mathbb{R}^{H\!\times\!W\!\times\!C}} denotes the input medical image with a spatial resolution of H×WH\!\times\!W and CC number of channels. The ground-truth mask is represented as 𝒚0{0,1}H×W\bm{y}_{0}\!\in\!{{\{0,1\}}^{H\times W}}, where 0 represents background while 11 ROI. Inspired by diffusion-based models such as denoising diffusion probabilistic model (DDPM) and DDIM, we propose a novel conditional Bernoulli diffusion model, which can be represented as pθ(𝒚0|𝒙):=pθ(𝒚0:T|𝒙)d𝒚1:Tp_{\theta}({\bm{y}_{0}}|\bm{x}):=\int p_{\theta}(\bm{y}_{0:T}|\bm{x})\mathrm{d}\bm{y}_{1:T}, where 𝒚1,,𝒚T\bm{y}_{1},\ldots,\bm{y}_{T} are latent variables of the same size as the mask 𝒚0\bm{y}_{0}. For medical binary segmentation tasks, the diverse reverse process of our BerDiff starts from the initial Bernoulli noise 𝒚T(𝒚T;12𝟏)\bm{y}_{T}\sim\mathcal{B}(\bm{y}_{T};\frac{1}{2}\!\cdot\!{\mathbf{1}}) and progresses through intermediate latent variables constrained by the input medical image 𝒙\bm{x} to produce segmentation masks, where 𝟏\mathbf{1} denotes an all-ones matrix of the size H×WH\!\times\!W.

2.2 Bernoulli forward process

In previous generation-related diffusion models, Gaussian noise is progressively added with increasing timestep tt. However, for segmentation tasks, the ground-truth masks are represented by discrete values. To address this, our BerDiff gradually adds more Bernoulli noise using a noise schedule β1,,βT\beta_{1},\ldots,\beta_{T}, as shown in Fig. 1. The Bernoulli forward process q(𝒚1:T|𝒚0)q(\bm{y}_{1:T}|\bm{y}_{0}) of our BerDiff is a Markov chain, which can be represented as:

q(𝒚1:T𝒚0):=\displaystyle q\left(\bm{y}_{1:T}\mid\bm{y}_{0}\right):= t=1Tq(𝒚t𝒚t1),\displaystyle\prod\nolimits_{t=1}^{T}q\left(\bm{y}_{t}\mid\bm{y}_{t-1}\right), (1)
q(𝒚t𝒚t1):=\displaystyle q\left(\bm{y}_{t}\mid\bm{y}_{t-1}\right):= (𝒚t;(1βt)𝒚t1+βt/2),\displaystyle\mathcal{B}(\bm{y}_{t};(1-\beta_{t})\bm{y}_{t-1}+\beta_{t}/2), (2)

where \mathcal{B} denotes the Bernoulli distribution with the probability parameters (1βt)𝒚t1+βt/2(1-\beta_{t})\bm{y}_{t-1}+\beta_{t}/2. Using the notation αt=1βt\alpha_{t}=1-\beta_{t} and α¯t=τ=1tατ\bar{\alpha}_{t}={\textstyle\prod_{\tau=1}^{t}}{\alpha}_{\tau}, we can efficiently sample 𝒚t\bm{y}_{t} at an arbitrary timestep tt as follows:

q(𝒚t𝒚0)=(𝒚t;α¯t𝒚0+(1α¯t)/2)).\displaystyle q\left(\bm{y}_{t}\mid\bm{y}_{0}\right)=\mathcal{B}(\bm{y}_{t};\bar{\alpha}_{t}\bm{y}_{0}+(1-\bar{\alpha}_{t})/2)). (3)
Algorithm 1 Training
repeat
     (𝒙,𝒚0)q(𝒙,𝒚0)(\bm{x},\bm{y}_{0})\sim q\left(\bm{x},\bm{y}_{0}\right)
     tUniform({1,,T})t\sim\mathrm{Uniform}(\{1,\dots,T\})
     ϵ(ϵ;(1α¯t)/2)\bm{\epsilon}\sim\mathcal{B}(\bm{\epsilon};(1-\bar{\alpha}_{t})/2)
     𝒚t=𝒚0ϵ\bm{y}_{t}=\bm{y}_{0}\oplus\bm{\epsilon}
     Calculate Eq. (4)
     Estimate ϵ^(𝒚t,t,𝒙)\hat{\bm{\epsilon}}(\bm{y}_{t},{t},\bm{x})
     Calculate Eq. (6)
     Take gradient descent on θ(Total)\nabla_{\theta}(\mathcal{L}_{\text{Total}})
until converged
Algorithm 2 Sampling
𝒚T(𝒚T;12𝟏)\bm{y}_{T}\sim\mathcal{B}(\bm{y}_{T};\frac{1}{2}\cdot{\mathbf{1}})
for t=Tt=T to 11 do
     𝝁^(𝒚t,t,𝒙)=C(𝒚t,ϵ^(𝒚t,t,𝒙))\hat{\bm{\mu}}(\bm{y}_{t},t,\bm{x})=\mathcal{F}_{C}(\bm{y}_{t},\hat{\bm{\epsilon}}(\bm{y}_{t},t,\bm{x}))
     For DDPM:
     𝒚t1(𝒚t1;𝝁^(𝒚t,t,𝒙))\bm{y}_{t-1}\!\sim\!\mathcal{B}(\bm{y}_{t-1};\hat{\bm{\mu}}(\bm{y}_{t},t,\bm{x}))
     For DDIM:
     𝒚t1(𝒚t1;σt𝒚t+(α¯t1σtα¯t)|𝒚tϵ^(𝒚t,t,𝒙)|+((1α¯t1)(1α¯t)σt)/2)\bm{y}_{t-1}\!\sim\!\mathcal{B}(\bm{y}_{t-1};\sigma_{t}\bm{y}_{t}+(\bar{\alpha}_{t-1}\!-\!\sigma_{t}\bar{\alpha}_{t})|\bm{y}_{t}\!-\!\hat{\bm{\epsilon}}(\bm{y}_{t},t,\bm{x})|+((1-\bar{\alpha}_{t-1})-(1-\bar{\alpha}_{t})\sigma_{t})/2)
end for
return 𝒚0\bm{y}_{0}

To ensure that the objective function described in Sec. 2.4 is tractable and easy to compute, we use the sampled Bernoulli noise ϵ(ϵ;1α¯t2𝟏)\bm{\epsilon}\!\sim\!\mathcal{B}(\bm{\epsilon};\frac{1-\bar{\alpha}_{t}}{2}\!\cdot\!{\mathbf{1}}) to reparameterize 𝒚t\bm{y}_{t} of Eq. (3) as 𝒚0ϵ\bm{y}_{0}\oplus\bm{\epsilon}, where \oplus denotes the logical operation of “exclusive or (XOR)”. Additionally, let \odot denote elementwise product, and Norm()\operatorname{Norm}(\cdot) denote normalizing the input data along the channel dimension and then returning the second channel. The concrete Bernoulli posterior can be represented as:

q(𝒚t1𝒚t,𝒚0)\displaystyle q(\bm{y}_{t-1}\mid\bm{y}_{t},\bm{y}_{0}) =(𝒚t1;θpost(𝒚t,𝒚0)).\displaystyle=\mathcal{B}(\bm{y}_{t-1};\theta_{\text{post}}\left(\bm{y}_{t},\bm{y}_{0}\right)). (4)

where θpost(𝒚t,𝒚0)=Norm([αt[1𝒚t,𝒚t]+1αt2][α¯t1[1𝒚0,𝒚0]+1α¯t12])\theta_{\text{post}}\left(\bm{y}_{t},\bm{y}_{0}\right)=\operatorname{Norm}([\alpha_{t}[1-\bm{y}_{t},\bm{y}_{t}]+\frac{1-\alpha_{t}}{2}]\odot[\bar{\alpha}_{t-1}[1-\bm{y}_{0},\bm{y}_{0}]+\frac{1-\bar{\alpha}_{t-1}}{2}]).

2.3 Diverse reverse process

The diverse reverse process pθ(𝒚0:T)p_{\theta}(\bm{y}_{0:T}) can be also viewed as a Markov chain that starts from the Bernoulli noise 𝒚T(𝒚T;12𝟏)\bm{y}_{T}\sim\mathcal{B}(\bm{y}_{T};\frac{1}{2}\cdot{\mathbf{1}}) and progresses through intermediate latent variables constrained by the input medical image 𝒙\bm{x} to produce diverse segmentation masks, as shown in Fig. 1. The concrete diverse reverse process of our BerDiff can be represented as:

pθ(𝒚0:T𝒙)\displaystyle p_{\theta}(\bm{y}_{0:T}\mid\bm{x}) :=p(𝒚T)t=1Tpθ(𝒚t1𝒚t,𝒙),\displaystyle:=p(\bm{y}_{T})\prod\nolimits_{t=1}^{T}p_{\theta}(\bm{y}_{t-1}\mid\bm{y}_{t},\bm{x}), (5)
pθ(𝒚t1𝒚t,𝒙)\displaystyle p_{\theta}(\bm{y}_{t-1}\mid\bm{y}_{t},\bm{x}) :=(𝒚t1;𝝁^(𝒚t,t,𝒙)).\displaystyle:=\mathcal{B}(\bm{y}_{t-1};\hat{\bm{\mu}}(\bm{y}_{t},t,\bm{x})). (6)

Specifically, we utilize the estimated Bernoulli noise ϵ^(𝒚t,t,𝒙)\hat{\bm{\epsilon}}(\bm{y}_{t},t,\bm{x}) of 𝒚t\bm{y}_{t} to parameterize 𝝁^(𝒚t,t,𝒙)\hat{\bm{\mu}}(\bm{y}_{t},t,\bm{x}) via a calibration function C\mathcal{F}_{C}, as follows:

𝝁^(𝒚t,t,𝒙)=C(𝒚t,ϵ^(𝒚t,t,𝒙))=θpost(𝒚t,|𝒚tϵ^(𝒚t,t,𝒙)|),\displaystyle\hat{\bm{\mu}}(\bm{y}_{t},t,\bm{x})=\mathcal{F}_{C}(\bm{y}_{t},\hat{\bm{\epsilon}}(\bm{y}_{t},t,\bm{x}))=\theta_{\text{post}}(\bm{y}_{t},|\bm{y}_{t}-\hat{\bm{\epsilon}}(\bm{y}_{t},t,\bm{x})|), (7)

where |||\cdot| denotes the absolute value operation.

2.4 Detailed procedure

Here, we provide an overview of the training and sampling procedure in Algorithms 1 and 2. During the training phase, given an image and mask data pair {𝒙,𝒚0}\{\bm{x},\bm{y}_{0}\}, we sample a random timestep tt from a uniform distribution {1,,T}\{1,\dots,T\}, which is used to sample the Bernoulli noise ϵ\bm{\epsilon}.

We then use ϵ\bm{\epsilon} to sample 𝒚t\bm{y}_{t} from q(𝒚t𝒚0)q(\bm{y}_{t}\!\mid\!\bm{y}_{0}), which allows us to obtain the Bernoulli posterior q(𝒚t1𝒚t,𝒚0)q(\bm{y}_{t-1}\!\mid\!\bm{y}_{t},\bm{y}_{0}). We pass the estimated Bernoulli noise ϵ^(𝒚t,t,𝒙)\bm{\hat{\epsilon}}(\bm{y}_{t},t,\bm{x}) through the calibration function C\mathcal{F}_{C} to parameterize pθ(𝒚t1𝒚t,𝒙)p_{\theta}(\bm{y}_{t-1}\!\mid\!\bm{y}_{t},\bm{x}). Based on the variational upper bound on the negative log-likelihood in previous diffusion models [3], we adopt Kullback-Leibler (KL) divergence and binary cross-entropy (BCE) loss to optimize our BerDiff as follows:

KL\displaystyle\mathcal{L}_{\text{KL}} =𝔼q(𝒙,𝒚0)𝔼q(𝒚t𝒚0)[DKL[q(𝒚t1𝒚t,𝒚0)pθ(𝒚t1𝒚t,𝒙)]],\displaystyle=\mathbb{E}_{q(\bm{x},\bm{y}_{0})}\mathbb{E}_{q(\bm{y}_{t}\mid\bm{y}_{0})}[D_{\mathrm{KL}}[q(\bm{y}_{t-1}\mid\bm{y}_{t},\bm{y}_{0})\|p_{\theta}(\bm{y}_{t-1}\mid\bm{y}_{t},\bm{x})]], (8)
BCE\displaystyle\mathcal{L}_{\text{BCE}} =𝔼(ϵ,ϵ^)i,jH,W[ϵi,jlogϵ^i,j+(1ϵi,j)log(1ϵ^i,j)].\displaystyle=-\mathbb{E}_{(\bm{\epsilon},\hat{\bm{\epsilon}})}{\textstyle\sum_{i,j}^{H,W}}[{\epsilon}_{i,j}\log\hat{\epsilon}_{i,j}+(1-\epsilon_{i,j})\log{(1-\hat{\epsilon}_{i,j})}]. (9)

Finally, the overall objective function is presented as:

Total=KL+λBCEBCE,\displaystyle\mathcal{L}_{\text{Total}}=\mathcal{L}_{\text{KL}}+\lambda_{\text{BCE}}\mathcal{L}_{\text{BCE}}, (10)

where λBCE\lambda_{\text{BCE}} is set to 11 in our experiments.

During the sampling phase, our BerDiff first samples the initial latent variable 𝒚T\bm{y}_{T}, followed by iterative calculation of the probability parameters of 𝒚t1\bm{y}_{t-1} for different tt. In Algorithm 2, we present two different sampling strategies from DDPM and DDIM for the latent variable 𝒚t1\bm{y}_{t-1}. Finally, our BerDiff is capable of producing diverse segmentation masks. By taking the mean of these masks, we can further obtain a saliency segmentation mask to highlight salient ROI that can serve as a valuable reference for radiologists. Note that our BerDiff proposes a novel parameterization technology, i.e.\textit{i}.\textit{e}. calibration function, to estimate the Bernoulli noise of 𝒚t\bm{y}_{t}, which is different from previous discrete state diffusion-based models [3, 11, 22].

3 Experiment

3.1 Experimental setup

Dataset and preprocessing The data used in this experiment are obtained from LIDC-IDRI [2, 7] and BRATS 2021 [4] datasets. LIDC-IDRI contains 1,018 lung CT scans with plausible segmentation masks annotated by four radiologists. We adopt a standard preprocessing pipeline for lung CT scans and the train-validation-test partition as in previous work [5, 14, 21]. BRATS 2021 consists of four different sequence (T1, T2, FlAIR, T1CE) MRI images for each patient. All 3D scans are sliced into axial slices and discarded the bottom 80 and top 26 slices. Note that we treat the original four types of brain tumors as one type following previous work [23], converting the multi-target segmentation problem into binary. Our training set includes 55,174 2D images scanned from 1,126 patients, and the test set comprises 3,991 2D images scanned from 125 patients. Finally, the images from LIDC-IDRI and BRAST 2021 are resized to 128×128128\times 128 and 224×224224\times 224, respectively.

Table 1: Ablation results of hyperparameters on LIDC-IDRI.
Loss Estimation GED HM-IoU
Target 16 8 4 1 16
KL\mathcal{L}_{\text{KL}} Bernoulli noise 0.332 0.365 0.430 0.825 0.517
BCE\mathcal{L}_{\text{BCE}} Bernoulli noise 0.251 0.287 0.359 0.785 0.566
BCE+KL\mathcal{L}_{\text{BCE}}+\mathcal{L}_{\text{KL}} Bernoulli noise 0.249 0.287 0.358 0.775 0.575
BCE+KL\mathcal{L}_{\text{BCE}}+\mathcal{L}_{\text{KL}} Ground-truth mask 0.277 0.317 0.396 0.866 0.509
Table 2: Ablation results of diffusion kernel on LIDC-IDRI
Training Diffusion GED HM-IoU
Iteration Kernel 16 8 4 1 16
21,000 Gaussian 0.671 0.732 0.852 1.573 0.020
Bernoulli 0.252 0.287 0.358 0.775 0.575
86,500 Gaussian 0.251 0.282 0.345 0.719 0.587
Bernoulli 0.238 0.271 0.340 0.748 0.596

Implementation Details We implement all the methods with the PyTorch library and train the models on NVIDIA V100 GPUs. All the networks are trained using the AdamW [17] optimizer with a batch size of 32. The initial learning rate is set to 1×1041\times 10^{-4} for BRATS 2021 and 5×1055\times 10^{-5} for LIDC-IDRI. The Bernoulli noise estimation U-net network in Fig. 1 of our BerDiff is the same as previous diffusion-based models [18]. We employ a linear noise schedule for T=1000T=1000 timesteps for all the diffusion models. And we use the sub-sequence sampling strategy of DDIM to accelerate the segmentation process. During mini-batch training of LIDC-IDRI, our BerDiff learn diverse expertise by randomly sampling one from four annotated segmentation masks for each image. Three metrics are used for performance evaluation, including Generalized Energy Distance (GED), Hungarian-matched Intersection over Union (HM-IoU), and Dice coefficient. We compute GED using a varying number of segmentation samples (1, 4, 8, and 16), HM-IoU using 16 samples.

3.2 Ablation study

We start by conducting ablation experiments to demonstrate the effectiveness of different losses and estimation targets, as shown in Table 1. All experiments are trained for 21,000 training iterations on LIDC-IDRI. We first conduct the ablation study of different losses while estimating Bernoulli noise in the top three rows. We find that the combination of KL divergence and BCE loss can achieve the best performance. Then, we conduct an ablation study of selecting estimation target in the bottom two rows. We observe that estimating Bernoulli noise, instead of directly estimating the ground-truth mask, is more suitable for our binary segmentation task. All of these findings are consistent with previous works [3, 10]. Please refer to Appendix 0.A for extra ablation studies on the sampling strategy and sampled timesteps.

Table 3: Results on LIDC-IDRI.
Methods
GED
16
HM-IoU
16
Prob.U-net [14] 0.320±\pm0.03 0.500±\pm0.03
Hprob.U-net [15] 0.270±\pm0.01 0.530±\pm0.01
CAR [13] 0.264±\pm0.00 0.592±\pm0.01
JPro.U-net [26] 0.260±\pm0.00 0.585±\pm0.00
PixelSeg [25] 0.260±\pm0.00 0.587±\pm0.01
SegDiff [1] 0.248±\pm0.01 0.585±\pm0.00
MedSegDiff [24] 0.420±\pm0.03 0.413±\pm0.03
BerDiff (ours) 0.238±\pm0.01 0.596±\pm0.00
Table 4: Results on BRATS 2021.
Methods Dice
nnU-net [12] 88.2
TransU-net [6] 88.6
Swin UNETR [9] 89.0
U-net 89.2
SegDiff [1] 89.3
BerDiff (ours) 89.7
  • \sharp The U-net has the same architecture as the noise estimation network in our BerDiff and previous diffusion-based models.

Refer to caption
Figure 2: Diverse segmentation masks and the corresponding saliency mask of two lung nodules randomly selected in LIDC-IDRI. 𝒙0i\bm{x}^{i}_{0} and 𝒙gti\bm{x}^{i}_{\text{gt}} refer to the ii-th generated and ground-truth segmentation masks, respectively. Saliency Mask is the mean of diverse segmentation masks.

Here, we conduct ablation experiments on our BerDiff with Gaussian or Bernoulli noise, and the results are shown in Table 2. For discrete segmentation tasks, we find that using Bernoulli noise can produce favorable results when training iterations are limited (e.g. 21,000 iterations), and even outperform using Gaussian noise when training iterations are sufficient (e.g. 86,500 iterations). We also provide a more detailed performance comparison between Bernoulli- and Gaussian-based diffusion models over training iterations in Appendix 0.B. In addition, we present a toy example to demonstrate the superiority of Bernoulli diffusion over Gaussian diffusion in Appendix 0.C.

3.3 Comparison to other state-of-the-art methods

Refer to caption
Figure 3: Segmentation masks of four MRI images randomly selected in BRATS 2021. The segmentation masks of diffusion-based models (SegDiff and ours) presented here are saliency segmentation masks.

Results on LIDC-IDRI Here, we present the quantitative comparison results of LIDC-IDRI in Table 3.2, and find that our BerDiff perform well for discrete segmentation tasks. Probabilistic U-net (Prob.U-net), Hierarchical Prob.U-net (Hprob.U-net), and Joint Prob.U-net (JPro.U-net) use conditional variational autoencoder (cVAE) to accomplish segmentation tasks. Calibrated Adversarial Refinement (CAR) employs generative adversarial networks (GAN) to refine segmentation. PixelSeg is based on autoregressive models, while SegDiff and MedSegDiff are diffusion-based models. We have the following two observations: 1) diffusion-based methods have demonstrated significant superiority over traditional approaches based on VAE, GAN, and autoregression models for discrete segmentation tasks; and 2) our BerDiff has outperformed other diffusion-based models that use Gaussian noise as the diffusion kernel. At the same time, we present comparison segmentation results in Fig. 2. Compared to other models, our BerDiff can effectively learn diverse expertise, resulting in more diverse and accurate segmentation masks. Especially for small nodules that can create ambiguity, such as the lung nodule on the left, our BerDiff approach produces segmentation masks that are more in line with the ground-truth masks.

Results on BRATS 2021 Here, we present the quantitative and qualitative comparison results of BRATS 2021 in Table 3.2 and Fig. 3, respectively. We conducted a comparative analysis of our BerDiff with other models such as nnU-net, transformer-based models like TransU-net and Swin UNETR, as well as diffusion-based methods like SegDiff. First, we find that diffusion-based methods have shown superior performance compared to traditional U-net and transformer-based approaches. Besides, the high performance achieved by U-net, which shares the same architecture as our noise estimation network, highlights the effectiveness of the backbone design in diffusion-based models. Moreover, our proposed BerDiff surpasses other diffusion-based models that use Gaussian noise as the diffusion kernel. Finally, from Fig. 3, we find that our BerDiff segments more accurately on parts that are difficult to recognize by the human eye, such as the tumor in the 33rd row. At the same time, we can also generate diverse plausible segmentation masks to produce a saliency segmentation mask. We note that some of these masks may be false positives as shown in the 11st row, but they can be filtered out due to low saliency. Please refer to Appendix 0.D for more examples of diverse segmentation masks generated by our BerDiff.

4 Conclusion

We first propose to use the Bernoulli noise as the diffusion kernel to enhance the capacity of the diffusion model for binary segmentation tasks, achieving accurate and diverse medical image segmentation results. Our BerDiff only focuses on binary segmentation tasks and takes much time during the iterative sampling process as other diffusion-based models; e.g. our BerDiff takes 0.4s to segment one medical image, which is ten times of traditional U-net. In the future, we will extend our BerDiff to the multi-target segmentation problem and implement additional strategies for speeding up the segmentation process.

References

  • [1] Amit, T., Nachmani, E., Shaharbany, T., Wolf, L.: SegDiff: Image segmentation with diffusion probabilistic models. arXiv:2112.00390 (2021)
  • [2] Armato III, S.G., McLennan, G., Bidaut, L., McNitt-Gray, M.F., Meyer, C.R., Reeves, A.P., Zhao, B., Aberle, D.R., Henschke, C.I., Hoffman, E.A., et al.: The lung image database consortium (LIDC) and image database resource initiative (IDRI): a completed reference database of lung nodules on CT scans. Medical physics 38(2), 915–931 (2011)
  • [3] Austin, J., Johnson, D.D., Ho, J., Tarlow, D., van den Berg, R.: Structured denoising diffusion models in discrete state-spaces. Advances in Neural Information Processing Systems 34, 17981–17993 (2021)
  • [4] Baid, U., Ghodasara, S., Mohan, S., Bilello, M., Calabrese, E., Colak, E., Farahani, K., Kalpathy-Cramer, J., Kitamura, F.C., Pati, S., et al.: The RSNA-ASNR-MICCAI BraTS 2021 benchmark on brain tumor segmentation and radiogenomic classification. arXiv:2107.02314 (2021)
  • [5] Baumgartner, C.F., Tezcan, K.C., Chaitanya, K., Hötker, A.M., Muehlematter, U.J., Schawkat, K., Becker, A.S., Donati, O., Konukoglu, E.: PHISeg: Capturing uncertainty in medical image segmentation. In: Medical Image Computing and Computer Assisted Intervention, 2019, Proceedings, Part II 22. pp. 119–127. Springer (2019)
  • [6] Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., Wang, Y., Lu, L., Yuille, A.L., Zhou, Y.: TransUNet: Transformers make strong encoders for medical image segmentation. arXiv:2102.04306 (2021)
  • [7] Clark, K., Vendt, B., Smith, K., Freymann, J., Kirby, J., Koppel, P., Moore, S., Phillips, S., Maffitt, D., Pringle, M., et al.: The cancer imaging archive (TCIA): maintaining and operating a public information repository. Journal of digital imaging 26(6), 1045–1057 (2013)
  • [8] Haque, I.R.I., Neubert, J.: Deep learning approaches to biomedical image segmentation. Informatics in Medicine Unlocked 18, 100297 (2020)
  • [9] Hatamizadeh, A., Nath, V., Tang, Y., Yang, D., Roth, H.R., Xu, D.: Swin UNETR: Swin transformers for semantic segmentation of brain tumors in MRI images. In: Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries, Part I. pp. 272–284. Springer (2022)
  • [10] Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33, 6840–6851 (2020)
  • [11] Hoogeboom, E., Nielsen, D., Jaini, P., Forré, P., Welling, M.: Argmax flows and multinomial diffusion: Learning categorical distributions. Advances in Neural Information Processing Systems 34, 12454–12465 (2021)
  • [12] Isensee, F., Jaeger, P.F., Kohl, S.A., Petersen, J., Maier-Hein, K.H.: nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods 18(2), 203–211 (2021)
  • [13] Kassapis, E., Dikov, G., Gupta, D.K., Nugteren, C.: Calibrated adversarial refinement for stochastic semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 7057–7067 (2021)
  • [14] Kohl, S., Romera-Paredes, B., Meyer, C., De Fauw, J., Ledsam, J.R., Maier-Hein, K., Eslami, S., Jimenez Rezende, D., Ronneberger, O.: A probabilistic U-Net for segmentation of ambiguous images. Advances in neural information processing systems 31 (2018)
  • [15] Kohl, S.A., Romera-Paredes, B., Maier-Hein, K.H., Rezende, D.J., Eslami, S., Kohli, P., Zisserman, A., Ronneberger, O.: A hierarchical probabilistic u-net for modeling multi-scale ambiguities. arXiv preprint arXiv:1905.13077 (2019)
  • [16] Lenchik, L., Heacock, L., Weaver, A.A., Boutin, R.D., Cook, T.S., Itri, J., Filippi, C.G., Gullapalli, R.P., Lee, J., Zagurovskaya, M., et al.: Automated segmentation of tissues using CT and MRI: a systematic review. Academic radiology 26(12), 1695–1706 (2019)
  • [17] Loshchilov, I., Hutter, F.: Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101 (2017)
  • [18] Nichol, A.Q., Dhariwal, P.: Improved denoising diffusion probabilistic models. In: International Conference on Machine Learning. pp. 8162–8171. PMLR (2021)
  • [19] Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., Chen, M.: Hierarchical text-conditional image generation with clip latents. arXiv:2204.06125 (2022)
  • [20] Saharia, C., Chan, W., Saxena, S., Li, L., Whang, J., Denton, E., Ghasemipour, S.K.S., Ayan, B.K., Mahdavi, S.S., Lopes, R.G., et al.: Photorealistic text-to-image diffusion models with deep language understanding. arXiv:2205.11487 (2022)
  • [21] Selvan, R., Faye, F., Middleton, J., Pai, A.: Uncertainty quantification in medical image segmentation with normalizing flows. In: Machine Learning in Medical Imaging. pp. 80–90. Springer (2020)
  • [22] Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., Ganguli, S.: Deep unsupervised learning using nonequilibrium thermodynamics. In: International Conference on Machine Learning. pp. 2256–2265. PMLR (2015)
  • [23] Wolleb, J., Sandkühler, R., Bieder, F., Valmaggia, P., Cattin, P.C.: Diffusion models for implicit image segmentation ensembles. arXiv:2112.03145 (2021)
  • [24] Wu, J., Fang, H., Zhang, Y., Yang, Y., Xu, Y.: MedSegDiff: Medical image segmentation with diffusion probabilistic model. arXiv preprint arXiv:2211.00611 (2022)
  • [25] Zhang, W., Zhang, X., Huang, S., Lu, Y., Wang, K.: PixelSeg: Pixel-by-pixel stochastic semantic segmentation for ambiguous medical images. In: Proceedings of the 3030-th ACM International Conference on Multimedia. pp. 4742–4750 (2022)
  • [26] Zhang, W., Zhang, X., Huang, S., Lu, Y., Wang, K.: A probabilistic model for controlling diversity and accuracy of ambiguous medical image segmentation. In: Proceedings of the 3030-th ACM International Conference on Multimedia. pp. 4751–4759 (2022)

Appendix

Appendix 0.A Ablation study on sampling strategy and timestep

Our BerDiff is compatible with various sampling strategies, and here, we compare the performance of BerDiff using DDPM’s and DDIM’s sampling strategies. The concrete sampling algorithms can be found in Algorithm 2. Our results in Table A1 indicate that for binary segmentation tasks, BerDiff using DDIM’s sampling strategy achieves better performance compared to using DDPM’s. Furthermore, to attain satisfactory performance with limited computational resources, we uniformly sample 10 timesteps from the complete trajectory in all other experiments.

Table A1: Ablation results of sampling strategy and timestep on LIDC-IDRI. The model utilized in this study was trained for 21,000 training iterations.
Configuration Sampled GED HM-IoU
Timestep 16 8 4 1 16
BerDiff + DDPM’s sampling strategy 2 0.441 0.483 0.568 1.076 0.303
10 0.266 0.302 0.377 0.824 0.533
100 0.258 0.296 0.372 0.829 0.539
1000 0.254 0.293 0.369 0.832 0.539
BerDiff + DDIM’s sampling strategy 2 0.432 0.481 0.579 1.167 0.341
10 0.252 0.287 0.358 0.775 0.575
100 0.250 0.284 0.351 0.759 0.582
1000 0.247 0.280 0.348 0.758 0.585
Refer to caption
Figure A1: Performance curves over training iterations for Gaussian- and Bernoulli-based diffusion models on LIDC-IDRI.
Refer to caption
Figure A2: 1D binary classification tasks.

Appendix 0.B Performance curves

Here we present a detailed performance comparison between Bernoulli- and Gaussian-based diffusion models over training iterations in Fig. A1. Results show that employing Bernoulli noise leads to faster convergence and higher performance in contrast to Gaussian noise.

Appendix 0.C Toy example

We provide intuitive insight into why Bernoulli diffusion outperforms Gaussian diffusion by designing and conducting experiments on two simple one-dimensional binary classification tasks. To create the datasets, we apply a predefined conditional probability function, as shown in the first row of Fig. A2, to map an input x[0,1]x\in[0,1] to an output y{0,1}y\in\{0,1\}, which serves as the Ground Truth. For each data configuration, we use a 4-layer MLP for the noise estimation network. We train all runs using a learning rate of 1e-3. During inference, we uniformly sample 1000 points from [0,1][0,1] and generate 100 samples using the diffusion model for each point. We take the mean of the 100 samples as the estimated probability parameter and evaluate performance using mean-squared error (MSE) between the estimated and ground truth probability parameters. Fig. A2 shows that Bernoulli diffusion offers superior training stability, faster convergence, and better fitting of the conditional distribution compared to Gaussian diffusion.

Appendix 0.D More examples of diverse segmentation masks

In this section, we provide additional examples of diverse segmentation masks generated by our BerDiff for LIDC-IDRI and BRATS 2021, as shown in Figs. A3 and A4, respectively.

Refer to caption
Figure A3: More segmentation masks generated by our BerDiff on LIDC-IDRI. 𝒙\bm{x} represents the input medical image. 𝒚0i\bm{y}^{i}_{0} and 𝒚gti\bm{y}^{i}_{\text{gt}} refer to the ii-th generated and ground-truth segmentation masks, respectively. Saliency Mask is the mean of diverse segmentation masks.
Refer to caption
Figure A4: More segmentation masks generated by our BerDiff on BRATS 2021. 𝒙\bm{x} represents the input medical image, while 𝒚gt\bm{y}_{gt} denotes the corresponding ground-truth segmentation mask. 𝒚0i\bm{y}^{i}_{\text{0}} refers to the ii-th generated segmentation mask. Saliency Mask is obtained by calculating the mean of diverse segmentation masks. Note that the variance of the generated segmentation masks is presented in the last row.