This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Being Bayesian about Categorical Probability

Taejong Joo    Uijung Chung    Min-Gwan Seo
Abstract

Neural networks utilize the softmax as a building block in classification tasks, which contains an overconfidence problem and lacks an uncertainty representation ability. As a Bayesian alternative to the softmax, we consider a random variable of a categorical probability over class labels. In this framework, the prior distribution explicitly models the presumed noise inherent in the observed label, which provides consistent gains in generalization performance in multiple challenging tasks. The proposed method inherits advantages of Bayesian approaches that achieve better uncertainty estimation and model calibration. Our method can be implemented as a plug-and-play loss function with negligible computational overhead compared to the softmax with the cross-entropy loss function.

Deep learning, Bayesian principle, Neural network, Softmax alternative, Variational inference

1 Introduction

Softmax (Bridle, 1990) is the de facto standard for post processing of logits of neural networks (NNs) for classification. When combined with the maximum likelihood objective, it enables efficient gradient computation with respect to logits and has achieved state-of-the-art performances on many benchmark datasets. However, softmax lacks the ability to represent the uncertainty of predictions (Blundell et al., 2015; Gal & Ghahramani, 2016) and has poorly calibrated behavior (Guo et al., 2017). For instance, the NN with softmax can easily be fooled to confidently produce wrong outputs; when rotating digit 3, it will predict it as the digit 8 or 4 with high confidence (Louizos & Welling, 2017). Another concern of the softmax is its confident predictive behavior makes NNs to be subject to overfitting (Xie et al., 2016; Pereyra et al., 2017). This issue raises the need for effective regularization techniques for improving generalization performance.

Bayesian NNs (BNNs; MacKay, 1992) can address the aforementioned issues of softmax. BNNs provide quantifiable measures of uncertainty such as predictive entropy and mutual information (Gal, 2016) and enable automatic embodiment of Occam’s razor (MacKay, 1995). However, some practical obstacles have impeded the wide adoption of BNNs. First, the intractable posterior inference in BNNs demands approximate methods such as variation inference (VI; Graves, 2011; Blundell et al., 2015) and Monte Carlo (MC) dropout (Gal & Ghahramani, 2016). Even with such novel approximation methods, concerns arise regarding both the degree of approximation and the computational expensive posterior inference (Wu et al., 2019a; Osawa et al., 2019). In addition, under extreme non-linearity between parameters and outputs in the NNs, determining a meaningful weight prior distribution is challenging (Sun et al., 2019). Last but not least, BNNs often require considerable modifications to existing baselines, or they result in performance degradation (Lakshminarayanan et al., 2017).

In this paper, we apply the Bayesian principle to construct the target distribution for learning classifiers. Specifically, we regard a categorical probability as a random variable, and construct the target distribution over the categorical probability by means of the Bayesian inference, which is approximated by NNs. The resulting target distribution can be thought of as being regularized via the prior belief whose impact is controlled by the number of observations. By considering only the random variable of categorical probability, the Bayesian principle can be efficiently adopted to existing deep learning building blocks without huge modifications. Our extensive experiments show effectiveness of being Bayesian about the categorical probability in improving generalization performances, uncertainty estimation, and calibration property.

Our contributions can be summarized as follows: 1) we show the importance of considering categorical probability as a random variable instead of being determined by the label; 2) we provide experimental results showing the usefulness of the Bayesian principle in improving generalization performance of large models on standard benchmark datasets, e.g., ResNext-101 on ImageNet; 3) we enable NNs to inherit the advantages of the Bayesian methods in better uncertainty representation and well-calibrated behavior with a negligible increase in computational complexity.

Refer to caption
(a) Softmax cross-entropy loss
Refer to caption
(b) Belief matching framework
Figure 1: Illustration of the difference between softmax cross-entropy loss and belief matching framework when each image is unique in the training set. In softmax cross-entropy loss, the label “cat” is directly transformed into the target categorical distribution. In belief matching framework, the label “cat” is combined with the prior Dirichlet distribution over the categorical probability. Then, the Bayes’ rule updates the belief about categorical probability, which produces the target distribution.

2 Preliminary

This paper focuses on classification problems in which, given i.i.d. training samples 𝒟={𝒙(i),y(i)}i=1N(𝒳×𝒴)N\mathcal{D}=\left\{{\bm{x}}^{(i)},y^{(i)}\right\}_{i=1}^{N}\in(\mathcal{X}\times\mathcal{Y})^{N}, we construct a classifier :𝒳𝒴\mathcal{F}:\mathcal{X}\rightarrow\mathcal{Y}. Here, 𝒳\mathcal{X} is an input space and 𝒴={1,,K}\mathcal{Y}=\left\{1,\cdots,K\right\} is a set of labels. We denote 𝐱{\mathbf{x}} and y as random variables whose unknown probability distributions generate inputs and labels, respectively. Also, we let 𝐲~\tilde{{\mathbf{y}}} be a one-hot representation of y.

Let f𝑾:𝒳𝒳f^{{\bm{W}}}:\mathcal{X}\rightarrow\mathcal{X}^{\prime} be a NN with parameters 𝑾{\bm{W}} where 𝒳=K\mathcal{X}^{\prime}=\mathbb{R}^{K} is a logit space. In this paper, we assume argmaxjfj𝑾\operatorname*{arg\,max}_{j}f_{j}^{{\bm{W}}}\subseteq\mathcal{F} is the classification model where fj𝑾f_{j}^{{\bm{W}}} denotes the jj-th output basis of f𝑾f^{{\bm{W}}}, and we concentrate on the problem of learning 𝑾{\bm{W}}. Given ((𝒙,y)𝒟,f𝑾)(({\bm{x}},y)\in\mathcal{D},f^{{\bm{W}}}), a standard minimization loss function is the softmax cross-entropy loss, which applies the softmax to logit and then computes the cross-entropy between a one-hot encoded label and a softmax output (Figure 1(a)). Specifically, the softmax, denoted by ϕ:𝒳K1\phi:\mathcal{X}^{\prime}\rightarrow\triangle^{K-1}, transforms a logit f𝑾(𝒙)f^{{\bm{W}}}({\bm{x}}) into a normalized exponential form:

ϕk(f𝑾(𝒙))=exp(fk𝑾(𝒙))jexp(fj𝑾(𝒙))\phi_{k}(f^{{\bm{W}}}({\bm{x}}))=\frac{\exp(f^{{\bm{W}}}_{k}({\bm{x}}))}{\sum_{j}\exp(f^{{\bm{W}}}_{j}({\bm{x}}))} (1)

, and then the cross-entropy loss can be computed by lCE(𝒚~,ϕ(f𝑾(𝒙)))=k𝒚~klogϕk(f𝑾(𝒙))l_{CE}(\tilde{{\bm{y}}},\phi(f^{{\bm{W}}}({\bm{x}})))=-\sum_{k}\tilde{{\bm{y}}}_{k}\log\phi_{k}(f^{{\bm{W}}}({\bm{x}})). Here, note that the softmax output can be viewed as a parameter of the categorical distribution, which can be denoted by 𝒫C(ϕ(f𝑾(𝒙)))\mathcal{P}^{C}(\phi(f^{{\bm{W}}}({\bm{x}}))).

We can formulate the minimization of the softmax cross-entropy loss over 𝒟\mathcal{D} into a collection of distribution matching problems. To this end, let c𝒟(𝒙)c^{\mathcal{D}}({\bm{x}}) be a vector-valued function that counts label frequency at 𝒙𝒳{\bm{x}}\in\mathcal{X} in 𝒟\mathcal{D}, which is defined as:

c𝒟(𝒙)=(𝒙,y)𝒟𝒚~𝟙{𝒙}(𝒙)c^{\mathcal{D}}({\bm{x}})=\sum_{({\bm{x}}^{\prime},y^{\prime})\in\mathcal{D}}\tilde{{\bm{y}}}^{\prime}\mathds{1}_{\{{\bm{x}}\}}({\bm{x}}^{\prime}) (2)

where 𝟙𝒜(x)\mathds{1}_{\mathcal{A}}(x) is an indicator function that takes 1 when x𝒜x\in\mathcal{A} and 0 otherwise. Then, the empirical risk on 𝒟\mathcal{D} can be expressed as follows:

