This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: Geng LI 22institutetext: The Chinese University of Hong Kong
22email: [email protected]
33institutetext: Qiu DI 44institutetext: Google Research
44email: [email protected]
55institutetext: Lok Ming LUI 66institutetext: The Chinese University of Hong Kong
66email: [email protected]

A deep neural network framework for dynamic multi-valued mapping estimation and its applications thanks: This work is supported by HKRGC GRF (Project ID: 14305919).

Geng LI    Di QIU    Lok Ming LUI
(Received: date / Accepted: date)
Abstract

This paper addresses the problem of modeling and estimating dynamic multi-valued mappings. While most mathematical models provide a unique solution for a given input, real-world applications often lack deterministic solutions. In such scenarios, estimating dynamic multi-valued mappings is necessary to suggest different reasonable solutions for each input. This paper introduces a deep neural network framework incorporating a generative network and a classification component. The objective is to model the dynamic multi-valued mapping between the input and output by providing a reliable uncertainty measurement. Generating multiple solutions for a given input involves utilizing a discrete codebook comprising finite variables. These variables are fed into a generative network along with the input, producing various output possibilities. The discreteness of the codebook enables efficient estimation of the output’s conditional probability distribution for any given input using a classifier. By jointly optimizing the discrete codebook and its uncertainty estimation during training using a specially designed loss function, a highly accurate approximation is achieved. The effectiveness of our proposed framework is demonstrated through its application to various imaging problems, using both synthetic and real imaging data. Experimental results show that our framework accurately estimates the dynamic multi-valued mapping with uncertainty estimation.

Keywords:
dynamic multi-valued mapping, deep neural network framework, uncertainty estimation

1 Introduction

Uncertainty is a significant challenge that arises when making predictions for various tasks, including pose estimation, action prediction, and clinical diagnosis. In many cases, obtaining an accurate and definitive solution is difficult due to missing information or noise. There can be multiple possible interpretations or solutions. Moreover, in complex real-world scenarios, the relationship between inputs and solutions becomes even more intricate. Different inputs may correspond to different numbers of potential solutions. For example, in natural language processing, a sentence may have two different meanings if contextual information is lacking, while another sentence may have three meanings. Similarly, in diagnosing lung lesions, different medical professionals may provide varying diagnoses for a patient’s CT scan showing potential lung damage. This variation in diagnosis can be attributed to incomplete information and the inherent uncertainties in medical imaging. If we consider all the diagnoses provided by different medical professionals as potential solutions, each scan can have a varying number of potential diagnoses (see Fig 1). Therefore, there is a need to develop an effective framework that can address the challenge of uncertainty estimation in real-life scenarios.