^𝒟(𝑾)=1Ni=1Nlogϕy(i)(f𝑾(𝒙(i)))=𝒙G(𝒟)ici𝒟(𝒙)Nl𝒙𝒟(𝑾)+C\hat{\mathcal{L}}_{\mathcal{D}}({\bm{W}})=-\frac{1}{N}\sum_{i=1}^{N}\log\phi_{y^{(i)}}(f^{{\bm{W}}}({\bm{x}}^{(i)}))\\ =\sum_{{\bm{x}}\in G(\mathcal{D})}\frac{\sum_{i}c_{i}^{\mathcal{D}}({\bm{x}})}{N}l_{{\bm{x}}}^{\mathcal{D}}({\bm{W}})+C (3)

where G(𝒟)G(\mathcal{D}) is a set of unique values in 𝒟\mathcal{D}, e.g., G({1,2,2})={1,2}G(\{1,2,2\})=\{1,2\}, and CC is a constant with respect to 𝑾{\bm{W}}; l𝒙𝒟(𝑾)l^{\mathcal{D}}_{{\bm{x}}}({\bm{W}}) measures the KL divergence between the empirical target distribution and the categorical distribution modeled by the NN at location 𝒙{\bm{x}}, which is given by:

l𝒙𝒟(𝑾)=KL(𝒫C(c𝒟(𝒙)ici𝒟(𝒙))𝒫C(ϕ(f𝑾(𝒙))))l^{\mathcal{D}}_{{\bm{x}}}({\bm{W}})=KL\left(\mathcal{P}^{C}\left(\frac{c^{\mathcal{D}}({\bm{x}})}{\sum_{i}c_{i}^{\mathcal{D}}({\bm{x}})}\right)\parallel\mathcal{P}^{C}\left(\phi(f^{{\bm{W}}}({\bm{x}}))\right)\right) (4)

Therefore, the normalized value of c𝒟(𝒙)c^{\mathcal{D}}({\bm{x}}) becomes the estimator of a categorical probability of the target distribution at location 𝒙{\bm{x}}. However, directly approximating this target distribution can be problematic. This is because the estimator uses single or very few samples since most of the inputs are unique or very rare in the training set.

One simple heuristic to handle this problem is label smoothing (Szegedy et al., 2016) that constructs a regularized target estimator, in which a one-hot encoded label 𝒚~\tilde{{\bm{y}}} is relaxed by (1λ)𝒚~+λK1(1-\lambda)\tilde{{\bm{y}}}+\frac{\lambda}{K}\textbf{1} with hyperparameter λ\lambda. Under the smoothing operation, the target estimator is regularized by a mixture of the empirical counts and the parameter of the discrete uniform distribution 𝒫U\mathcal{P}^{U} such that (1λ)𝒫C(c𝒟(𝒙)ici𝒟(𝒙))+λ𝒫U(1-\lambda)\mathcal{P}^{C}\left(\frac{c^{\mathcal{D}}({\bm{x}})}{\sum_{i}c_{i}^{\mathcal{D}}({\bm{x}})}\right)+\lambda\mathcal{P}^{U}. One concern is that the mixing coefficient is constant with respect to the number of observations, which can possibly prevent the exploitation of the empirical counting information when it is needed.

Another more principled approach is BNNs, which prevents full exploitation of the noisy estimation by balancing the distance to the target distribution with model complexity and maintaining the weight ensemble instead of choosing a single best configuration. Specifically, in BNNs with the Gaussian weight prior 𝒩(0,τ1𝑰)\mathcal{N}(\textbf{0},\tau^{-1}{\bm{I}}), the score of configuration 𝑾{\bm{W}} is measured by the posterior density p𝐖(𝑾|𝒟)p(𝒟|𝑾)p𝐖(𝑾)p_{{\mathbf{W}}}({\bm{W}}|\mathcal{D})\propto p(\mathcal{D}|{\bm{W}})p_{{\mathbf{W}}}({\bm{W}}) where we have logp𝐖(𝑾)τ𝑾22\log p_{{\mathbf{W}}}({\bm{W}})\propto-\tau\parallel{\bm{W}}\parallel_{2}^{2}. Therefore, the complexity penalty term induced by the prior prevents the softmax output from exactly matching a one-hot encoded target. In modern deep NNs, however, 𝑾22\parallel{\bm{W}}\parallel_{2}^{2} may be poor proxy for the model complexity due to extreme non-linear relationship between weights and outputs (Hafner et al., 2018; Sun et al., 2019) as well as weight-scaling invariant property of batch normalization (Ioffe & Szegedy, 2015). This issue may result in poorly regularized predictions, i.e., cannot prevent NNs from the full exploitation of the information contained in the noisy target estimator.

3 Method

3.1 Constructing Target Distribution

We propose a Bayesian approach to construct the target distribution for classification, called a belief matching framework (BM; Figure 1(b)), in which the categorical probability about a label is regarded as a random variable 𝐳{\mathbf{z}}. Specifically, we express the likelihood of 𝐳{\mathbf{z}} (given 𝐱{\mathbf{x}}) about the label y as a categorical distribution py|𝐱,𝐳=𝒫C(𝐳|𝐱)p_{{\textnormal{y}}|{\mathbf{x}},{\mathbf{z}}}=\mathcal{P}^{C}({\mathbf{z}}|{\mathbf{x}})111In this paper, p𝐱=𝒫(θ)p_{{\mathbf{x}}}=\mathcal{P}(\theta) is read as a random variable 𝐱{\mathbf{x}} follows a probability distribution 𝒫\mathcal{P} with parameter θ\theta.. Then, specification of the prior distribution over 𝐳|𝐱{\mathbf{z}}|{\mathbf{x}} automatically determines the target distribution by means of the Bayesian inference: p𝐳|𝐱,y(𝒛)py|𝐳,𝐱(y)p𝐳|𝐱(𝒛)p_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}({\bm{z}})\propto p_{{\textnormal{y}}|{\mathbf{z}},{\mathbf{x}}}(y)p_{{\mathbf{z}}|{\mathbf{x}}}({\bm{z}}).

We consider a conjugate prior for simplicity, i.e., the Dirichlet distribution. A random variable 𝐳{\mathbf{z}} (given 𝐱{\mathbf{x}}) following the Dirichlet distribution with concentration parameter vector β\beta, denoted by 𝒫D(β)\mathcal{P}^{D}(\beta), has the following density:

p𝐳|𝐱(𝒛)=Γ(β0)jΓ(βj)k=1Kzkβk1p_{{\mathbf{z}}|{\mathbf{x}}}({\bm{z}})=\frac{\Gamma(\beta_{0})}{\prod_{j}\Gamma(\beta_{j})}\prod_{k=1}^{K}z_{k}^{\beta_{k}-1} (5)

where Γ()\Gamma(\cdot) is the gamma function, izi=1\sum_{i}z_{i}=1 meaning that 𝒛{\bm{z}} belongs to the K1K-1 simplex K1\triangle^{K-1}, βi>0,i\beta_{i}>0,\>\forall i, and β0=iβi\beta_{0}=\sum_{i}\beta_{i}. Here, we have that the mean of 𝐳|𝐱{\mathbf{z}}|{\mathbf{x}} is β/β0\beta/\beta_{0} and β0\beta_{0} controls the sharpness of the density such that more mass centered around the mean as β0\beta_{0} becomes larger.

By the characteristics of the conjugate family, we have the following posterior distribution given 𝒟\mathcal{D}:

p𝐳|𝐱,y=𝒫D(β+c𝒟(𝐱))p_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}=\mathcal{P}^{D}(\beta+c^{\mathcal{D}}({\mathbf{x}})) (6)

where the target posterior mean is explicitly smoothed by the prior belief, and the smoothing operation is performed by the principled way of applying Bayes’ rule. Specifically, the posterior mean is given by 1β0+ici𝒟(𝐱)(β+c𝒟(𝐱))\frac{1}{\beta_{0}+\sum_{i}c_{i}^{\mathcal{D}}({\mathbf{x}})}(\beta+c^{\mathcal{D}}({\mathbf{x}})), in which the prior distribution acts as adding pseudo counts. We note that the relative strength between the prior belief and the empirical count information becomes adaptive with respect to each data point.

3.2 Representing Approximate Distribution

Now, we specify the approximate posterior distribution modeled by the NNs, which aims to approximate p𝐳|𝐱,yp_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}. In this paper, we model the approximate posterior as the Dirichlet distribution. To this end, we use an exponential function g(x)=exp(x)g(x)=\exp(x) to transform logits to the concentration parameter of 𝒫D\mathcal{P}^{D}, and we let α𝑾=expf𝑾\alpha^{{\bm{W}}}=\exp\circ f^{{\bm{W}}}. Then, the NN represents the density over K1\triangle^{K-1} as follows:

q𝐳|𝐱𝑾(𝒛)=Γ(α0𝑾(𝐱))jΓ(αj𝑾(𝐱))k=1Kzkαk𝑾(𝐱)1q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}({\bm{z}})=\frac{\Gamma(\alpha^{{\bm{W}}}_{0}({\mathbf{x}}))}{\prod_{j}\Gamma(\alpha^{{\bm{W}}}_{j}({\mathbf{x}}))}\prod_{k=1}^{K}z_{k}^{\alpha^{{\bm{W}}}_{k}({\mathbf{x}})-1} (7)

where α0𝑾(𝐱)=iαi𝑾(𝐱)\alpha_{0}^{{\bm{W}}}({\mathbf{x}})=\sum_{i}\alpha^{{\bm{W}}}_{i}({\mathbf{x}}).

From equation 7, we can see that outputs under BM encode much more information compared to those under the softmax. Specifically, it can be easily shown that the approximate posterior mean corresponds to the softmax. In this regard, BM enables neural networks to represent more rich information in their outputs, i.e., the density over K1\triangle^{K-1} itself not just a single point on it such as the mean. This capability allows capturing more diverse characteristics of predictions at different locations, such as how much concentrate its density around the center mass point, which can be extremely helpful in many applications. For instance, BM gives a more sophisticated measure of the difference between predictions of two neural networks, which can benefit the consistency-based loss for semi-supervised learning as we will show in section 5.4. Besides, BM represents a more sophisticated measure of predictive uncertainty based on the density over simplex, such as mutual information.

From the perspective of learning the target distribution, BM can be considered as a generalization of softmax in terms of changing the moment matching problem to the distribution matching problem in 𝒫(K1)\mathcal{P}(\triangle^{K-1}). To understand the distribution matching objective in BM, we reformulate equation 7 as follows:

q𝐳|𝐱𝑾(𝒛)exp(kαk𝑾(𝐱)logzkklogzk)exp(lCE(ϕ(f𝑾(𝐱)),𝒛)+KL(𝒫U𝒫C(𝒛))α0𝑾(𝐱)/K)q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}({\bm{z}})\propto\exp\left(\sum_{k}\alpha^{{\bm{W}}}_{k}({\mathbf{x}})\log z_{k}-\sum_{k}\log z_{k}\right)\\ \propto\exp\left(-l_{CE}(\phi(f^{{\bm{W}}}({\mathbf{x}})),{\bm{z}})+\frac{KL(\mathcal{P}^{U}\parallel\mathcal{P}^{C}({\bm{z}}))}{\alpha^{{\bm{W}}}_{0}({\mathbf{x}})/K}\right) (8)

In the limit of q𝐳|𝐱𝑾p𝐳|𝐱,yq_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}\rightarrow p_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}, mean of the target posterior (equation 6) becomes a virtual label, for which individual 𝒛{\bm{z}} ought to match; the penalty for ambiguous configuration 𝒛{\bm{z}} is determined by the number of observations. Therefore, the distribution matching in BM can be thought of as learning to score a categorical probability based on closeness to the target posterior mean, in which exploitation of the closeness information is automatically controlled by the data.

3.3 Distribution Matching

We have defined the target distribution p𝐳|𝐱,yp_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}} and the approximate distribution modeled by the neural network q𝐳|𝐱𝑾q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}. We now present a solution to the distribution matching problem with maximizing the evidence lower bound (ELBO), defined by lEB(y,α𝑾(𝒙))=𝔼q𝐳|𝐱𝑾[logp(y|𝐱,𝐳)]KL(q𝐳|𝐱𝑾p𝐳|𝐱)l_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\bm{x}}))=\mathbb{E}_{q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}}[\log p({\textnormal{y}}|{\mathbf{x}},{\mathbf{z}})]-KL(q^{{\bm{W}}}_{{\mathbf{z}}|{\mathbf{x}}}\parallel p_{{\mathbf{z}}|{\mathbf{x}}}). Using the ELBO can be motivated by the following equality (Jordan et al., 1999):

logp(y|𝐱)=q𝐳|𝐱𝑾(𝒛)log(p(y,𝒛|𝐱)p(𝒛|𝐱,y))𝑑𝒛=lEB(y,α𝑾(𝒙))KL(q𝐳|𝐱𝑾p𝐳|𝐱,y)\log p({\textnormal{y}}|{\mathbf{x}})=\int q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}({\bm{z}})\log\left(\frac{p({\textnormal{y}},{\bm{z}}|{\mathbf{x}})}{p({\bm{z}}|{\mathbf{x}},{\textnormal{y}})}\right)d{\bm{z}}\\ =l_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\bm{x}}))-KL(q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}\parallel p_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}) (9)

where we can see that maximizing lEB(y,α𝑾(𝐱))l_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}})) corresponds to minimizing KL(q𝐳|𝐱𝑾p𝐳|𝐱,y)KL(q_{{\mathbf{z}}|{\mathbf{x}}}^{{\bm{W}}}\parallel p_{{\mathbf{z}}|{\mathbf{x}},{\textnormal{y}}}), i.e., matching the approximate distribution to the target distribution, because the KL-divergence is non-negative and logp(y|𝐱)\log p({\textnormal{y}}|{\mathbf{x}}) is a constant with respect to 𝑾{\bm{W}}. Here, each term in the ELBO can be analytically computed by:

𝔼q𝐳|𝐱[logp(y|𝐱,𝐳)]=𝔼q𝐳|𝐱[logzy]=ψ(αy𝑾(𝐱))ψ(α0𝑾(𝐱))\mathbb{E}_{q_{{\mathbf{z}}|{\mathbf{x}}}}\left[\log p({\textnormal{y}}|{\mathbf{x}},{\mathbf{z}})\right]\\ =\mathbb{E}_{q_{{\mathbf{z}}|{\mathbf{x}}}}[\log{\textnormal{z}}_{{\textnormal{y}}}]=\psi(\alpha^{{\bm{W}}}_{{\textnormal{y}}}({\mathbf{x}}))-\psi(\alpha^{{\bm{W}}}_{0}({\mathbf{x}})) (10)

where ψ()\psi(\cdot) is the digamma function (the logarithmic derivative of Γ()\Gamma(\cdot)), and

KL(q𝐳|𝐱𝑾p𝐳|𝐱)=logΓ(α0𝑾(𝐱))kΓ(βk)kΓ(αk𝑾(𝐱))Γ(β0)+k(αk𝑾(𝐱)βk)(ψ(αk𝑾(𝐱))ψ(α0𝑾(𝐱)))KL(q^{{\bm{W}}}_{{\mathbf{z}}|{\mathbf{x}}}\parallel p_{{\mathbf{z}}|{\mathbf{x}}})=\log\frac{\Gamma(\alpha^{{\bm{W}}}_{0}({\mathbf{x}}))\prod_{k}\Gamma(\beta_{k})}{\prod_{k}\Gamma(\alpha^{{\bm{W}}}_{k}({\mathbf{x}}))\Gamma(\beta_{0})}\\ +\sum_{k}\left(\alpha^{{\bm{W}}}_{k}({\mathbf{x}})-\beta_{k}\right)\left(\psi(\alpha^{{\bm{W}}}_{k}({\mathbf{x}}))-\psi(\alpha^{{\bm{W}}}_{0}({\mathbf{x}}))\right) (11)

where p𝐳|𝐱p_{{\mathbf{z}}|{\mathbf{x}}} is assumed to be an input independent conjugate prior for simplicity; that is, p𝐳|𝐱=𝒫D(β)p_{{\mathbf{z}}|{\mathbf{x}}}=\mathcal{P}^{D}(\beta). With this analytical solution, we maximizes the ELBO with mini-batch approximation, which gives the following loss function: (𝑾)=𝔼𝐱,y[lEB(y,α𝑾(𝐱))]1mi=1mlEB(y(i),α𝑾(𝒙(i)))\mathcal{L}({\bm{W}})=\mathbb{E}_{{\mathbf{x}},{\textnormal{y}}}[l_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}}))]\approx\frac{1}{m}\sum_{i=1}^{m}l_{EB}(y^{(i)},\alpha^{{\bm{W}}}({\bm{x}}^{(i)})). We note that computations of the ELBO and its gradient have a complexity of 𝒪(K)\mathcal{O}(K) per sample, which is equal to those of softmax. This means that BM can preserve the scalability and the efficiency of the existing baseline. We also note that the analytical solution of the ELBO under BM allows to implement the distribution matching loss as a plug-and-play loss function applied to the logit directly.

3.4 On Prior Distributions