Mathematically, the aforementioned problems can be described as follows: Let XX and YY represent the input space and output space, respectively. We are provided with a collection of paired datasets {(xi,yik)k=1𝒩i}i=1𝒯\{{(x_{i},y_{i}^{k})}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{\mathcal{T}}, where yikYy_{i}^{k}\in Y is a possible output corresponding to xiXx_{i}\in X. These paired datasets consist of unorganized pairs, allowing for flexibility as each collection {yik}k=1𝒩i\{y_{i}^{k}\}_{k=1}^{\mathcal{N}_{i}} may contain repetitions. Our objective is to find a suitable mapping ff that fits this dataset. However, the required mapping for this type of dataset is non-standard. For each xXx\in X, f(x)={yjY}j=1𝒩xf(x)=\{y_{j}\in Y\}_{j=1}^{\mathcal{N}_{x}}, where 𝒩x\mathcal{N}_{x}\in\mathbb{N} depends on xx and yjy_{j}’s are distinct for different jj. In other words, the number of plausible outputs associated different xXx\in X may vary, and hence 𝒩x\mathcal{N}_{x} is dynamic depending on xXx\in X. We refer to this type of mapping as a dynamic multi-valued mapping (DMM). To find an optimal DMM that fits the given dataset, we need to solve the following mapping problem:

f=argming:X𝒫f(Y)(g)f=\mathop{\arg\min}\limits_{g:X\rightarrow\mathcal{P}^{f}(Y)}\mathcal{L}(g) (1)

Here, \mathcal{L} is a suitable loss function, such as the L2L^{2} fidelity data loss, that depends on the datasets and applications. In addition to fitting the dataset, it is also desirable to estimate the probability of each plausible output. The likelihood is dependent on the occurrence of an output in the input dataset. However, an immediate challenge we face is how to mathematically model a DMM. Directly implementing such a mapping can be mathematically challenging, particularly when dealing with the diversity of 𝒩x\mathcal{N}_{x}. Motivated by this challenge, we are interested in developing a numerical framework to solve the dynamic multivariate mapping problem described above.

Refer to caption
Figure 1: Some samples of lung CT scans. The first column is lung CT scans and the left column is labels from four experts. Experts provide different annotations for each CT scan, resulting in varying numbers and probabilities of outputs.

In many scenarios, the number of possible outputs 𝒩x\mathcal{N}_{x} is bounded by a fixed number NN\in\mathbb{N}. In other words, for each input xXx\in X, the number of potential outputs is always less than or equal to NN. In this case, a DMM can be viewed as a one-to-NN mapping with uncertainty estimation. To measure the likelihood of each plausible output, we introduce probability measure for each output while disregarding those with zero probabilities. This approach allows us to obtain a sequence of results with a dynamic number of elements 𝒩x\mathcal{N}_{x}, which depends on the input xx. Although the size of 𝒩x\mathcal{N}_{x} is constrained by a predefined value NN, we can effectively handle most real-world problems by selecting an appropriate value for NN. To model a one-to-NN mapping with uncertainty estimation, we can utilize a dictionary C={c1,c2,,cN}C=\{c_{1},c_{2},...,c_{N}\}, also known as a codebook, which consists of a collection of finite variables. Thus, a DMM can be described by two bivariate functions: f:X×CYNf:X\times C\rightarrow Y^{N} and p:C×X[0,1]Np:C\times X\rightarrow[0,1]^{N}. For each j=1,2,,Nj=1,2,...,N, f(x,cj)f(x,c_{j}) represents a potential output associated with input xx, while p(cj|x)p(c_{j}|x) represents the probability of generating such an output f(x,cj)f(x,c_{j}). Outputs with a probability of 0 are discarded. This formulation allows us to effectively represent a DMM as two bivariate functions. In this work, our numerical framework to solve the DMM problem is based on this construction. For ease of computation, the bivariate functions ff and pp are parameterized by deep neural networks. Different parameterizations of the deep neural networks result in different ff and pp functions. Consequently, the optimal DMM can be obtained by optimizing the parameters of the deep neural networks.

Numerous studies have explored the use of auto-encoders to model multi-valued mappings by sampling and decoding latent codes from the latent space. Conditional generation has shown promise in producing multiple outcomes for various tasks. Noteworthy examples include the works of Zhu et al. zhu2017toward and Huang et al. huang2018multimodal , which focus on the task of multi-modal image-to-image translation. Additionally, Zheng et al. zheng2019pluralistic introduced a specialized approach for image inpainting. These methods demonstrate the capability to generate a variety of plausible outputs. However, these methods often lack uncertainty estimation for each plausible output, making them inadequate for solving the DMM problem. One notable exception is the Probabilistic U-Net proposed by Kohl et al. kohl2018probabilistic . Built upon the conditional VAE framework sohn2015learning , the Probabilistic U-Net allows for quantitative performance evaluation thanks to its application and associated datasets. Moreover, the Probabilistic U-Net outperforms many other methods in calibrated uncertainty estimation, including the Image2Image VAE of zhu2017toward . However, it inherits the Gaussian latent representation from CVAE sohn2015learning , which leads to the drawback of posterior collapse, resulting in wrong outputs in the generated samples. In our previous conference paper qiu2021modal , we proposed a preliminary method to address the challenges faced by auto-encoders in modeling multi-valued mappings. By utilizing a discrete representation space to approximate the multi-modal distribution of the output space, our preliminary method aimed to overcome the issue of posterior collapse and provide rough estimates of the conditional probability associated with each output. However, further analysis revealed several drawbacks. Firstly, the preliminary method exhibited repetitions in outputs, where the same output corresponded to different modes, rendering it unsuitable for representing a DMM. Secondly, the performance of the preliminary method suffered when dealing with imbalanced datasets, as the accuracy of the uncertainty estimation associated with different modes was often inaccurate. Lastly, the output results of the preliminary method were found to be inaccurate when dealing with imbalanced and unorganized dataset. These drawbacks make it challenging to use the preliminary method to represent a DMM.

In this work, building upon our previous preliminary model qiu2021modal , we develop a deep neural network framework that is capable of addressing the DMM problem. The bivariate functions ff and pp are parameterized by a deep generative network GθG_{\theta} and a classification network PθP_{\theta}, respectively. The generative network GθG_{\theta} is responsible for generating multiple results as a one-to-NN mapping ff, while the classification network PθP_{\theta} predicts the probabilities associated with these multiple results. The multiple outcomes of a given input xx are associated with a codebook, which is a set of discrete variables c1,c2,,cN{c_{1},c_{2},...,c_{N}}. Specifically, we utilize the generative network GθG_{\theta} to generate multiple results Gθ(x,cj)G_{\theta}(x,c_{j}) based on the discrete variables cjc_{j}. Simultaneously, our classification network PθP_{\theta} predicts the probability Pθ(cj,x)P_{\theta}(c_{j},x) for each variable cjc_{j}, which serves as the probability for each outcome Gθ(x,cj)G_{\theta}(x,c_{j}). There are two key challenges that need to be addressed in the proposed framework. Firstly, to ensure that the deep neural network framework accurately represents a DMM, it is essential that the outputs G(x,cj)G(x,c_{j}) of the generative network differ for different code values cjc_{j}. In other words, for each input xXx\in X, the mapping Gθ(x,)G_{\theta}(x,\cdot) must be injective, guaranteeing the generation of a unique output for each code and preventing duplication. Additionally, another significant challenge is to enrich the codebook with essential information, allowing the generative network to produce diverse and accurate results, while also accurately predicting the probabilities associated with each outcome. To overcome these challenges, we propose a specialized loss function that incorporates the covariance loss and the ETF cross-entropy loss. This loss function enables the generation of diverse, plausible, and unique outputs for each code, as well as giving an accurate uncertainty estimation. Consequently, our framework effectively parameterizes a DMM with a dynamic range of plausible outputs for each input. Even when dealing with imbalanced data, our proposed framework can identify an optimal DMM that fits the given dataset. We evaluate the performance of our proposed framework on various imaging problems, including both synthetic and real images. Experimental results demonstrate the effectiveness of our model in solving dynamic multi-valued mapping problems across a range of imaging applications.

The rest of the paper is organized as follows. The section 2 outlines the primary contributions of this work. We introduce our general framework and model details in 3. In 4, real and synthetic datasets are used in model experiments, and some details are introduced. Finally, we conclude the paper in 5.

2 Contribution

The main contributions of this paper are listed as follows.

  1. 1.

    We propose a notion of dynamic multi-valued mapping (DMM) and formulate a general optimization problem for DMM to address the challenge of computing multiple plausible solutions with uncertainty estimation in real-life scenarios.

  2. 2.

    By considering a DMM by a 11-to-NN mapping with uncertainty estimation, a DMM is represented by two bi-variate functions, which are parameterized by deep neural networks. A specialized loss function based on the covariance loss and the incorporation of the ETF cross-entropy loss is proposed to train the deep neural networks. This enables the framework to be capable of representing a DMM, with dynamic number of plausible outputs.

  3. 3.

    We apply the proposed framework to various practical imaging problems, demonstrating its efficacy and effectiveness in solving real-world challenges.

3 Proposed model

In this section, we will describe our proposed deep neural network framework for computing multiple plausible solutions with uncertainty estimation in real-life scenarios.

3.1 Mathematical formulation of dynamic multi-valued mapping problem

In this subsection, we will first provide a mathematical formulation of our proposed problem. Our objective is to develop a framework that suggests multiple plausible outputs and estimates their likelihood based on a given dataset. Let XX and YY represent the input space and output space, respectively. Suppose we are given a collection of paired datasets 𝒟={{(xi,yik)}k=1𝒩i}i=1𝒯\mathcal{D}=\{\{(x_{i},y_{i}^{k})\}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{\mathcal{T}}. Here, yikYy_{i}^{k}\in Y represents a plausible solution associated with the input xiXx_{i}\in X. In practice, the datasets 𝒟\mathcal{D} can be imbalanced and unorganized. For each ii, yiky_{i}^{k} may be repeated for different values of kk. For example, when labeling the lesion position for ambiguous medical images, different medical experts may provide the same predictions. Consequently, this results in duplicated outputs within 𝒟\mathcal{D}. On the other hand, quantifying the likelihood of each suggested output can be challenging. The extent of repetition in 𝒟\mathcal{D} can provide valuable information for estimating the probability associated with each plausible output.

In this context, our proposed problem can be mathematically formulated as finding an appropriate mapping ff to fit the given dataset 𝒟\mathcal{D}. The main challenge arises from the fact that each input xXx\in X can be associated with multiple plausible outputs. Furthermore, the number 𝒩x\mathcal{N}_{x} of plausible outputs corresponding to different inputs xXx\in X can vary, making it dynamic and dependent on xx. Traditional mappings between XX and YY are inadequate for fitting 𝒟\mathcal{D}. To address this unique dataset, we introduce the concept of a dynamic multi-valued mapping (DMM) defined as follows:

Definition 1

Let XX and YY be two metric spaces. A dynamic multi-valued mapping (DMM) between XX and YY is a mapping f:X𝒫f(Y)𝒫(Y)f:X\to\mathcal{P}^{f}(Y)\subset\mathcal{P}(Y), where 𝒫(Y)\mathcal{P}(Y) denotes the power set of YY and

𝒫f(Y):={𝒮𝒫(Y):𝒮 is non-empty and finite}.\mathcal{P}^{f}(Y):=\{\mathcal{S}\in\mathcal{P}(Y):\mathcal{S}\text{ is non-empty and finite}\}.

In other words, for each xXx\in X, a DMM maps xx to a non-empty and finite subset of YY. We assume that the number of plausible outputs for each input xXx\in X is finite, and that each input xXx\in X must be associated with at least one plausible output. These assumptions generally hold true in most imaging problems.

Our objective is to find an optimal DMM ff that accurately represents the dataset 𝒟\mathcal{D}. Specifically, for each input xix_{i}, we aim to find an optimal ff such that f(xi)f(x_{i}) is a subset containing all plausible outputs yi1y_{i}^{1}, yi2y_{i}^{2},…, yi𝒩iy_{i}^{\mathcal{N}_{i}}. It is important to note that for each ii, the same plausible output yiky_{i}^{k} may appear multiple times for different values of kk. Consequently, the cardinality |f(xi)||f(x_{i})| of the subset is always less than or equal to 𝒩i\mathcal{N}_{i}. Additionally, it is desirable to estimate the likelihood of each plausible output by considering the extent of repetition of each output in 𝒟\mathcal{D}. Next, we will discuss how we can mathematically formulate this problem.

For this purpose, the first task is to mathematically model a DMM. In most real-world scenarios, the number of plausible output 𝒩x\mathcal{N}_{x} can be bounded by a fixed number NN\in\mathbb{N} for all xx. That means the number of possible outputs for each input is at most NN. In this case, a DMM can be represented by a one-to-NN mapping f:XYNf:X\to Y^{N} or f(x)=(y1,y2,,yN)f(x)=(y_{1},y_{2},...,y_{N}) for xX\forall x\in X, together with an uncertainty estimation p(j|x)p(j|x). p(j|x)p(j|x) measures the probaility or likelihood of the solution yjy_{j}. If p(j|x)=0p(j|x)=0, the output yjy_{j} is discarded and we can simply set yj=0y_{j}=0. The probability measure pp helps us to model the dynamic nature of 𝒩x\mathcal{N}_{x}. More specifically,

𝒩x=N|{j:p(j|x)=0}|,\mathcal{N}_{x}=N-|\{j:p({j}|x)=0\}|,

where |||\cdot| is the cardinality of a set. Also, we require that yjyky_{j}\neq y_{k} if jkj\neq k, p(j|x)0p(j|x)\neq 0 and p(k|x)0p(k|x)\neq 0. This requirement is necessary to ensure that (f,p)(f,p) can effectively represent a DMM.

Under this setup, we can formulate our problem of fitting 𝒟={{(xi,yik)}k=1𝒩i}i=1𝒯\mathcal{D}=\{\{(x_{i},y_{i}^{k})\}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{\mathcal{T}} as an optimization problem over the space of DMMs that minimizes:

E1(f,p)=1𝒯i=1𝒯1𝒩i\displaystyle E_{1}(f,p)=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}} k=1𝒩ilog(1p(sik|xi)αd(f(xi)sik,yik))\displaystyle\sum^{\mathcal{N}_{i}}_{k=1}\log\left(\frac{1}{p(s_{i}^{k}|x_{i})^{\alpha}}d(f(x_{i})_{s_{i}^{k}},\ y_{i}^{k})\right) (2)
such that sik=argmins=1,,Nd(f(xi)s,yik),\displaystyle s_{i}^{k}=\mathop{\arg\min}\limits_{s=1,...,N}d^{\prime}(f(x_{i})_{s},\ y_{i}^{k}),
j=1Np(j|xi)=1,xiX.\displaystyle\sum^{N}_{j=1}p(j|x_{i})=1,\ \forall\ x_{i}\in X.

where f(x)sf(x)_{s} is the ss-th item of f(x)f(x), d(,)d(\cdot,\cdot) is the data fitting term, and d(,)d^{\prime}(\cdot,\cdot) is the distance functions for index choice. α>0\alpha>0 is a fixed parameter. The objective E1(f,p)E_{1}(f,p) aims to encourage the suggested solutions f(xi)f(x_{i}) to closely match the given dataset 𝒟\mathcal{D} by minimizing the discrepancy between the mapped outputs and the true values. Moreover, minimizing E1(f,p)E_{1}(f,p) promotes larger values of p(sik|xi)p(s_{i}^{k}|x_{i}) when the mapped outputs f(xi)sikf(x_{i})_{s_{i}^{k}} appear more frequently in the paired dataset {xi,yik}k=1𝒩i\{x_{i},y_{i}^{k}\}_{k=1}^{\mathcal{N}_{i}}. This allows pp to capture the repetitive patterns in the dataset.

To ensure the capability of ff to represent a DMM, we impose a constraint on f(x)f(x) that for each xXx\in X, such that (f(x))j(f(x))k(f(x))_{j}\neq(f(x))_{k} if jkj\neq k, and both p(j|x)0p(j|x)\neq 0 and p(k|x)0p(k|x)\neq 0. In order to effectively enforce this property, we leverage the concept of a codebook in formulating ff and pp. Let C={c1,c2,,cN}C=\{c_{1},c_{2},...,c_{N}\} be a codebook, where each cjmc_{j}\in\mathbb{R}^{m}. We can now express ff and pp as bivariate functions:

f:X×CYNwheref(x,cj)=yj,\displaystyle f:X\times C\to Y^{N}\quad\text{where}\quad f(x,c_{j})=y_{j},
p:C×X[0,1]wherep(cj,x)=p(j|x).\displaystyle p:C\times X\to[0,1]\quad\text{where}\quad p(c_{j},x)=p(j|x).

By formulating ff and pp in this manner, we can optimize the codebook to control the properties of ff and pp, ensuring their suitability for representing a DMM. Our optimization can now be rewritten as finding two bivariate functions f:X×CYNf:X\times C\to Y^{N} and p:C×X[0,1]p:C\times X\to[0,1] by minimizing:

E2(f,p)=1𝒯i=1𝒯1𝒩i\displaystyle E_{2}(f,p)=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}} k=1𝒩ilog(1p(sik,xi)αd(f(xi,csik),yik))\displaystyle\sum^{\mathcal{N}_{i}}_{k=1}\log\left(\frac{1}{p(s_{i}^{k},x_{i})^{\alpha}}d(f(x_{i},c_{s_{i}^{k}}),\ y_{i}^{k})\right) (3)
such that sik=argmins=1,,Nd(f(xi,cs),yik),\displaystyle s_{i}^{k}=\mathop{\arg\min}\limits_{s=1,...,N}d^{\prime}(f(x_{i},c_{s}),\ y_{i}^{k}),
j=1Np(j,xi)=1,xiX.\displaystyle\sum^{N}_{j=1}p(j,x_{i})=1,\ \forall\ x_{i}\in X.

Note that obtaining the code index from the codebook by traversing all the output results can introduce a significant computational burden, especially when dealing with complex output spaces. Additionally, the choice of distance function, denoted as d(,)d^{\prime}(\cdot,\cdot), defined on the output space can greatly impact the performance of our model. If the structure of the output space is excessively complex or overly simplistic, it may lead to suboptimal results when using certain distance functions. To address these challenges, we introduce the concept of cluster mapping, denoted as z:(X,Y)mz:(X,Y)\rightarrow\mathbb{R}^{m}. This mapping allows us to bypass the need for traversing the entire output space and instead focus on finding the most suitable distance metric within the discrete codebook, leading to improved efficiency and effectiveness. The problem of obtaining the index of code can then be formulated as follows:

sik=argmins=1,,Nz(xi,yik)cs2s_{i}^{k}=\mathop{\arg\min}\limits_{s=1,...,N}||z(x_{i},y_{i}^{k})-\ c_{s}||^{2} (4)

Additionally, the codebook can also be simultaneously optimized with a suitable regularization to obtain the best collection of codes to capture multiple outputs. The final optimization problem can now be written as finding optimal ff, pp, and zz, which minimizes:

E3(f,p,z,C)=1𝒯i=1𝒯1𝒩i\displaystyle E_{3}({f,p,z,C})=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}} k=1𝒩ilog(1p(csik|xi)αd(f(xi,csik),yik))\displaystyle\sum^{\mathcal{N}_{i}}_{k=1}\log\left(\frac{1}{p(c_{s_{i}^{k}}|x_{i})^{\alpha}}d(f(x_{i},c_{s_{i}^{k}}),\ y_{i}^{k})\right) (5)
+βd′′(z(xi,yik),csik)+γd′′′(C)\displaystyle+\beta d^{\prime\prime}(z(x_{i},y_{i}^{k}),c_{s_{i}^{k}})+\gamma d^{\prime\prime\prime}(C)
such that csik=argmincCz(xi,yik)c2,\displaystyle c_{s_{i}^{k}}=\mathop{\arg\min}\limits_{c\in C}||z(x_{i},y_{i}^{k})-\ c||^{2},
j=1Np(cj|xi)=1,xiX.\displaystyle\sum^{N}_{j=1}p(c_{j}|x_{i})=1,\ \forall\ x_{i}\in X.

where β,γ>0\beta,\gamma>0 are weight parameters. d′′(,)d^{\prime\prime}(\cdot,\cdot) is a regularization term to restrict zz. d′′′d^{\prime\prime\prime} is the regularization term to control the property of CC.

The primary challenge lies in effectively modeling ff, pp, and zz in a way that enables simultaneous optimization of the multi-valued mapping and uncertainty estimation. To address this challenge, we adopt a strategy of formulating the problem using a deep neural network, whereby the optimization problem is solved through training the network parameters. The specific details of this approach will be elaborated upon in the subsequent subsection.

3.2 Deep neural network framework for DMM problem

As outlined in the previous subsection, the problem at hand can be formulated as an optimization problem involving three bi-variate mappings: ff, pp, and zz. In this work, we propose to parameterize these mappings using deep neural networks. In this subsection, we will provide a comprehensive explanation of our proposed deep neural network framework for the dynamic multi-valued mapping (DMM) problem.

3.2.1 Overall Network structure

Our objective is to develop a deep neural network framework to solve the optimization problem 5. To achieve this, we parameterize the mappings ff, pp, and zz using deep neural networks Gθ,PφG_{\theta},P_{\varphi} and EϕE_{\phi} respectively. The network structures are illustrated in the Fig. 2. The framework for formulating the DMM is shown in Fig. 3.

Refer to caption
Figure 2: The architecture of our model in the training process.
Refer to caption
Figure 3: The architecture of our framework for modeling DMM.

The deep neural network GθG_{\theta} represents the bi-variate mapping ff. In other words, we have f(x,cj)=Gθ(x,cj)f(x,c_{j})=G_{\theta}(x,c_{j}). It is depicted within the region bounded by the red dotted boundary in Fig. 2. The input xXx\in X is passed through an embedding network, which generates a latent representation ll. This latent representation captures the meaningful features of xx. Subsequently, the latent representation, along with the codebook C={c1,c2,,cN}C=\{c_{1},c_{2},...,c_{N}\}, is fed into another deep generative network that produces plausible outputs f(x,c1),f(x,c2),,f(x,cN){f(x,c_{1}),f(x,c_{2}),...,f(x,c_{N})}. Here, we assume that the discrete codebook CC captures shared label information across different xXx\in X. It is worth noting that different parameters θ\theta of the deep neural network result in different one-to-NN mappings. Within this framework, we can effectively search for the optimal ff by optimizing θ\theta.

Similarly, the deep neural network PφP_{\varphi} represents the bi-variate mapping pp. In other words, p(cj|x)=Pφ(cj,x)p(c_{j}|x)=P_{\varphi}(c_{j},x). It is shown within the region bounded by the green dotted boundary in Fig. 2. The input xXx\in X is first passed through an embedding network, which produces a latent vector hh. By performing a matrix multiplication AhAh and applying a softmax operation, we obtain a probability vector p=(p(c1|x),p(c2|x),,p(cN|x))p=(p(c_{1}|x),p(c_{2}|x),...,p(c_{N}|x)). The output of PθP_{\theta} estimates the probability p(cj|x)p(c_{j}|x) for each plausible output f(x,cj)f(x,c_{j}). The choice of the fixed matrix AA is crucial for accurately predicting uncertainty estimation, and we will discuss this in detail later. In addition, we fuse the embedding networks for GθG_{\theta} and PθP_{\theta} as one named EθE_{\theta}. This reduces the parameters of the model and improves the efficiency of training. For simplification, we abbreviate Gθ(Eθ(x),cj)G_{\theta}(E_{\theta}(x),c_{j}) and Pφ(cj,Eθ(x))P_{\varphi}(c_{j},E_{\theta}(x)) to Gθ(x,cj)G_{\theta}(x,c_{j}) and Pφ(cj,x)P_{\varphi}(c_{j},x) respectively.

To solve the optimization problem 5, we introduce the cluster mapping zz. The cluster mapping takes xXx\in X and an associated plausible output yYy\in Y from the dataset as the input and output a vector, which is of the same dimension of the code in the codebook. To parameterize zz, we utilize another deep neural network EϕE_{\phi}. In other words, z(x,y)=Eϕ(x,y)z(x,y)=E_{\phi}(x,y). As shown within the region bounded by the blue dotted boundary in Fig. 2, xXx\in X and yYy\in Y are fed into an embedding network to output a vector Eϕ(x,y)E_{\phi}(x,y). With Eϕ(x,y)E_{\phi}(x,y), we can find the code in CC closest to Eϕ(x,y)E_{\phi}(x,y) that solves 4.

Under this setting, all mappings ff, pp, and zz to be optimized are parameterized using deep neural networks. Therefore, they can be optimized by training the deep neural networks to obtain the optimal parameters that minimize a loss function defined by the energy functional in our optimization problem.

3.2.2 Optimization of discrete codebook CC

A crucial component of our proposed framework is the use of a discrete codebook C={c1,c2,,cN}mC=\{c_{1},c_{2},...,c_{N}\}\subset\mathbb{R}^{m} to represent a one-to-NN mapping. For each xXx\in X, we can obtain NN plausible outputs: f(x,c1)f(x,c_{1}), f(x,c2)f(x,c_{2}), …, f(x,cN)f(x,c_{N}). Additionally, we can estimate the corresponding probabilities associated with each output: p(c1,x)p(c_{1},x), p(c2,x)p(c_{2},x), …, p(cN,x)p(c_{N},x).

It is important to highlight that for our framework to effectively represent a DMM, the following condition must hold for any jkj\neq k: if both p(cj,x)p(c_{j},x) and p(ck,x)p(c_{k},x) are non-zero, then f(x,cj)f(x,ck)f(x,c_{j})\neq f(x,c_{k}). In other words, each code in the codebook should be associated with a distinct output for every xXx\in X. The choice of the codebook is therefore crucial in enforcing this requirement.

In our framework, for every xXx\in X, a code cjCc_{j}\in C corresponds to a plausible output Gθ(x,cj)G_{\theta}(x,c_{j}). It is important to note that cjmc_{j}\in\mathbb{R}^{m}, where usually the dimension mm of cjc_{j} is significantly smaller than the dimension of the output space YY. Consequently, the generator produces plausible outputs Gθ(x,cj)YG_{\theta}(x,c_{j})\in Y that have a much higher dimension. The separability of the codes cjc_{j} encourages the separability of the corresponding outputs f(x,cj)f(x,c_{j}). Conversely, if two codes cjc_{j} and ckc_{k} are close to each other in the codebook CC, the plausible outputs f(x,cj)f(x,c_{j}) and f(x,ck)f(x,c_{k}) will also be close to each other. This can hinder the capability of our framework to effectively represent a DMM.