The success of the Bayesian approach largely depends on how we specify the prior distribution due to its impact on the resulting posterior distribution. For example, the target posterior mean in equation 6 becomes the counting estimator as β00\beta_{0}\rightarrow 0. On the contrary, as β0\beta_{0} becomes higher, the effect of empirical counting information is weakened, and eventually disappeared in the limit of β0\beta_{0}\rightarrow\infty. Therefore, considering that most of the inputs are unique in 𝒟\mathcal{D}, choosing small β0\beta_{0} is appropriate for prevents the resulting posterior distribution from being dominated by the prior222In an ideal fully Bayesian treatment, β\beta can be modeled hierarchically, and we left this as future research..

However, a prior distribution with small β0\beta_{0} implicitly makes α0𝑾(𝐱)\alpha_{0}^{{\bm{W}}}({\mathbf{x}}) small, which poses significant challenges on the gradient-based optimization. This is because the gradient of the ELBO is notoriously large in the small-value regimes of α0𝑾(𝐱)\alpha_{0}^{{\bm{W}}}({\mathbf{x}}), e.g., ψ(0.01)>10000\psi^{\prime}(0.01)>10000. In addition, our various building blocks including normalization (Ioffe & Szegedy, 2015), initialization (He et al., 2015), and architecture (He et al., 2016a) are implicitly or explicitly designed to make 𝔼[f𝑾(𝐱)]0\mathbb{E}[f^{{\bm{W}}}({\mathbf{x}})]\approx\textbf{0}; that is, 𝔼[α𝑾(𝐱)] 1\mathbb{E}[\alpha^{{\bm{W}}}({\mathbf{x}})]\approx\textbf{ 1}. Therefore, making α0𝑾(𝐱)\alpha_{0}^{{\bm{W}}}({\mathbf{x}}) small can be wasteful or requires huge modifications to the existing building blocks. Also, 𝔼[α𝑾(𝐱)]1\mathbb{E}[\alpha^{{\bm{W}}}({\mathbf{x}})]\approx\textbf{1} is encouraged in a sense of natural gradient (Amari, 1998), which improves the conditioning of Fisher information matrix (Schraudolph, 1998; LeCun et al., 1998; Raiko et al., 2012; Wiesler et al., 2014).

In order to resolve the gradient-based optimization challenge in learning the posterior distribution while preventing dominance of the prior distribution, we set β=1\beta=\textbf{1} for the prior distribution and then multiply λ\lambda to the KL divergence term in the ELBO: lEBλ(y,α𝑾(𝐱))=𝔼q𝐳|𝐱[logp(y|𝐱,𝐳)]λKL(q𝐳|𝐱𝑾𝒫D(1))l^{\lambda}_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}}))=\mathbb{E}_{q_{{\mathbf{z}}|{\mathbf{x}}}}\left[\log p({\textnormal{y}}|{\mathbf{x}},{\mathbf{z}})\right]-\lambda KL(q^{{\bm{W}}}_{{\mathbf{z}}|{\mathbf{x}}}\parallel\mathcal{P}^{D}(\textbf{1})). This trick significantly stabilizes the optimization process, while making a local optimal point remains unchanged. To see this, we can compare the gradients of the ELBO and the lambda multiplied ELBO:

lEB(y,α𝑾(𝐱))αk𝑾(𝒙)=(𝐲~k(αk𝑾(𝒙)βk))ψ(αk𝑾(𝒙))(1(α0𝑾(𝒙)β0))ψ(α0𝑾(𝒙))\frac{\partial l_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}}))}{\partial\alpha_{k}^{{\bm{W}}}({\bm{x}})}=\left(\tilde{{\mathbf{y}}}_{k}-(\alpha_{k}^{{\bm{W}}}({\bm{x}})-\beta_{k})\right)\psi^{\prime}(\alpha_{k}^{{\bm{W}}}({\bm{x}}))\\ -\left(1-(\alpha_{0}^{{\bm{W}}}({\bm{x}})-\beta_{0})\right)\psi^{\prime}(\alpha_{0}^{{\bm{W}}}({\bm{x}})) (12)
lEBλ(y,α𝑾(𝐱))αk𝑾(𝐱)=(𝐲~k(α~k𝑾(𝐱)λ))ψ(α~k𝑾(𝐱))ψ(α~0𝑾(𝐱))(1(α~0𝑾(𝐱)λK))\frac{\partial l^{\lambda}_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}}))}{\partial\alpha_{k}^{{\bm{W}}}({\mathbf{x}})}=\left(\tilde{{\mathbf{y}}}_{k}-(\tilde{\alpha}_{k}^{{\bm{W}}}({\mathbf{x}})-\lambda)\right)\frac{\psi^{\prime}(\tilde{\alpha}_{k}^{{\bm{W}}}({\mathbf{x}}))}{\psi^{\prime}(\tilde{\alpha}^{{\bm{W}}}_{0}({\mathbf{x}}))}\\ -\left(1-(\tilde{\alpha}_{0}^{{\bm{W}}}({\mathbf{x}})-\lambda K)\right) (13)

where α~k𝑾(𝐱)=λαk𝑾(𝐱)\tilde{\alpha}_{k}^{{\bm{W}}}({\mathbf{x}})=\lambda\alpha_{k}^{{\bm{W}}}({\mathbf{x}}). Here, we can see that a local optimal in equation 12 is achieved when α𝑾(𝐱)=β+𝐲~\alpha^{{\bm{W}}}({\mathbf{x}})=\beta+\tilde{{\mathbf{y}}} and a local optima for equation 13 is α𝑾(𝐱)=1+1λ𝐲~\alpha^{{\bm{W}}}({\mathbf{x}})=1+\frac{1}{\lambda}\tilde{{\mathbf{y}}}. Therefore, a ratio between αi𝑾(𝐱)\alpha_{i}^{{\bm{W}}}({\mathbf{x}}) and αj𝑾(𝐱)\alpha_{j}^{{\bm{W}}}({\mathbf{x}}) equal to those of a local optimal point in equation 12 for every pair of ii and jj. In this regard, searching for λ\lambda with lEBλ(y,α𝑾(𝐱))l^{\lambda}_{EB}({\textnormal{y}},\alpha^{{\bm{W}}}({\mathbf{x}})) and then multiplying λ\lambda after training corresponds to the process of searching for the prior distribution’s parameter β\beta with (𝑾)\mathcal{L}({\bm{W}}).

4 Related Work

BNNs are the dominant approach for applying Bayesian principles in neural networks. Because BNNs require the intractable posterior inference, many posterior approximation schemes have been developed to reduce the approximation gap and improve scalability (e.g., VI (Graves, 2011; Blundell et al., 2015; Wu et al., 2019a) and stochastic gradient Markov Chain Monte Carlo (Welling & Teh, 2011; Ma et al., 2015; Gong et al., 2019)). However, even with these novel approximation techniques, BNNs are not scalable to state-of-the-art architectures in large-scale datasets or they often reduce the generalization performance in practice, which impedes the wide adoption of BNNs despite their numerous potential benefits.

Other approaches avoid explicit modeling of the weight posterior distribution. MC dropout (Gal & Ghahramani, 2016) reinterprets the dropout (Srivastava et al., 2014) as an approximate VI, which retains the standard NN training procedure and modifies only the inference procedure for posterior MC approximation. In a similar spirit, some approaches (Mandt et al., 2017; Zhang et al., 2018; Maddox et al., 2019; Osawa et al., 2019) sequentially estimate the mean and covariance of the weight posterior distribution by using gradients computed at each step. As different from the BNNs, Deep kernel learning (Wilson et al., 2016a, b) places Gaussian processes (GPs) on top of the “deterministic” NNs, which combines NNs’ capability of handling complex high dimensional data and GPs’ capability of principled uncertainty representation and robust extrapolation.

Non-Bayesian approaches also help to resolve the limitations of softmax. Lakshminarayanan et al. (2017) propose an ensemble-based method to achieve better uncertainty representation and improved self-calibration. Both Guo et al. (2017) and Neumann et al. (2018) proposed temperature scaling-based methods for post-hoc modifications of softmax for improved calibration. To improve generalization by penalizing over-confidence, Pereyra et al. (2017) propose an auxiliary loss function that penalizes low predictive entropy, and Szegedy et al. (2016) and Xie et al. (2016) consider the types of noise included in ground-truth labels.

We also note that some recent studies use NNs to model the concentration parameter of the Dirichlet distribution but with a different purpose than BM. Sensoy et al. (2018) uses the loss function of explicitly minimizing prediction variances on training samples, which can help to produce high uncertainty prediction for out-of-distribution (OOD) or adversarial samples. Prior network (Malinin & Gales, 2018) investigates two types of auxiliary losses computed on in-distribution and OOD samples, respectively. Similar to prior network, Chen et al. (2018) considers an auxiliary loss computed on adversarially generated samples.