In practice, although the discrete codebook inherently lends itself to modeling multi-modal label data {(xi,yik)k=1𝒩i}i=1T\{(x_{i},y_{i}^{k})_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{T}, the occurrence of similar codes within the codebook during continuous updates can lead to highly similar results. To prevent this repetition and ensure diversity in the generated outcomes, it becomes crucial to maximize the separation between each code in the codebook CC.

To maximize the separation between each code in the codebook, our strategy is to reduce the mutual correlation among vectors in the codebook. For this purpose, we introduce the following covariance loss.

Definition 2

Let 𝒳={x1,x2,,xN}m\mathcal{X}=\{x_{1},x_{2},...,x_{N}\}\subset\mathbb{R}^{m} be a finite subset of m\mathbb{R}^{m} with |x1|==|xN|=1|x_{1}|=...=|x_{N}|=1. The covariance loss with a threshold τ\tau is defined as

Lcovτ(𝒳)=Tτ(𝒳T𝒳IN)F2𝕀{Tτ(𝒳T𝒳IN)}L_{cov}^{\tau}(\mathcal{X})=\frac{||T^{\tau}(\mathcal{X}^{T}\mathcal{X}-I_{N})||_{F}^{2}}{\mathbb{I}\{T^{\tau}(\mathcal{X}^{T}\mathcal{X}-I_{N})\}} (6)

where ||||F||\cdot||_{F} is the Frobenius norm, Tτ:N×NN×NT^{\tau}:\mathbb{R}^{N\times N}\to\mathbb{R}^{N\times N} given by:

Tτ(M)={Mij, if |Mij|>τ;0, otherwise,T^{\tau}(M)=\begin{cases}M_{ij},\text{ if }|M_{ij}|>\tau;\\ 0,\text{ otherwise},\end{cases} (7)

where MN×NM\in\mathbb{R}^{N\times N} and MijM_{ij} is the ii-th row jj-th column entry of MM. Thus, Also, 𝕀:N×N\mathbb{I}:\mathbb{R}^{N\times N}\to\mathbb{N}, where 𝕀(M)\mathbb{I}(M) is the number of non-zero entries of MM.

To understand the meaning of the covariance loss, observe that all diagonal entries of 𝒳T𝒳N×N\mathcal{X}^{T}\mathcal{X}\in\mathbb{R}^{N\times N} are equal to 1. Thus, 𝒳T𝒳IN\mathcal{X}^{T}\mathcal{X}-I_{N} has 0 on its diagonal. Each non-diagonal entry of 𝒳T𝒳IN\mathcal{X}^{T}\mathcal{X}-I_{N} represents the inner product between two distinct data points in 𝒳\mathcal{X}. By applying a threshold τ\tau to Tτ(𝒳T𝒳IN)T^{\tau}(\mathcal{X}^{T}\mathcal{X}-I_{N}), we retain only the entries that exceed τ\tau. These entries signify pairs of data points that are close to each other, as their inner product surpasses the threshold. The covariance loss measures the mean squared sum of all non-zero entries in Tτ(𝒳T𝒳IN)T^{\tau}(\mathcal{X}^{T}\mathcal{X}-I_{N}). Minimizing the covariance loss LcovτL_{cov}^{\tau} aims to encourage a greater separation between the data points in 𝒳\mathcal{X}. By this loss function, we are only concerned with those vectors that are not almost orthogonal, i.e., inner products greater than the threshold. It can help us focus on preventing the occurrence of similar codes within the selected codes. In particular, when Lcovτ=0L_{cov}^{\tau}=0, it implies that the inner product between any pair of data points in 𝒳\mathcal{X} is below the threshold τ\tau. This indicates that the separability of each pair of data points in 𝒳\mathcal{X} satisfies a set tolerance, promoting greater distinctiveness among the data points.

Using LcovτL_{cov}^{\tau}, we can encourage the separation of the codes in the codebook C={c1,c2,,cN}C=\{c_{1},c_{2},...,c_{N}\}. To achieve this, we first normalize the codes in the codebook to C~={c1|c1|,c2|c2|,,cN|cN|}\widetilde{C}=\{\frac{c_{1}}{|c_{1}|},\frac{c_{2}}{|c_{2}|},...,\frac{c_{N}}{|c_{N}|}\}. By minimizing the covariance loss Lcovτ(C~)L_{cov}^{\tau}(\widetilde{C}), we can effectively separate the normalized codes in the codebook CC. This, in turn, promotes the separability of the corresponding plausible outputs f(x,cj)f(x,c_{j}) in the output space YY, ensuring a diverse set of generated results for each input xx.

Furthermore, choosing suitable parameters for the size of the codebook NN and the dimensionality of the codes mm is also important. For a set of unit vectors {v1,,vN}\{v_{1},...,v_{N}\} in m\mathbb{R}^{m}, if NmN\leq m, it is possible for their inner products to be all equal to zero, ensuring perfect orthogonality. However, in practical applications, the hyperparameters NN and mm of the codebook may require a broader range of choices to adapt to different datasets. It is not always the case that NmN\leq m. In situations where N>mN>m, we require the variables in the codebook to be ”almost orthogonal.” The Kabatjanskii-Levenstein bound for almost orthogonal vectors tao2013almostorthogonal helps us to decide on appropriate choices for the values of NN, mm, and the threshold tt in such cases. This theorem provides theoretical guidance on the trade-offs between the codebook size NN, the code dimensionality mm, and the degree of orthogonality required, allowing us to configure these parameters effectively for different datasets.

Theorem 3.1

Let v1,v2,,vmv_{1},v_{2},...,v_{m} be unit vectors in n\mathbb{R}^{n} such that <vi,vj>An1/2\|<v_{i},v_{j}>\|\leq An^{-1/2} for all distinct i,ji,j, 12An2\frac{1}{2}\leq A\leq\frac{\sqrt{n}}{2}, then we have m(CnA2)CA2m\leq(\frac{Cn}{A^{2}})^{CA^{2}} for some absolute constant CC.

For the special case when the hyperparameter A=1/2A=1/2, we have the following theorem tao2013almostorthogonal :

Theorem 3.2

Let v1,v2,,vmv_{1},v_{2},...,v_{m} be unit vectors in n\mathbb{R}^{n} such that |<vi,vj>|12n1/2|<v_{i},v_{j}>|\leq\frac{1}{2n^{1/2}} for all distinct i,ji,j. Then, m<2nm<2n.

Therefore, in our case, we set the hyperparameter A=1/2A=1/2. Then, the threshold τ\tau is chosen as 1/2m1/21/2m^{-1/2}. As a condition on the hyperparameter choices, we must have N<2mN<2m, where NN is the size of the codebook and mm is the dimensionality of the codes. In our work, we carefully select the values of NN and mm such that the separability tolerance can be achieved using this threshold. With these parameters, we will incorporate the thresholded covariance loss LcovτL_{cov}^{\tau} in the overall loss function for training the deep neural network to solve the optimization problem in 5. Our experimental results demonstrate the powerful impact of the covariance loss in improving the separability and diversity of the generated outputs.

In practice, the number of codes NN is often larger than the actual number of plausible outputs. This means that 𝒩xN\mathcal{N}_{x}\ll N for all xXx\in X, resulting in some codes being unused and left idle. Initially, the codes are configured such that the inner products between distinct pairs are less than the threshold. However, when updating the codes that have been used, they tend to accumulate in similar positions, reducing their separability. Minimizing the covariance loss plays a crucial role in addressing this issue. When minimizing the covariance loss, only the active codes are modified, while the inner products involving non-active codes remain at 0 even after applying the threshold function TτT^{\tau}. This ensures that the focus is on adjusting the positions of the codes that are actually contributing to the plausible outputs, enhancing their separability and promoting diversity in the output space.

Another important challenge to address is that when solving the optimization problem 5, the following condition 4 has to be considered:

csik=argmincC||(Eϕ(xi,yik)c||2c_{s_{i}^{k}}=\mathop{\arg\min}\limits_{c\in C}||(E_{\phi}(x_{i},y_{i}^{k})-\ c||^{2}

During the training process, a data-label pair (xi,yik)(x_{i},y_{i}^{k}) is randomly sampled from the unorganized dataset (xi,yik)k=1𝒩i{(x_{i},y_{i}^{k})}_{k=1}^{\mathcal{N}_{i}}, and fed into the embedding network EϕE_{\phi}. Subsequently, the nearest code cc is selected from the codebook for the corresponding feature Eϕ(xi,yik)E_{\phi}(x_{i},y_{i}^{k}), as follows:

c(xi,yik)=csik,where sik=argmins=1,,NcsEϕ(xi,yik)2.c(x_{i},y_{i}^{k})=c_{s_{i}^{k}},\ \text{where }\ s_{i}^{k}=\mathop{\arg\min}\limits_{s=1,...,N}\|c_{s}-E_{\phi}(x_{i},y_{i}^{k})\|^{2}. (8)

Note that the selected code c(xi,yik)c(x_{i},y_{i}^{k}) replaces the original feature Eϕ(xi,yik)E_{\phi}(x_{i},y_{i}^{k}) as the input for the generative network GθG_{\theta}. However, this approach presents a potential challenge: there is no direct gradient of Eϕ(xi,yik)E_{\phi}(x_{i},y_{i}^{k}) from the data fidelity term in the backward propagation process. To address this problem, we use a simple gradient approximation method, following the approach of van2017neural . The key idea is to copy the gradient of the selected code cc and assign it to the feature Eϕ(xi,yik)E_{\phi}(x_{i},y_{i}^{k}) so that the parameters ϕ\phi of the embedding network can be updated using the gradient information from the loss function. Specifically, in the forward pass, we directly input the feature representation cc to the generator GθG_{\theta}. In the backward computation, we directly assign the gradient cL\nabla_{c}L to the embedding network EϕE_{\phi}. To ensure that this gradient approximation is meaningful, we need to make the output Eϕ(xi,yik)E_{\phi}(x_{i},y_{i}^{k}) as close as possible to the selected code cc. To achieve this, we add an extra regularization loss for EϕE_{\phi} to the overall loss function, given by:

Ezreg(ϕ)=βi=1Tk=1𝒩i|Eϕ(xi,yik)sg[c(xi,yik)]|2,E_{zreg}(\phi)=\beta\sum_{i=1}^{T}\sum_{k=1}^{\mathcal{N}_{i}}|E_{\phi}(x_{i},y_{i}^{k})-\text{sg}[c(x_{i},y_{i}^{k})]|^{2},

where β>0\beta>0 and sg is the stop-gradient operation (identity in the forward pass, zero derivative in the backward pass, and easily implemented in neural network algorithms). This regularization loss encourages the embedding network to output features that are close to the selected codes, without backpropagating gradients to the codebook CC itself.

The rationale behind this approach is that a learnable codebook, while exhibiting a slower learning capability, is required to capture more information from the data. Therefore, we propose to update the codebook CC using a dictionary learning algorithm, Vector Quantization (VQ) van2017neural , which computes the exponential moving average of the corresponding embedding network’s outputs. Our ablation analysis 10 has shown that a learnable codebook outperforms a fixed codebook.

3.2.3 Probability prediction

Another crucial component in our framework is the estimation of the probability associated with each plausible output. In practice, the ground truth probability of each plausible output associated with an input is not known. Our goal is to estimate these probabilities from the sample training dataset {{(xi,yik)}k=1𝒩i}i=1T\{\{(x_{i},y_{i}^{k})\}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{T}. For each input xix_{i}, recall that the collection of sampled plausible outputs {yik}k=1𝒩i\{y_{i}^{k}\}_{k=1}^{\mathcal{N}_{i}} can contain repeated values. For example, when labeling lesions of a medical image, different medical experts might provide the same label, resulting in repeated labels in the dataset. The more repetitions of a particular label, the higher the estimated probability associated with that plausible output. The intuition behind this approach is that the frequency of a plausible output in the sample dataset can serve as a proxy for its true probability. The more often a plausible output appears, the more likely it is to be the correct label for the given input. By leveraging this idea, we can estimate the probability distribution over the plausible outputs for each input, even though the ground truth probabilities are not known.

In our framework, the probability p:C×X[0,1]p:C\times X\to[0,1] is parameterized by a deep neural network PφP_{\varphi}. Given an input xix_{i} from the training dataset, it is fed into the network to obtain a feature vector hmh\in\mathbb{R}^{m}. Next, we compute a vector p~iN\tilde{p}_{i}\in\mathbb{R}^{N} as p~i=Ah\tilde{p}_{i}=Ah, where AN×mA\in\mathbb{R}^{N\times m} is a suitable matrix. We then pass p~i\tilde{p}_{i} through a softmax operation to obtain the probability vector pip_{i}, where:

pij=ep~ijj=1mep~ij.p_{ij}=\frac{e^{\tilde{p}_{ij}}}{\sum_{j=1}^{m}e^{\tilde{p}_{ij}}}. (9)

Here, pijp_{ij} represents the probability of the plausible output f(xi,cj)f(x_{i},c_{j}). If pij=0p_{ij}=0, the output f(xi,cj)f(x_{i},c_{j}) is considered meaningless and can be ignored.

The choice of the matrix AA is crucial in this setup. The value of pijp_{ij} is related to the inner product between the jj-th row of AA and the vector p~i\tilde{p}_{i}. Specifically, pijp_{ij} will be larger if p~i\tilde{p}_{i} is closer to the jj-th row of AA. One possible approach is to learn the matrix AA during the training process, allowing the network to discover the optimal projection that captures the relationship between the input features and the plausible output probabilities.

However, when training on an imbalanced dataset, the learnable vectors of the minority classes may collapse, a phenomenon known as “minority collapse”. To alleviate this issue, we define AA using the simplex equiangular tight frame (ETF) papyan2020prevalence . This results in a fixed ETF classifier, as proposed in yang2022we .

The simplex equiangular tight frame (ETF) is formally defined as follows:

Definition 3

(Simplex Equiangular Tight Frame) A collector of vectors mim,i=1,2,,N,mN1m_{i}\in\mathbb{R}^{m},i=1,2,...,N,m\geq N-1, is said to be a simplex equiangular tight frame if:

M=NN1U(IN1N1N1NT),M=\sqrt{\frac{N}{N-1}}U(I_{N}-\frac{1}{N}1_{N}1_{N}^{T}),

where M=[m1,,mN]m×NM=[m_{1},...,m_{N}]\in\mathbb{R}^{m\times N},Um×NU\in\mathbb{R}^{m\times N} allows a rotation and satisfies UTU=INU^{T}U=I_{N}, INI_{N} is the identity matrix, and 1N1_{N} is an all-ones vector.

All vectors in a simplex ETF have an equal l2l_{2} norm and the same pair-wise angle, i.e.: miTmj=NN1δi,j1N1,i,j1,2,,Nm_{i}^{T}m_{j}=\frac{N}{N-1}\delta_{i,j}-\frac{1}{N-1},\quad\forall i,j\in{1,2,...,N} where δi,j\delta_{i,j} equals 1 when i=ji=j and 0 otherwise. The pair-wise angle 1N1-\frac{1}{N-1} is the maximal equiangular separation of NN vectors in m\mathbb{R}^{m} papyan2020prevalence .

We can then define AA whose jj-th row is given by mjm_{j}. To find the optimal pp, we optimize the parameters of PφP_{\varphi} given a training dataset {{(xi,yik)}k=1𝒩i}i=1T\{\{(x_{i},y_{i}^{k})\}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{T} to minimize E3E_{3} in our optimization problem 5. In particular, the first term in E3E_{3} involves PφP_{\varphi}:

1𝒯i=1𝒯1𝒩ik=1𝒩ilog(1p(csik|xi)αd(f(xi,csik),yik))=\displaystyle\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}}\sum^{\mathcal{N}_{i}}_{k=1}\log\left(\frac{1}{p(c_{s_{i}^{k}}|x_{i})^{\alpha}}d(f(x_{i},c_{s_{i}^{k}}),\ y_{i}^{k})\right)= 1𝒯i=1𝒯1𝒩ik=1𝒩ilog(d(f(xi,csik),yik))\displaystyle\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}}\sum^{\mathcal{N}_{i}}_{k=1}\log(d(f(x_{i},c_{s_{i}^{k}}),\ y_{i}^{k})) (10)
αlog(Pφ(csik,xi))\displaystyle-\alpha\log(P_{\varphi}(c_{s_{i}^{k}},x_{i}))

Therefore, PφP_{\varphi} can be optimized by minimizing the following cross-entropy loss:

LCE(h,M)=1𝒯i=1𝒯1𝒩ij=1𝒩i(log(exp(hTmsij)k=1Nexp(hTmk))),L_{CE}(h,M)=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}}\sum^{\mathcal{N}_{i}}_{j=1}\left(-log(\frac{exp(h^{T}m_{s_{i}^{j}})}{\sum_{k=1}^{N}\exp(h^{T}m_{k})})\right),