5 Experiment

In this section, we show versatility of BM through extensive empirical evaluations. We first verify its improvement of the generalization error in image classification tasks (section 5.1). We then verify whether BM inherits the advantages of the Bayesian method by placing the prior distribution only on the label categorical probability (section 5.2). We conclude this section by providing further applications that shows versatility of BM. To support reproducibility, we release our code at: https://github.com/tjoo512/belief-matching-framework. We performed all experiments on a single workstation with 8 GPUs (NVIDIA GeForce RTX 2080 Ti).

Throughout all experiments, we employ various large-scale models based on a residual connection (He et al., 2016a), which are the standard benchmark models in practice. For fair comparison and reducing burden of hyperparameter search, we fix experimental configurations to the reference implementation of corresponding architecture. However, we additionally use an initial learning rate warm-up and gradient clipping, which are extremely helpful for stable training of BM. Specifically, we use learning rates of [0.1ϵ\epsilon, 0.2ϵ\epsilon, 0.4ϵ\epsilon, 0.6ϵ\epsilon, 0.8ϵ\epsilon] for first five epochs when the reference learning rate is ϵ\epsilon and clip gradient when its norm exceeds 1.0. Without these methods, we had difficulty in training deep models, e.g., ResNet-50, due to gradient explosion at an initial stage of training.

We compare BM to following baseline methods: softmax, which is our primary object to improve; MC dropout with 100 MC samples, which is a simple and efficient BNN; deep ensemble with five NNs, which greatly improves the uncertainty representation ability. While there are other methods using NNs to model the Dirichlet distribution (Sensoy et al., 2018; Malinin & Gales, 2018, 2019), we note that these methods are not scalable to ResNet. Similarly, we observe that training a mixture of Dirichlet distributions (Wu et al., 2019b) with ResNet is subject to the gradient explosion, even with a 10x lower learning rate. Besides, BNNs with VI (or MCMC) are not directly comparable to our approach due to their huge modifications to existing baselines. For example, Heek & Kalchbrenner (2019) replace batch normalization and ReLU, use additional techniques (2x more filters, cyclic learning rate, snapshot ensemble), and require almost 10x more computations on ImageNet to converge.

Table 1: Test classification error rates on CIFAR. Here, we split a train set of 50K examples into a train set of 40K examples and a validation set of 10K example. Numbers indicate μ±σ\mu\pm\sigma computed across five trials, and boldface indicates the minimum mean error rate. Model and hyperparameter are selected based on validation error rates. We searched for the coefficients of BM over {0.01,0.003,0.001}\left\{0.01,0.003,0.001\right\} and MC dropout over {0.1,0.2,0.5}\left\{0.1,0.2,0.5\right\}.
Model Method C-10 C-100
Res-18 Softmax 6.13 ±0.13\pm 0.13 26.44 ±0.33\pm 0.33
MC Drop (last) 6.13 ±0.08\pm 0.08 26.15 ±0.10\pm 0.10
MC Drop (all) 6.50 ±0.14\pm 0.14 27.32 ±0.45\pm 0.45
BM 5.93 ±0.07\pm 0.07 24.19 ±0.34\pm 0.34
Res-50 Softmax 5.76 ±0.06\pm 0.06 25.00 ±0.23\pm 0.23
MC Drop (last) 5.75 ±0.22\pm 0.22 25.17 ±0.09\pm 0.09
MC Drop (all) 5.84 ±0.23\pm 0.23 26.74 ±0.37\pm 0.37
BM 5.59 ±0.05\pm 0.05 23.86 ±0.37\pm 0.37

5.1 Generalization Performance

We evaluate the generalization performance of BM on CIFAR (Krizhevsky, 2009) with the pre-activation ResNet (He et al., 2016b). CIFAR-10 and CIFAR-100 contain 50K training and 10K test images, and each 32x32x3-sized image belongs to one of 10 categories in CIFAR-10 and one of 100 categories in CIFAR-100. Table 1 lists the classification error rates of the softmax cross-entropy loss, BM, and MC-dropout. In all configurations, BM consistently achieves the best generalization performance on both datasets. On the other hand, last-layer MC dropout sometimes results in higher generalization errors than softmax and all-layer MC-dropout significantly increases error rates even though they consume 100x more computations for inference.

We next perform a large-scale experiment using ResNext-50 32x4d and ResNext-101 32x8d (Xie et al., 2017) on ImageNet (Russakovsky et al., 2015). ImageNet contains approximately 1.3M training samples and 50K validation samples, and each sample is resized to 224x224x3 and belongs to one of the 1K categories; that is, the ImageNet has more categories, a larger image size, and more training samples compared to CIFAR, which may enable a more precise evaluation of methods. Consistent with the results on CIFAR, BM improves test errors of softmax (Table 2). This result is appealing because improving the generalization error of deep NNs on large-scale datasets by adopting a Bayesian principle without computational overhead has rarely been reported in the literature.

Table 2: Classification error rates on the ImageNet. Here, we use only λ=0.001\lambda=0.001 for ResNext-50 and λ=0.0001\lambda=0.0001 for ResNext-101, and measure the validation error rates directly. We report the result obtained by single experiment due to computational constraint.
Model Method Top1 Top5
ResNext-50 Softmax 22.23 6.36
BM 22.03 6.29
ResNext-101 Softmax 20.72 5.59
BM 20.23 5.26
Refer to caption
Refer to caption
Figure 2: Penultimate layer’s activations of examples belonging to one of three classes (beaver, dolphin, and otter; indexed by 0,1,2 in CIFAR-100).

Regularization Effect of Prior

In theory, BM has two regularization effects, which may explain the generalization performance improvements under BM: the prior distribution, which smooths the target posterior mean by adding pseudo counts, and computing the distribution matching loss by averaging of all possible categorical probabilities. In this regard, the ablation of the KL term in the ELBO helps to examine these two effects separately, which removes only the effect of the prior distribution.

We first examine its impact on the generalization performance by training a ResNet-50 on CIFAR without the KL term. The resulting test error rates were 5.68% on CIFAR-10 and 24.69%\textbf{24.69}\% on CIFAR-100. These significant reductions in generalization performances indicates the powerful regularization effect of the prior distribution (cf. Table 1). The result that BM without the KL term still achieves lower test error rates compared to softmax demonstrates the regularization effect of considering all possible categorical probabilities by the Dirichlet distribution instead of choosing single categorical probability.

Considering the role of the prior distribution on smoothing the posterior mean, we conjecture that the impact of the prior distribution is similar to the effect of label smoothing. In Müller et al. (2019), it is shown that label smoothing makes learned representation reveal tight clusters of data points within the same classes and smaller deviations among the data points. Inspired by this result, we analyze the activations in the penultimate layer with the visualization method proposed in Müller et al. (2019). Figure 2 illustrates that the prior distribution significantly reduces the value ranges of the activations of data points, which implies the implicit function regularization effect considering that the LpL^{p} norm of f𝑾Lp(𝒳)f^{{\bm{W}}}\in L^{p}(\mathcal{X}) can be approximated by f𝑾p(1Ni|f𝑾(𝒙(i))|p)1/p\parallel f^{{\bm{W}}}\parallel_{p}\approx\left(\frac{1}{N}\sum_{i}|f^{{\bm{W}}}({\bm{x}}^{(i)})|^{p}\right)^{1/p}. Besides, Figure 2 shows that the prior distribution makes activations belong to the same class to form much tighter clusters, which can be thought of as the implicit manifold regularization effect. To see this, assume that two images belonging to the same class have close distance in the data manifold. Then, the difference between logits of same class examples becomes a good proxy for the gradient of f𝑾f^{{\bm{W}}} along the data manifold since the gradient measures changes in the output space with respect to small changes in the input space.

Impact of β\beta

In section 3.4, we claimed that a value of β\beta is implicitly related to the distribution of logit values, and its extreme value can be detrimental to the training stability. We verify this claim by training ResNet-18 on CIFAR-10 with different values of β\beta. Specifically, we examine two strategies of changing β\beta: modifying only β\beta or jointly modifying the lambda proportional to 1/β1/\beta to match local optima (cf. section 3.4). As a result, we obtain a robust generalization performance in both strategies when β[exp(1),exp(4)]\beta\in[\exp(-1),\exp(4)] (Figure 3). However, when β\beta becomes extremely small (exp(2)\exp(-2) when changing only β\beta and exp(8)\exp(-8) when jointly tuning λ\lambda and β\beta), the gradient explosion occurs due to extreme slope of the digamma near 0. Conversely, when we increase only β\beta to extremely large value, the error rate increases by a large margin (7.37) at β=exp(8)\beta=\exp(8), and eventually explodes at β=exp(16)\beta=\exp(16). This is because large beta increases the values of activations, so the gradient with respect to parameters explodes. Under the joint tuning strategy, such a high values region makes λ0\lambda\approx 0, which removes the impact of the prior distribution.