where M=[m1,,mN]m×NM=[m_{1},...,m_{N}]\in\mathbb{R}^{m\times N} is the fixed ETF classifier generated by Definition 3.

3.2.4 Loss Function

To solve the optimization problem 5, our framework is reduced to finding optimal parameters θ,φ\theta,\varphi, and ϕ\phi, which minimizes E3E_{3}. More specifically, E3E_{3} can now be written as follows:

E3(θ,φ,ϕ,C)=1𝒯i=1𝒯1𝒩i\displaystyle E_{3}({\theta,\varphi,\phi,C})=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}} k=1𝒩ilog(1Pφ(csik,xi)αd(Gθ(xi,csik),yik))\displaystyle\sum^{\mathcal{N}_{i}}_{k=1}\log\left(\frac{1}{P_{\varphi}(c_{s_{i}^{k}},x_{i})^{\alpha}}d(G_{\theta}(x_{i},c_{s_{i}^{k}}),\ y_{i}^{k})\right) (11)
+βd′′(Eϕ(xi,yik),csik)+γd′′′(C)\displaystyle+\beta d^{\prime\prime}(E_{\phi}(x_{i},y_{i}^{k}),c_{s_{i}^{k}})+\gamma d^{\prime\prime\prime}(C)

As discussed in the previous subsections, the regularization term d′′d^{\prime\prime} for EϕE_{\phi} is chosen as d′′=Ezregd^{\prime\prime}=E_{zreg}. Also, the regularization for CC is chosen as d′′′=Lcovτd^{\prime\prime\prime}=L_{cov}^{\tau}. Hence,

E3(θ,φ,ϕ,C)=Lrecon(θ)+αLCE(h,M)+βEzreg(ϕ)+γLcovτ(C)E_{3}({\theta,\varphi,\phi,C})=L_{recon}(\theta)+\alpha L_{CE}(h,M)+\beta E_{zreg}(\phi)+\gamma L_{cov}^{\tau}(C) (12)

where

Lrecon(θ)=1𝒯i=1𝒯1𝒩ij=1𝒩ilogd(Gθ(xi,c(xi,yik)),yik).L_{recon}(\theta)=\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}}\sum^{\mathcal{N}_{i}}_{j=1}\log d(G_{\theta}(x_{i},c(x_{i},y_{i}^{k})),\ y_{i}^{k}). (13)

The choice of distance function dd depends on datasets and tasks. The overall loss function to train the deep neural network to solve the optimization problem 5 can now be summarized as follows.

L(θ,φ,ϕ,C)=1𝒯i=1𝒯1𝒩ij=1𝒩ilogd(Gθ(xi,c(xi,yik)),yik)αlog(Pφ(c(xi,yik),xi))+βEϕ(xi,yik)sg[c(xi,yik)]2+γTτ(C~TC~IN)F2𝕀{Tτ(C~TC~IN)}\begin{split}L(\theta,\varphi,\phi,C)=&\frac{1}{\mathcal{T}}\sum^{\mathcal{T}}_{i=1}\frac{1}{\mathcal{N}_{i}}\sum^{\mathcal{N}_{i}}_{j=1}\log d(G_{\theta}(x_{i},c(x_{i},y_{i}^{k})),\ y_{i}^{k})-\alpha\log(P_{\varphi}(c(x_{i},y_{i}^{k}),x_{i}))\\ &+\beta\|E_{\phi}(x_{i},y_{i}^{k})-sg[c(x_{i},y_{i}^{k})]\|^{2}+\gamma\frac{||T^{\tau}(\tilde{C}^{T}\tilde{C}-I_{N})||_{F}^{2}}{\mathbb{I}\{T^{\tau}(\tilde{C}^{T}\tilde{C}-I_{N})\}}\end{split} (14)