Refer to caption
Figure 3: Impact of β\beta on generalization performance. We exclude the ranges β<exp(2)\beta<\exp(-2) and β>exp(8)\beta>\exp(8) because the ranges result in gradient explosion under the strategy of changing only β\beta

5.2 Uncertainty Representation

One of the most attractive benefits of Bayesian methods is their ability to represent the uncertainty about their predictions. In a naive sense, uncertainty representation ability is the ability to “know what it doesn’t know.” For instance, models having a good uncertainty representation ability would increase some form of predictive uncertainty on misclassified examples compared to those on correctly classified examples. This ability is extremely useful in both real-world applications and downstream tasks in machine learning. For example, underconfident NNs can produce many false alarms, which makes humans ignore the predictions of NNs; conversely, overconfident NNs can exclude humans from the decision-making loop, which results in catastrophic accidents. Also, better uncertainty representation enables balancing exploitation and exploration in reinforcement learning (Gal & Ghahramani, 2016) and detecting OOD samples (Malinin & Gales, 2018).

We evaluate the uncertainty representation ability on both in-distribution and OOD datasets. Specifically, we measure the calibration performance on in-distribution test samples, which examines a model’s ability to match its probabilistic output associated with an event to the actual long-term frequency of the event (Dawid, 1982). The notion of calibration in NNs is associated with how well its confidence matches the actual accuracy; e.g., we expect the average accuracy of a group of predictions having the confidence around 0.70.7 to be close to 70%70\%. We also examine the predictive uncertainty for OOD samples. Since the examples belong to none of the classes seen during training, we expect neural networks to produce outputs of “I don’t know.”

Refer to caption
(a) CIFAR-10
Refer to caption
(b) CIFAR-100
Figure 4: Reliability plots of ResNet-50 with BM and softmax. Here, ECE is computed with 15 groups.

In-Distribution Uncertainty

We measure the calibration performance by the expected calibration error (ECE; Naeini et al., 2015), in which maxiϕi(f𝑾(𝒙))\max_{i}\phi_{i}(f^{{\bm{W}}}({\bm{x}})) is regarded as a prediction confidence for the input 𝒙{\bm{x}}. ECE is calculated by grouping predictions based on the confidence score and then computing the absolute difference between the average accuracy and average confidence for each group; that is, the ECE of f𝑾f^{{\bm{W}}} on 𝒟\mathcal{D} with MM groups is as follows:

ECEM(f𝑾,𝒟)=i=1M|𝒢i||𝒟||acc(𝒢i)conf(𝒢i)|ECE^{M}(f^{{\bm{W}}},\mathcal{D})=\sum_{i=1}^{M}\frac{|\mathcal{G}_{i}|}{|\mathcal{D}|}|\text{acc}(\mathcal{G}_{i})-\text{conf}(\mathcal{G}_{i})| (14)

where 𝒢i\mathcal{G}_{i} is a set of samples in the ii-th group, defined as 𝒢i={j:i/M<maxkϕk(f𝑾(𝒙(j)))(1+i)/M}\mathcal{G}_{i}=\left\{j:i/M<\max_{k}\phi_{k}(f^{{\bm{W}}}({\bm{x}}^{(j)}))\leq(1+i)/M\right\}, acc(𝒢i)\text{acc}(\mathcal{G}_{i}) is an average accuracy in the ii-th group, and conf(𝒢i)\text{conf}(\mathcal{G}_{i}) is an average confidence in the ii-th group.

We analyze the calibration property of ResNet-50 examined in section 5.1. As Figure 4 presents, BM’s predictive probability is well matched to its accuracy compared to softmax–that is, BM improves the calibration property of NNs. Specifically, BM improves ECE of softmax from 3.82 to 1.66 on CIFAR-10 and from 13.48 to 4.25 CIFAR-100, respectively. These improvements are comparable to the deep ensemble, which achieves 1.04 on CIFAR-10 and 3.54 on CIFAR-100 with 5x more computations for both training and inference. In the case of all-layer MC dropout, ECE decreases to 1.50 on CIFAR-10 and 9.76 on CIFAR-100, however, these improvements require to compromise the generalization performance. On the other hand, the last-layer MC dropout, which often improves the generalization performance, does not show meaningful ECE improvements (3.78 on CIFAR-10 and 13.52 on CIFAR-100). We note that there are also post-hoc solutions improving calibration performance, e.g., temperature scaling (Guo et al., 2017). However, these methods often require an additional dataset for tuning the behavior of NNs, which may prevent the exploitation of entire samples to train NNs.

Refer to caption
(a) Softmax
Refer to caption
(b) BM
Refer to caption
(c) Deep ensemble
Refer to caption
(d) MC dropout (all)
Figure 5: Uncertainty representation for in-distribution (CIFAR-100) and OOD (SVHN) of ResNet-50 under softmax and BM. We exclude the result of last-layer MC dropout because it shows no meaningful difference compared to the softmax.

Out-of-Distribution Uncertainty

We quantify uncertainty by predictive entropy, which measures the uncertainty of f𝑾(𝒙)f^{{\bm{W}}}({\bm{x}}) as follows: H[𝔼q𝐳|𝒙𝑾[𝐳]]=H[ϕ(f𝑾(𝒙))]=k=1Kϕk(f𝑾(𝒙))logϕk(f𝑾(𝒙))H[\mathbb{E}_{q^{{\bm{W}}}_{{\mathbf{z}}|{\bm{x}}}}[{\mathbf{z}}]]=H[\phi(f^{{\bm{W}}}({\bm{x}}))]=-\sum_{k=1}^{K}\phi_{k}(f^{{\bm{W}}}({\bm{x}}))\log\phi_{k}(f^{{\bm{W}}}({\bm{x}})). This uncertainty measure gives intuitive interpretation such that the “I don’t know” the answer is close to the uniform distribution; conversely, “I confidently know” answer has one dominating categorical probability.

Figure 5 presents density plots of the predictive entropy, showing that BM provides notably better uncertainty estimation compared to other methods. Specifically, BM makes clear peaks of predictive entropy in high uncertainty region for OOD samples (Figure 5(b)). In contrast, softmax produces relatively flat uncertainty for OOD samples (Figure 5(a)). Even though both MC dropout and deep ensemble successfully increase predictive uncertainty for OOD samples compared to softmax, they fail to make a clear peak on the high uncertainty region for such samples, unlike BM. We note that this remarkable result is obtained by being Bayesian only about the categorical probability.

Note that some in-distribution samples should be answered as “I don’t know” because the network does not achieve perfect test accuracy. As Figure 5 shows, BM contains more samples of high uncertainty for in-distribution samples compared to softmax that is almost certain in its predictions. This result consistently supports the previous result that BM resolves the overconfidence problem of softmax.

Table 3: Transfer learning performance (test error rates) from ResNet-50 pretrained on ImageNet to smaller datasets. μ\mu and σ\sigma are obtained by five experiments, and boldface indicates the minimum mean error rate. We examine only λ=0.01\lambda=0.01 for BM.
Method C-10 Food-101 Cars
Softmax 5.44 ±0.10\pm 0.10 28.49 ±0.08\pm 0.08 42.99 ±0.14\pm 0.14
BM 5.03 ±0.04\pm 0.04 26.41 ±0.07\pm 0.07 39.99 ±0.20\pm 0.20

5.3 Transfer Learning