The parameters of the deep neural network can then be optimized by stochastic gradient descent through backward propagation.

3.2.5 Numerical algorithm

We describe the numerical algorithms in detail. Several techniques, such as simple gradient approximation, stop-gradient operation, and exponential moving average, are used in our numerical algorithms.

Firstly, the simple gradient approximation is used during the update of ϕ\phi. Since the encoder EϕE_{\phi} receives no gradient from the reconstruction loss, we assign it the gradient of the code c(xi,yik)c(x_{i},y_{i}^{k}), abbreviated as cc here. Thus, with the input data pair (xi,yik)(x_{i},y_{i}^{k}), we have:

ϕL(θ,φ,ϕ,C)\displaystyle\partial_{\phi}L(\theta,\varphi,\phi,C) =Lrecon(Gθ(xi,c),yik)cEϕ(xi,yik)ϕ+βEϕ(xi,yik)sg[c]2ϕ\displaystyle=\frac{\partial L_{recon}(G_{\theta}(x_{i},c),y_{i}^{k})}{\partial c}\cdot\frac{\partial E_{\phi}(x_{i},y_{i}^{k})}{\partial\phi}+\frac{\beta\cdot\partial\|E_{\phi}(x_{i},y_{i}^{k})-sg[c]\|^{2}}{\partial\phi} (15)
θL(θ,φ,ϕ,C)\displaystyle\partial_{\theta}L(\theta,\varphi,\phi,C) =θLrecon(Gθ(xi,c),yik)\displaystyle=\partial_{\theta}L_{recon}(G_{\theta}(x_{i},c),y_{i}^{k})
φL(θ,φ,ϕ,C)\displaystyle\partial_{\varphi}L(\theta,\varphi,\phi,C) =αφLCE(h,M)\displaystyle=\alpha\cdot\partial_{\varphi}L_{CE}(h,M)
CL(θ,φ,ϕ,C)\displaystyle\partial_{C}L(\theta,\varphi,\phi,C) =γCLcov(C).\displaystyle=\gamma\cdot\partial_{C}L_{cov}(C).

Here, the first equation approximates the gradient of the loss function with respect to the encoder parameters ϕ\phi. The second equation computes the gradient with respect to the generator parameters θ\theta, which includes both the reconstruction loss and the cross-entropy loss. The third equation computes the gradient with respect to the code matrix CC.

Note that the selected codes receive no gradient from the reconstruction loss and regularization loss, but instead from the covariance loss. In addition, the codes are also updated by the exponential moving average van2017neural as follows:

nt\displaystyle n^{t} =nt1κ+countt(1κ)\displaystyle=n^{t-1}*\kappa+count^{t}*(1-\kappa) (16)
mt\displaystyle m^{t} =mt1κ+jcounttejt(1κ)\displaystyle=m^{t-1}*\kappa+\sum^{count^{t}}_{j}e_{j}^{t}*(1-\kappa)
ct\displaystyle c^{t} =mtnt,\displaystyle=\frac{m^{t}}{n^{t}},

where {e1,e2,,ecount}\{e_{1},e_{2},...,e_{count}\} is a set of the embedding features that are closest to code cc. 0<κ<10<\kappa<1 is the decay coefficient.

During the testing process, the probabilities predicted by PθP_{\theta} are not exactly equal but very close to 0. A small threshold ϵ=1e5\epsilon=1e-5 is set to eliminate results with extremely low probability.

The details of the numerical algorithms for the training and testing processes are described in Algorithms 1 and 2, respectively.

Input: weights α\alpha, β\beta, γ\gamma in loss function 14, parameter N,mN,m in codebook, learning rate lrlr.
Output: network parameters θ\theta, φ\varphi, ϕ\phi and codebook CC.
Initialize embedding network EθE_{\theta}, EϕE_{\phi}, generator DθD_{\theta}, classifier PφP_{\varphi} and codebook CC.
for t=0,1,2,..t=0,1,2,.. do
       Sample batch of data pair (xi,yik)(x_{i},y_{i}^{k}) from dataset {{(xi,yik)}k=1𝒩i}i=1T\{\{(x_{i},y_{i}^{k})\}_{k=1}^{\mathcal{N}_{i}}\}_{i=1}^{T} randomly.
       Compute nearest csikc_{s_{i}^{k}} by 8.
       Compute θL(θt,φt,ϕt,Ct)\partial_{\theta}L(\theta^{t},\varphi^{t},\phi^{t},C^{t}), φL(θt,φt,ϕt,Ct)\partial_{\varphi}L(\theta^{t},\varphi^{t},\phi^{t},C^{t}), ϕL(θt,φt,ϕt,Ct)\partial_{\phi}L(\theta^{t},\varphi^{t},\phi^{t},C^{t}), and CL(θt,φt,ϕt,Ct)\partial_{C}L(\theta^{t},\varphi^{t},\phi^{t},C^{t}) by 15.
       Update csikc_{s_{i}^{k}} by the exponential moving average 16.
       Update θ\theta, φ\varphi, ϕ\phi, and CC by
         θt+1=θt+lrθL(θt,φt,ϕt,Ct)\theta^{t+1}=\theta^{t}+lr\cdot\partial_{\theta}L(\theta^{t},\varphi^{t},\phi^{t},C^{t})
         φt+1=φt+lrφL(θt,φt,ϕt,Ct)\varphi^{t+1}=\varphi^{t}+lr\cdot\partial_{\varphi}L(\theta^{t},\varphi^{t},\phi^{t},C^{t})
         ϕt+1=ϕt+lrϕL(θt,φt,ϕt,Ct)\phi^{t+1}=\phi^{t}+lr\cdot\partial_{\phi}L(\theta^{t},\varphi^{t},\phi^{t},C^{t})
         Ct+1=Ct+lrCL(θt,φt,ϕt,Ct)C^{t+1}=C^{t}+lr\cdot\partial_{C}L(\theta^{t},\varphi^{t},\phi^{t},C^{t}).
      
return θ\theta, φ\varphi, ϕ\phi and CC.
Algorithm 1 Training process of DMM Framework.
Input: xXx\in X.
Output: multi-labels y1,y2,,y𝒩xy_{1},y_{2},...,y_{\mathcal{N}_{x}} and their probabilities P(y1|x),P(y2|x),,P(y𝒩x|x)P(y_{1}|x),P(y_{2}|x),...,P(y_{\mathcal{N}_{x}}|x).
Sample a test data xx.
Compute {P(c^1|x),P(c^2|x),,P(c^N|x)}\{P(\hat{c}_{1}|x),P(\hat{c}_{2}|x),...,P(\hat{c}_{N}|x)\} by Pθ(x)P_{\theta}(x).
𝒩x=Ncount{P(yj^|x)<ϵ}\mathcal{N}_{x}=N-count\{P(\hat{y_{j}}|x)<\epsilon\}.
Obtain {c^1,c^2,,c^𝒩x}\{\hat{c}_{1},\hat{c}_{2},...,\hat{c}_{\mathcal{N}_{x}}\} by reordering {c1,c2,,cN}\{c_{1},c_{2},...,c_{N}\} in descending probabilities while removing items whose P(c^j|x)<ϵP(\hat{c}_{j}|x)<\epsilon.
for j=1,2,,𝒩xj=1,2,...,\mathcal{N}_{x} do
       yj=Dθ(x,c^j).y_{j}=D_{\theta}(x,\hat{c}_{j}).
return {y1,y2,,y𝒩x}\{y_{1},y_{2},...,y_{\mathcal{N}_{x}}\}, P(y1|x),P(y2|x),,P(y𝒩x|x)P(y_{1}|x),P(y_{2}|x),...,P(y_{\mathcal{N}_{x}}|x).
Algorithm 2 Testing process of DMM Framework.

4 Experiments

To evaluate the effectiveness of our proposed framework, we conducted experiments on both synthetic examples and real-world imaging problems. We also performed ablation studies to analyze the key components of the framework. In this section, the experimental results will be reported.

4.1 Experimental setup

We utilize convolutional neural networks (CNNs) for both the embedding networks and the generative module, resembling the U-Net architecture as in ronneberger2015u . To be more precise, both our embedding networks EϕE_{\phi} and EθE_{\theta} consist of a sequence of downsampling residual blocks, while the generative module GθG_{\theta} is comprised of a sequence of upsampling residual blocks. Additionally, the generative network incorporates feature information from the embedding network EθE_{\theta} at each resolution level. There are four downsampling or upsampling blocks in the sequence of each module. The downsampling and upsampling operations use bilinear interpolation. Each residual block comprises three convolution layers, utilizing 3×33\times 3 kernels and ReLU activation. The two embedding networks have the same architecture with output channel dimension [32,64,128,256][32,64,128,256]. The 1×1 convolution and global average pool follow the network EϕE_{\phi} of the data-label pair to obtain the feature of the same dimension as the code in the codebook. However, a fixed ETF classifier as a probability network follows the embedding network EθE_{\theta} to output the probability prediction of all codes in the codebook. The categories of the classifier are the same size as the codebook. We incorporate the selected code and the features of the generative network to the model’s last layers with 1×11\times 1 convolutions and finally activated by softmax. Note that the code can incorporate any skip connection to the generative module GθG_{\theta} for different tasks. Since the code selected from the codebook is a vector, it is impossible to directly incorporate the code into a generation network with spatial dimensions. We repeatedly extend each value in the code to the spatial size of the corresponding features, and then concatenate them. We initialize the codebook CC as a (256,256)(256,256) matrix of i.i.d random rotation matrices. Each column of the codebook represents an individual code with a norm of 1. During the training, we utilized the binary cross-entropy loss as label reconstruction loss and set the batch size to 3232. Additionally, we employed a learning rate schedule with values of [1e4,5e5,1e5,5e6][1e^{-4},5e^{-5},1e^{-5},5e^{-6}] at epochs [0,300,900,1200][0,300,900,1200]. The l2l^{2} penalization weight was fixed at β=0.25\beta=0.25, another weight of covariance constraint is γ=0.01\gamma=0.01, and we utilized the Adam optimizer kingma2014adam with its default settings for all of our experiments. Specifically, we train the first 2020 epochs without the cross-entropy loss to avoid the impact of violent fluctuations in early code selection on the learning of probability network parameters. For the Probabilistic U-Net, we followed the parameter numfilter=[32,64,128,192]num_{filter}=[32,64,128,192] in its released version and the suggested hyperparameters for segmentation tasks in kohl2018probabilistic .

4.2 Synthetic examples: Shape reconstruction

To rigorously evaluate the capabilities of our proposed framework, we first conducted experiments on a synthetic dataset for the task of shape reconstruction. We generated a dataset of shape data-label pairs, where the data consisted of a randomly generated triangle image of size (160, 160), and the labels represented four distinct shapes derived from the properties of the input triangle.

Specifically, the dataset generation process was as follows. First, we randomly created a triangle image to serve as the input data. We then synthesized four corresponding label images, each representing a shape related to, but distinct from, the original triangle:

  1. 1.

    The first label was a smaller triangle created by cropping a random portion of the input triangle.

  2. 2.

    The second label was the original input triangle.

  3. 3.

    The third label was a pentagon shape formed by cutting a random triangle from the upper-right corner of the parallelogram constructed from the input triangle.

  4. 4.

    The fourth label was a complete parallelogram shape generated by extending the parallelogram formed by the input triangle.

So the label shapes were not simply transformations of the input triangle, but rather new shapes that were algorithmically derived from the properties and geometry of the original triangle. This allowed us to evaluate how well our model could capture the underlying relationships between the input triangle and these related shape outputs.

Using this procedure, we constructed a synthetic dataset consisting of 2,000 data-label pairs for the training set and 200 pairs for the test set. Our goal is to find an optimal DMM that fits the training dataset. Given an input triangle, the optimal DMM should predict all four shapes, along with an estimation of the probability associated with each output shape.

In the training process, we repeatedly randomly sample a triangle xix_{i} from the training set and one of its four labels as yiky_{i}^{k}. Since the four labels are different, random sampling ensures that the distribution of the label space consists of four modes with the same probability equal to 0.25.

In the testing phase, we sample a triangle from the testing data and record its most likely outputs y^ij|P(y^ij|xi)>1e5,j=1,,N{\hat{y}_{i}^{j}|P(\hat{y}_{i}^{j}|x_{i})>1e^{-5},j=1,...,N}. Fig. 4 shows the results of two input triangles. In each case, the leftmost image in the first row shows the input sample. The other four images show the 4 labeled shapes associated with the input triangle. Note that the input and output pairs in the testing have not been used in the training process.

The second row of Fig. 4 shows the predictions by the optimal DMM. Observe that the generated plausible outputs closely resemble the 4 ground truth shapes. This demonstrates the efficacy of our proposed framework to make multi-modal predictions. The value in the top left corner of each output shows the probability associated with each plausible output generated by the DMM. All values are very close to the ground truth probability of 0.25. This demonstrates that our framework is successful in obtaining accurate uncertainty estimates corresponding to each output.

Fig. 5 shows the results of 8 more input shapes. Again, the optimal DMM accurately predicts the 4 labeled shapes and their associated probabilities. Overall, these results on the synthetic dataset validate the ability of the optimal DMM to not only generate the diverse set of shapes derived from the input triangle, but also provide reliable probability estimates for each predicted output.

Refer to caption
Refer to caption
Figure 4: Results visualization for shape reconstruction. The first row shows the input samples and their labels, and the next row shows the predictions from our method. The probability for each prediction is annotated in the upper-left corner.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 5: Results from our model on the shape reconstruction task. Results with the predicted uncertainties (>1e5>1e^{-5}) are shown.

4.3 Lung segmentation with pulmonary opacity

We next test our proposed framework for lung segmentation of chest radiographs with pulmonary opacities. Lung segmentation of chest radiographs with pulmonary opacities is a challenging task. The presence of pulmonary opacities, such as infections, masses, or consolidations, can introduce significant uncertainty in accurately delineating the lung boundaries. The opacities can obscure the true lung margins, making it difficult to determine the exact extent of the lungs. Due to the ambiguity introduced by the pulmonary opacities, different radiologists or experts may provide varying lung segmentation labels for the same chest radiograph. There can be disagreement on where to draw the precise lung boundaries, especially in regions with opacity.

In this experiment, a collection of chest radiographs with pulmonary opacities are synthesized based on the lung X-ray dataset candemir2013lung ; jaeger2013automatic which comprises 640 data-label pairs for training and 63 pairs for testing with the size of (180,180)(180,180). To simulate real-world scenarios, we introduce random intensity occlusions to intact the radiographs. The two corresponding segmentation masks labeled by two different experts are generated under different rules. The first-kind label is a segmentation mask from the original dataset, indicating the complete region of the lungs. The second kind differs from the original segmentation by randomly removing part of the opacity area to stimulate the doctors’ empirical annotations of the lung masks. Using our proposed framework, our goal is to obtain an optimal DMM that fits the dataset with the labeled segmentation mask, as well as predicts the probability associated with each plausible segmentation result.

In the training process, just as the shape reconstruction task, we still randomly sample a chest radiograph with pulmonary opacities xix_{i} from the training set and randomly sample one of its two labels as yiky_{i}^{k}. Then the distribution of the label space consists of two modes with the same probability of 0.50.5.

To show the test results, we output predictions of chest radiograph from the testing set with explicit probability estimates (P(y^ij|xi)>1e5P(\hat{y}_{i}^{j}|x_{i})>1e^{-5}) annotated on the upper-left corner in the second row in Fig. 6. More results are in Fig. 11. Obviously, the two outputs are very similar to the corresponding labels and have a probability close to 0.50.5. We also trained the Probabilistic U-Net model on this dataset with its results shown in Fig. 13, 14 and 15. Since it cannot directly predict the probability of the output, we respectively do 22 samples, 44 samples, and 1616 samples to estimate the results. For each mode’s probability, we count the proportion of predictions similar to that label in all sampling results as the probability of that label. For example, in the results of doing 22 samplings, if both results are similar to the first-kind label, then the probability of the first-kind label is 1, and the probability of the other label is 0. This means that it only predicts one mode. If the two results are similar to the two labels respectively, then the probability of both labels is 0.5, which is the correct prediction. The probabilities’ distributions are shown in Fig. 8(a) and Table 1. It shows our model achieves excellent results in both segmentation and probabilistic prediction accuracy.

Refer to caption
Refer to caption
Figure 6: Results visualization for lung corrupt segmentation. The first row shows the input samples and their labels, and the next row shows the predictions from our method. The predicted probability for each output is annotated in the upper-left corner.
model Ours Prob. U-net 2 Prob. U-net 4 Prob. U-net 16
Grader 0 mean 0.5040 0.4921 0.5238 0.5496
Grader 1 mean 0.4960 0.5079 0.4762 0.4504
std 0.0262 0.3391 0.2345 0.1482
Table 1: Mean and std values of predicted probabilities in Lung corrupt segmentation

4.4 Real applications

To evaluate the performance of our model on more complex real-world datasets, we work on the lesion segmentation task of ambiguous lung CT scans. In this experiment, the LIDC-IDRI dataset is provided by armato2011lung ; clark2013cancer , which contains 1018 lung CT scans from 1010 patients. It is an unorganized multi-modal distributed dataset. Each scan has four (out of twelve) medical experts labeling it a mask. And experts independently judged the location and shape of existing lesions based on their respective knowledge and experience. This does not mean that the label set is a balanced 4-modal discrete distribution. In actual situations, for the same scan, the masks given by experts may be the same, but when it comes to another scan, the masks they give will become different. The first rows of Fig. 7 that are sampled from the testing set can prove this. They even have objections to whether a lesion exists in the scan. Our task is to learn these imbalanced multi-modal distributions and make accurate uncertainty estimates while outputting possible outcomes.

During the training process, we randomly sample a CT scan xix_{i} from the training set and one of its four segmentations as yiky_{i}^{k}. Fig. 7 shows some examples from the testing set with high precision. The first row is the lesion scan and its four labels. The last two rows (We have reserved eight positions for the results in the picture.) are our top predictions, where the probability (larger than 1e51e^{-5}) associated with each prediction is annotated on the upper-left corner. More results are shown in Fig. 12. It’s obvious that our method effectively captures the uncertainty present in the segmentation labels, as evidenced by significant probability scores. The probabilities of predictions are almost equal to the probabilities of the true distribution. The no lesion inference on the left is especially approximated at 0.750.75, and on the right is close to 0.

Refer to caption
Refer to caption
Figure 7: Results visualization for lung corrupt segmentation. The first row shows the input samples and their labels, and the next row shows all the predictions from our method with probabilities larger than 1e51e^{-5}. The predicted probability for each output is annotated in the upper-left corner.

Unlike synthetic datasets, there is no ground truth distribution for the LIDC-IDRI dataset. In order to evaluate the model performance, we adopt the generalized energy distance metric DGED2D^{2}_{\text{GED}} found in bellemare2017cramer ; szekely2013energy , which only access the samples from the distributions that models induce. It is used to evaluate the performance of the Probabilistic U-net. However, the Probabilistic U-net model cannot generate dynamic result sets and corresponding uncertainty estimates. Thus, we rewrite this metric according to the type of output we get as follows: given the label set YxYY_{x}\subset Y and the prediction set SxYS_{x}\subset Y which are corresponding to the data xx, the general energy distance metric of them is

DGED2(Yx,Sx)=2yYxsSxpspyd(y,s)yYxyYxpypyd(y,y)sSxsSxpspsd(s,s),D_{\mathrm{GED}}^{2}\left(Y_{x},S_{x}\right)=2\sum_{y\in Y_{x}}\sum_{s\in S_{x}}p_{s}p_{y}d(y,s)-\sum_{y\in Y_{x}}\sum_{y^{\prime}\in Y_{x}}p_{y}p_{y^{\prime}}d\left(y,y^{\prime}\right)-\sum_{s\in S_{x}}\sum_{s^{\prime}\in S_{x}}p_{s}p_{s^{\prime}}d\left(s,s^{\prime}\right),