BM adopts the Bayesian principle outside the NNs, so it can be applied to models already trained on different tasks, unlike BNNs. In this regard, we examine the effectiveness of BM on the transfer learning scenario. Specifically, we downloaded the ImageNet-pretrained ResNet-50, and fine-tune weights of the last linear layer for 100 epochs by the Adam optimizer (Kingma & Ba, 2015) with learning rate of 3e-4 on three different datasets (CIFAR-10, Food-101 (Bossard et al., 2014), and Cars (Krause et al., 2013).

Table 3 compares test error rates of softmax and BM, in which BM consistently achieves better performances compared to softmax. Next, we examine the predictive uncertainty for OOD samples (Figure 6). Surprisingly, we observe that BM significantly improves the uncertainty representation ability of pretrained-models by only fine-tuning the last layer weights. These results present a possibility of adopting BM as a post-hoc solution to enhance the uncertainty representation ability of pretrained models without sacrificing their generalization performance. We believe that the interaction between BM and the pretrained models is significantly attractive, considering recent efforts of the deep learning community to construct general baseline models trained on extremely large-scale datasets and then transfer the baselines to multiple down-stream tasks (e.g., BERT (Devlin et al., 2018) and MoCo (He et al., 2019)).

Refer to caption
(a) Softmax
Refer to caption
(b) BM
Figure 6: Uncertainty representation for in-distribution samples (CIFAR-10) and OOD samples (SVHN, Foods, and Cars) in transfer learning tasks. BM produces clear peaks in high uncertainty region on SVHN and Food-101. We note that BM confidently predicts examples in Cars because CIFAR-10 contains the object category of “automobile”. On the other hand, softmax produces confident predictions on all datasets compared to BM.

5.4 Semi-Supervised Learning

BM enables NNs to represent rich information in their predictions (cf. section 3.2). We exploit this characteristic to benefit consistency-based loss functions for semi-supervised learning. The idea of consistency-based losses employs information of unlabelled samples to determine where to promote robustness of predictions under stochastic perturbations (Belkin et al., 2006; Oliver et al., 2018). In this section, we investigate two baselines that consider stochastic perturbations on inputs (VAT; Miyato et al., 2018) and networks (Π\Pi-model; Laine & Aila. 2017), respectively. Specifically, VAT generates adversarial direction 𝒓{\bm{r}}, then measures KL-divergence between predictions at 𝒙{\bm{x}} and 𝒙+𝒓{\bm{x}}+{\bm{r}}:

VAT(𝒙)=KL(ϕ(f𝑾(𝒙))ϕ(f𝑾(𝒙+𝒓)))\mathcal{L}_{VAT}({\bm{x}})=KL(\phi(f^{{\bm{W}}}({\bm{x}}))\parallel\phi(f^{{\bm{W}}}({\bm{x}}+{\bm{r}}))) (15)

where the adversarial direction is chosen by 𝒓=argmax𝒓ϵKL(ϕ(f𝑾(𝒙))ϕ(f𝑾(𝒙+𝒓))){\bm{r}}=\operatorname*{arg\,max}_{\parallel{\bm{r}}\parallel\leq\epsilon}KL(\phi(f^{{\bm{W}}}({\bm{x}}))\parallel\phi(f^{{\bm{W}}}({\bm{x}}+{\bm{r}}))), and Π\Pi-model measures L2L^{2} distance between predictions with and without enabling stochastic parts in NNs:

Π(𝒙)=ϕ(f¯𝑾(𝒙))ϕ(f𝑾(𝒙))22\mathcal{L}_{\Pi}({\bm{x}})=\parallel\phi(\bar{f}^{{\bm{W}}}({\bm{x}}))-\phi(f^{{\bm{W}}}({\bm{x}}))\parallel_{2}^{2} (16)

where f¯\bar{f} is a prediction without the stochastic parts.

We can see that both methods achieve the perturbation invariant predictions by minimizing the divergence between two categorical probabilities under some perturbations. In this regard, BM can provide a more delicate measure of the prediction consistency–divergence between Dirichlet distributions–that can capture richer probabilistic structures, e.g., (co)variances of categorical probabilities. This generalization the moment matching problem to the distribution matching problem can be achieved by replacing only the consistency measures in equation 15 with KL(q𝐳|𝒙𝑾q𝐳|𝒙+𝒓𝑾)KL(q_{{\mathbf{z}}|{\bm{x}}}^{{\bm{W}}}\parallel q_{{\mathbf{z}}|{\bm{x}}+{\bm{r}}}^{{\bm{W}}}) and in equation 16 with KL(q¯𝐳|𝒙𝑾q𝐳|𝒙𝑾)KL(\bar{q}_{{\mathbf{z}}|{\bm{x}}}^{{\bm{W}}}\parallel q_{{\mathbf{z}}|{\bm{x}}}^{{\bm{W}}}).

We train wide ResNet 28-2 (Zagoruyko & Komodakis, 2016) via the consistency-based loss functions on CIFAR-10 with 4K/41K/5K number of labeled training/unlabeled training/validation samples. Our experimental results show that the distribution matching metric of BM is more effective than the moment matching metric under the softmax on reducing the error rates (Table 4). We think that the improvement in semi-supervised learning with more sophisticated consistency measures shows the potential usefulness of BM on other useful applications. This is because employing the prediction difference of neural networks is a prevalent method in many domains such as knowledge distillation (Hinton et al., 2015) and model interpretation (Zintgraf et al., 2017).

Table 4: Classification error rates on CIFAR-10. μ\mu and σ\sigma are obtained by five experiments, and boldface indicates the minimum mean error rate. We matched the configurations to those of Oliver et al. (2018) except for a consistency loss coefficient of 0.03 for VAT and 0.5 for Π\Pi-model to match the scale between supervised and unsupervised losses. We use only λ=0.01\lambda=0.01 for BM.
Method Π\Pi-Model VAT
Softmax 16.52 ±0.21\pm 0.21 13.33 ±0.37\pm 0.37
BM 16.01 ±0.36\pm 0.36 12.40 ±0.23\pm 0.23

6 Conclusion

We adopted the Bayesian principle for constructing the target distribution by considering the categorical probability as a random variable rather than being given by the training label. The proposed method can be flexibly applied to the standard deep learning models by replacing only the softmax and the cross-entropy loss, which provides the consistent improvements in generalization performance, better uncertainty estimation, and well-calibrated behavior. We believe that BM shows promising advantages of being Bayesian about categorical probability.

We think that accommodating more expressive distributions in the belief matching framework is an interesting future direction. For example, parameterizing the logistic normal distribution (or mixture distribution) can make neural networks to capture strong semantic similarities among class labels, which would be helpful in large-class classification problems such as machine translation and classification on the ImageNet. Besides, considering the input dependent prior would result in interesting properties. For example, under the teacher-student framework, the teacher can dictate the prior for each input, thereby control the desired smoothness of the student’s prediction on the location. This property can benefit various domains such as imbalanced datasets and multi-domain learning.

Acknowledgements

We would like to thank Dong-Hyun Lee and anonymous reviewers for the discussions and suggestions.

References

  • Amari (1998) Amari, S.-I. Natural gradient works efficiently in learning. Neural Computation, 10(2):251–276, 1998.
  • Belkin et al. (2006) Belkin, M., Niyogi, P., and Sindhwani, V. Manifold regularization: A geometric framework for learning from labeled and unlabeled examples. Journal of Machine Learning Research, 7(Nov):2399–2434, 2006.
  • Blundell et al. (2015) Blundell, C., Cornebise, J., Kavukcuoglu, K., and Wierstra, D. Weight uncertainty in neural networks. In International Conference on Machine Learning, 2015.
  • Bossard et al. (2014) Bossard, L., Guillaumin, M., and Van Gool, L. Food-101 – mining discriminative components with random forests. In European Conference on Computer Vision, 2014.
  • Bridle (1990) Bridle, J. S. Probabilistic interpretation of feedforward classification network outputs, with relationships to statistical pattern recognition. In Neurocomputing, pp.  227–236. Springer, 1990.
  • Chen et al. (2018) Chen, W., Shen, Y., Jin, H., and Wang, W. A variational Dirichlet framework for out-of-distribution detection. arXiv preprint arXiv:1811.07308, 2018.
  • Dawid (1982) Dawid, A. P. The well-calibrated Bayesian. Journal of the American Statistical Association, 77(379):605–610, 1982.
  • Devlin et al. (2018) Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. BERT: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Gal (2016) Gal, Y. Uncertainty in Deep Learning. PhD thesis, University of Cambridge, 2016.
  • Gal & Ghahramani (2016) Gal, Y. and Ghahramani, Z. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In International Conference on Machine Learning, 2016.
  • Gong et al. (2019) Gong, W., Li, Y., and Hernández-Lobato, J. M. Meta-learning for stochastic gradient MCMC. In International Conference on Learning Representations, 2019.
  • Graves (2011) Graves, A. Practical variational inference for neural networks. In Advances in Neural Information Processing Systems, 2011.
  • Guo et al. (2017) Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. On calibration of modern neural networks. In International Conference on Machine Learning, 2017.
  • Hafner et al. (2018) Hafner, D., Tran, D., Lillicrap, T., Irpan, A., and Davidson, J. Noise contrastive priors for functional uncertainty. arXiv preprint arXiv:1807.09289, 2018.
  • He et al. (2015) He, K., Zhang, X., Ren, S., and Sun, J. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In IEEE International Conference on Computer Vision, 2015.
  • He et al. (2016a) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In IEEE Conference on Computer Vision and Pattern Recognition, 2016a.
  • He et al. (2016b) He, K., Zhang, X., Ren, S., and Sun, J. Identity mappings in deep residual networks. In European Conference on Computer Vision, 2016b.
  • He et al. (2019) He, K., Fan, H., Wu, Y., Xie, S., and Girshick, R. Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722, 2019.
  • Heek & Kalchbrenner (2019) Heek, J. and Kalchbrenner, N. Bayesian inference for large scale image classification. arXiv preprint arXiv:1908.03491, 2019.
  • Hinton et al. (2015) Hinton, G., Vinyals, O., and Dean, J. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Ioffe & Szegedy (2015) Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, 2015.
  • Jordan et al. (1999) Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., and Saul, L. K. An introduction to variational methods for graphical models. Machine Learning, 37(2):183–233, 1999.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Machine Learning, 2015.
  • Krause et al. (2013) Krause, J., Stark, M., Deng, J., and Fei-Fei, L. 3d object representations for fine-grained categorization. In ICCV Workshop on 3D Representation and Recognition, pp. 554–561, 2013.
  • Krizhevsky (2009) Krizhevsky, A. Learning multiple layers of features from tiny images. Technical report, 2009.
  • Laine & Aila (2017) Laine, S. and Aila, T. Temporal ensembling for semi-supervised learning. In International Conference on Learning Representations, 2017.
  • Lakshminarayanan et al. (2017) Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in Neural Information Processing Systems, 2017.
  • LeCun et al. (1998) LeCun, Y., Bottou, L., Orr, G. B., and Müller, K.-R. Efficient backprop. In Neural Networks: Tricks of the Trade, pp.  9–50. Springer, 1998.
  • Louizos & Welling (2017) Louizos, C. and Welling, M. Multiplicative normalizing flows for variational Bayesian neural networks. In International Conference on Machine Learning, 2017.
  • Ma et al. (2015) Ma, Y.-A., Chen, T., and Fox, E. A complete recipe for stochastic gradient MCMC. In Advances in Neural Information Processing Systems, 2015.
  • MacKay (1992) MacKay, D. J. A practical Bayesian framework for backpropagation networks. Neural Computation, 4(3):448–472, 1992.
  • MacKay (1995) MacKay, D. J. Probable networks and plausible predictions—a review of practical Bayesian methods for supervised neural networks. Network: Computation in Neural Systems, 6(3):469–505, 1995.
  • Maddox et al. (2019) Maddox, W. J., Izmailov, P., Garipov, T., Vetrov, D. P., and Wilson, A. G. A simple baseline for Bayesian uncertainty in deep learning. In Advances in Neural Information Processing Systems, 2019.
  • Malinin & Gales (2018) Malinin, A. and Gales, M. Predictive uncertainty estimation via prior networks. In Advances in Neural Information Processing Systems, 2018.
  • Malinin & Gales (2019) Malinin, A. and Gales, M. Reverse KL-divergence training of prior networks: Improved uncertainty and adversarial robustness. In Advances in Neural Information Processing Systems, 2019.
  • Mandt et al. (2017) Mandt, S., Hoffman, M. D., and Blei, D. M. Stochastic gradient descent as approximate Bayesian inference. Journal of Machine Learning Research, 18(1):4873–4907, 2017.
  • Miyato et al. (2018) Miyato, T., Maeda, S.-i., Koyama, M., and Ishii, S. Virtual adversarial training: A regularization method for supervised and semi-supervised learning. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(8):1979–1993, 2018.
  • Müller et al. (2019) Müller, R., Kornblith, S., and Hinton, G. When does label smoothing help? In Advances in Neural Information Processing Systems, 2019.
  • Naeini et al. (2015) Naeini, M. P., Cooper, G., and Hauskrecht, M. Obtaining well calibrated probabilities using Bayesian binning. In AAAI Conference on Artificial Intelligence, 2015.
  • Neumann et al. (2018) Neumann, L., Zisserman, A., and Vedaldi, A. Relaxed softmax: Efficient confidence auto-calibration for safe pedestrian detection. In NIPS Workshop on Machine Learning for Intelligent Transportation Systems, 2018.
  • Oliver et al. (2018) Oliver, A., Odena, A., Raffel, C. A., Cubuk, E. D., and Goodfellow, I. Realistic evaluation of deep semi-supervised learning algorithms. In Advances in Neural Information Processing Systems, 2018.
  • Osawa et al. (2019) Osawa, K., Swaroop, S., Jain, A., Eschenhagen, R., Turner, R. E., Yokota, R., and Khan, M. E. Practical deep learning with Bayesian principles. In Advances in Neural Information Processing Systems, 2019.
  • Pereyra et al. (2017) Pereyra, G., Tucker, G., Chorowski, J., Kaiser, Ł., and Hinton, G. Regularizing neural networks by penalizing confident output distributions. arXiv preprint arXiv:1701.06548, 2017.
  • Raiko et al. (2012) Raiko, T., Valpola, H., and LeCun, Y. Deep learning made easier by linear transformations in perceptrons. In International Conference on Artificial Intelligence and Statistics, 2012.
  • Russakovsky et al. (2015) Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A., Khosla, A., Bernstein, M., Berg, A. C., and Fei-Fei, L. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision, 115(3):211–252, 2015.
  • Schraudolph (1998) Schraudolph, N. Accelerated gradient descent by factor-centering decomposition. Technical report, 1998.
  • Sensoy et al. (2018) Sensoy, M., Kaplan, L., and Kandemir, M. Evidential deep learning to quantify classification uncertainty. In Advances in Neural Information Processing Systems, 2018.
  • Srivastava et al. (2014) Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R. Dropout: A simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(1):1929–1958, 2014.
  • Sun et al. (2019) Sun, S., Zhang, G., Shi, J., and Grosse, R. Functional variational Bayesian neural networks. In International Conference on Learning Representations, 2019.
  • Szegedy et al. (2016) Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., and Wojna, Z. Rethinking the inception architecture for computer vision. In IEEE Conference on Computer Vision and Pattern Recognition, 2016.
  • Welling & Teh (2011) Welling, M. and Teh, Y. W. Bayesian learning via stochastic gradient Langevin dynamics. In International Conference on Machine Learning, 2011.
  • Wiesler et al. (2014) Wiesler, S., Richard, A., Schlüter, R., and Ney, H. Mean-normalized stochastic gradient for large-scale deep learning. In IEEE International Conference on Acoustics, Speech and Signal Processing, 2014.
  • Wilson et al. (2016a) Wilson, A. G., Hu, Z., Salakhutdinov, R. R., and Xing, E. P. Deep kernel learning. In International Conference on Artificial Intelligence and Statistics, 2016a.
  • Wilson et al. (2016b) Wilson, A. G., Hu, Z., Salakhutdinov, R. R., and Xing, E. P. Stochastic variational deep kernel learning. In Advances in Neural Information Processing Systems, 2016b.
  • Wu et al. (2019a) Wu, A., Nowozin, S., Meeds, E., Turner, R. E., Hernández-Lobato, J. M., and Gaunt, A. L. Deterministic variational inference for robust Bayesian neural networks. In International Conference on Learning Representations, 2019a.
  • Wu et al. (2019b) Wu, Q., Li, H., Li, L., and Yu, Z. Quantifying intrinsic uncertainty in classification via deep dirichlet mixture networks. arXiv preprint arXiv:1906.04450, 2019b.
  • Xie et al. (2016) Xie, L., Wang, J., Wei, Z., Wang, M., and Tian, Q. Disturblabel: Regularizing cnn on the loss layer. In IEEE Conference on Computer Vision and Pattern Recognition, 2016.
  • Xie et al. (2017) Xie, S., Girshick, R., Dollár, P., Tu, Z., and He, K. Aggregated residual transformations for deep neural networks. In IEEE Conference on Computer Vision and Pattern Recognition, 2017.
  • Zagoruyko & Komodakis (2016) Zagoruyko, S. and Komodakis, N. Wide residual networks. In British Machine Vision Conference, 2016.
  • Zhang et al. (2018) Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy natural gradient as variational inference. In International Conference of Machine Learning, 2018.
  • Zintgraf et al. (2017) Zintgraf, L. M., Cohen, T. S., Adel, T., and Welling, M. Visualizing deep neural network decisions: Prediction difference analysis. In International Conference on Learning Representations, 2017.