where d(y,s)=1IoU(y,s)d(y,s)=1-\operatorname{IoU}(y,s) is the metric for evaluating the similarity of masks. psp_{s} is the probability prediction for the output ss and pyp_{y} is the probability for the ground truth yy. In case the ground truth is not available like LIDC-IDRI, we use py=1|Yx|p_{y}=\frac{1}{\left|Y_{x}\right|}, where |Yx|\left|Y_{x}\right| denotes the cardinality of YxY_{x}. In particular, py=14p_{y}=\frac{1}{4} on the LIDC-IDRI task. For the Probabilistic U-Net, there are no probability predictions for each output. We set ps=1|Sx|=1np_{s}=\frac{1}{\left|S_{x}\right|}=\frac{1}{n} if we have nn samples from the model.

The quantity results of our model and Probabilistic U-net model are shown in Fig. 8(b) and Table 2. Lower values demonstrate the performance superiority of our model. More test predictions from Probabilistic U-Net are shown in Fig. 16 and Fig. 17.

model Ours Prob. U-net 4 Prob. U-net 16
mean 0.3058 0.4552 0.3253
std 0.2761 0.3375 0.2743
Table 2: Generalized Energy Distance in LIDC
Refer to caption
(a)
Refer to caption
(b)
Figure 8: Quantitative comparison. The small dots represent the quantities for test samples, and the small triangles represent the means of quantities. The numbers behind Prob. U-net are that of sample predictions. (a) shows our model produces an accurate uncertainty estimate for each mode. Probabilistic U-Net uses conventional Gaussian latent parametrization so that we can only sample results of a fixed number, e.g., 2, 4, 16, … and then count the corporation of similar predictions as shown in the last three couples. However, the results are far less accurate than ours on the lung corrupt dataset. (b) shows the DGED2D^{2}_{\text{GED}} values of test data on LIDC-IDRI segmentation task. Our model makes the state-of-the-art performance

4.5 ABLATION ANALYSIS

In this subsection, we explore some tricks applied in the model and some terms in the loss function in order to verify how they promote the model.

To test the performance of covariance loss, we train our model on the LIDC dataset by setting the weight of covariance loss as γ=0\gamma=0 and γ=0.01\gamma=0.01, respectively. Given the same initial codebook, we count the number of codes called in each epoch and their average pairwise inner products in the training process. The results are shown in 9. We can see that covariance can help us separate the codes so that the features of the data can be concentrated on codes with a small number. On the contrary, without the balance of covariance loss, the frequency of use of codes will fluctuate significantly, and the codes will be more similar to each other.

Refer to caption
Refer to caption
Figure 9: Ablations analysis. The left shows the number of usage codes in each epoch during training, and the right shows the mean of the pairwise inner product of these codes.

In addition, we explore the performances of the renewal of the codebook and fixed ETF classifier by training our model on the LIDC dataset. We introduce two variations to our original approach: the Fixed codebook and the Learnable classifier. The Fixed codebook approach involves no longer updating the codes once the codebook space is initialized. On the other hand, the Learnable classifier replaces the fixed ETF classifier with a commonly used linear classifier that can be learned to predict probabilities. To evaluate the performance of these variations, we measure the DGED2D_{\mathrm{GED}}^{2} metric on the test dataset and analyze the usage of codes during training. The results, as shown in Fig. 10, indicate that our original approach outperforms the variation approaches. We observe that not updating the codebook results in worse DGED2D_{\mathrm{GED}}^{2} scores and leads to unstable code utilization. This is because the covariance loss no longer functions concurrently when the codebook is fixed. Furthermore, Fig. 10 demonstrates the advantages of the fixed ETF classifier over the learnable linear classifier. In the absence of explicit data distribution for learning in the LIDC dataset, the fixed ETF classifier provided more accurate distribution predictions with fewer classes. In summary, our findings highlight the superiority of our original approach, which incorporates both the renewal of the codebook and the fixed ETF classifier. This approach yields better performance in terms of accurate distribution prediction and stable code utilization.

Refer to caption
Refer to caption
Figure 10: Ablations visualization. The left is the comparison of trick variations of our approach using the generalized energy distance. The small dots represent the quantities for test samples , and the small triangles represent the means of quantities. The right shows the number of latent codes used in each epoch during training.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 11: Results from our model on the lung corrupt segmentation task. Results with the predicted uncertainties (>1e5>1e^{-5}) are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 12: Results from our model on the LIDC-IDRI segmentation task. Results with the predicted uncertainties (>1e5>1e^{-5}) are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 13: Results from Probabilistic U-Net on the lung corrupt segmentation task. 2 random sample results are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 14: Results from Probabilistic U-Net on the lung corrupt segmentation task. 4 random sample results are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 15: Results from Probabilistic U-Net on the lung corrupt segmentation task. 16 random sample results are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 16: Results from Probabilistic U-Net on the LIDC-IDRI segmentation task. 4 random sample results are shown.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 17: Results from Probabilistic U-Net on the LIDC-IDRI segmentation task. 16 random sample results are shown.

5 Conclusion

We have adopted dynamic multi-valued mapping to describe the uncertainty in practice. It is a multi-valued mapping with a probability measure for each output. A data-driven optimization problem is further designed to solve dynamic multi-valued mapping. We proposed a general deep neural network framework to achieve this. Meanwhile, a codebook is introduced to explain the corresponding relationship between input and output in multi-value mapping. It also includes an effective evaluation of prediction probabilities to capture uncertainty in the dataset. Through extensive validation of synthetic and realistic tasks, we have demonstrated the superior performance of our method compared to state-of-the-art approaches.

Acknowledgements.
L.M. Lui is supported by HKRGC GRF (Project ID: ).

References

  • (1) Armato III, S.G., McLennan, G., Bidaut, L., McNitt-Gray, M.F., Meyer, C.R., Reeves, A.P., Zhao, B., Aberle, D.R., Henschke, C.I., Hoffman, E.A., et al.: The lung image database consortium (lidc) and image database resource initiative (idri): a completed reference database of lung nodules on ct scans. Medical physics 38(2), 915–931 (2011)
  • (2) Bellemare, M.G., Danihelka, I., Dabney, W., Mohamed, S., Lakshminarayanan, B., Hoyer, S., Munos, R.: The cramer distance as a solution to biased wasserstein gradients. arXiv preprint arXiv:1705.10743 (2017)
  • (3) Candemir, S., Jaeger, S., Palaniappan, K., Musco, J.P., Singh, R.K., Xue, Z., Karargyris, A., Antani, S., Thoma, G., McDonald, C.J.: Lung segmentation in chest radiographs using anatomical atlases with nonrigid registration. IEEE transactions on medical imaging 33(2), 577–590 (2013)
  • (4) Clark, K., Vendt, B., Smith, K., Freymann, J., Kirby, J., Koppel, P., Moore, S., Phillips, S., Maffitt, D., Pringle, M., et al.: The cancer imaging archive (tcia): maintaining and operating a public information repository. Journal of digital imaging 26, 1045–1057 (2013)
  • (5) Huang, X., Liu, M.Y., Belongie, S., Kautz, J.: Multimodal unsupervised image-to-image translation. In: Proceedings of the European conference on computer vision (ECCV), pp. 172–189 (2018)
  • (6) Jaeger, S., Karargyris, A., Candemir, S., Folio, L., Siegelman, J., Callaghan, F., Xue, Z., Palaniappan, K., Singh, R.K., Antani, S., et al.: Automatic tuberculosis screening using chest radiographs. IEEE transactions on medical imaging 33(2), 233–245 (2013)
  • (7) Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014)
  • (8) Kohl, S., Romera-Paredes, B., Meyer, C., De Fauw, J., Ledsam, J.R., Maier-Hein, K., Eslami, S., Jimenez Rezende, D., Ronneberger, O.: A probabilistic u-net for segmentation of ambiguous images. Advances in neural information processing systems 31 (2018)
  • (9) Papyan, V., Han, X., Donoho, D.L.: Prevalence of neural collapse during the terminal phase of deep learning training. Proceedings of the National Academy of Sciences 117(40), 24,652–24,663 (2020)
  • (10) Qiu, D., Lui, L.M.: Modal uncertainty estimation for medical imaging based diagnosis. In: Uncertainty for Safe Utilization of Machine Learning in Medical Imaging, and Perinatal Imaging, Placental and Preterm Image Analysis: 3rd International Workshop, UNSURE 2021, and 6th International Workshop, PIPPI 2021, Held in Conjunction with MICCAI 2021, Strasbourg, France, October 1, 2021, Proceedings 3, pp. 3–13. Springer (2021)
  • (11) Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: Medical image computing and computer-assisted intervention–MICCAI 2015: 18th international conference, Munich, Germany, October 5-9, 2015, proceedings, part III 18, pp. 234–241. Springer (2015)
  • (12) Sohn, K., Lee, H., Yan, X.: Learning structured output representation using deep conditional generative models. Advances in neural information processing systems 28 (2015)
  • (13) Székely, G.J., Rizzo, M.L.: Energy statistics: A class of statistics based on distances. Journal of statistical planning and inference 143(8), 1249–1272 (2013)
  • (14) Tao, T.: A cheap version of the kabatjanskii-levenstein bound for almost orthogonal vectors. https://terrytao.wordpress.com/2013/07/18/ (2013)
  • (15) Van Den Oord, A., Vinyals, O., et al.: Neural discrete representation learning. Advances in neural information processing systems 30 (2017)
  • (16) Yang, Y., Xie, L., Chen, S., Li, X., Lin, Z., Tao, D.: Do we really need a learnable classifier at the end of deep neural network? arXiv e-prints pp. arXiv–2203 (2022)
  • (17) Zheng, C., Cham, T.J., Cai, J.: Pluralistic image completion. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1438–1447 (2019)
  • (18) Zhu, J.Y., Zhang, R., Pathak, D., Darrell, T., Efros, A.A., Wang, O., Shechtman, E.: Toward multimodal image-to-image translation. Advances in neural information processing systems 30 (2017)