This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Scale-invariant Bayesian Neural Networks with Connectivity Tangent Kernel

SungYub Kim1, Sihwan Park1, Kyungsu Kim2,3, Eunho Yang1,411footnotemark: 1 1Korea Advanced Institute of Science and Technology (KAIST)
2Medical AI Research Center, Research Institute for Future Medicine, Samsung Medical Center, Seoul, Korea
3Department of Data Convergence and Future Medicine, Sungkyunkwan University School of Medicine, Seoul, Korea
4AITRICS, Seoul, Korea
E-mail: [email protected], [email protected], [email protected], [email protected]
Co-corresponding author
Abstract

Explaining generalizations and preventing over-confident predictions are central goals of studies on the loss landscape of neural networks. Flatness, defined as loss invariability on perturbations of a pre-trained solution, is widely accepted as a predictor of generalization in this context. However, the problem that flatness and generalization bounds can be changed arbitrarily according to the scale of a parameter was pointed out, and previous studies partially solved the problem with restrictions: Counter-intuitively, their generalization bounds were still variant for the function-preserving parameter scaling transformation or limited only to an impractical network structure. As a more fundamental solution, we propose new prior and posterior distributions invariant to scaling transformations by decomposing the scale and connectivity of parameters, thereby allowing the resulting generalization bound to describe the generalizability of a broad class of networks with the more practical class of transformations such as weight decay with batch normalization. We also show that the above issue adversely affects the uncertainty calibration of Laplace approximation and propose a solution using our invariant posterior. We empirically demonstrate our posterior provides effective flatness and calibration measures with low complexity in such a practical parameter transformation case, supporting its practical effectiveness in line with our rationale.

1 Introduction

Though neural networks (NNs) have experienced extraordinary success, understanding the generalizability of NNs and successfully using them in real-world contexts still faces a number of obstacles [1, 2]. It is a well-known enigma, for instance, why such NNs generalize well and do not suffer from overfitting [3, 4, 5]. Recent research on the loss landscape of NNs seeks to reduce these obstacles. Hochreiter and Schmidhuber [6] proposed a theory known as flat minima (FM): the flatness of local minima (i.e., loss invariability w.r.t. parameter perturbations) is positively correlated with network generalizability, as empirically demonstrated by Jiang et al. [7]. Concerning overconfidence, MacKay [8] suggested an approximated Bayesian posterior using the curvature information of local minima, and Daxberger et al. [9] underlined its practical utility.

Nonetheless, the limitations of the FM hypothesis were pointed out by Dinh et al. [10], Li et al. [11]. By rescaling two successive layers, Dinh et al. [10] demonstrated it was possible to modify a flatness measure without modifying the functions, hence allowing extraneous variability to be captured in the computation of generalizability. Meanwhile, Li et al. [11] argued that weight decayregularization [12] is an important limitation of the FM hypothesis as it leads to a contradiction of the FM hypothesis in practice; the weight decay sharpens the pre-trained solutions of NNs by downscaling the parameters, whereas the weight decay actually improves the generalization performance of NNs in general cases [13]. In short, they suggest that scaling transformation on network parameters (e.g., re-scaling layers and weight decay) may lead to a contradiction of the FM hypothesis.

To resolve this contradiction, we investigate PAC-Bayesian prior and posterior distributions to derive a new scale-invariant generalization bound. Unlike related works [14, 15], our bound guarantees the invariance for a general class of function-preserving parameter scaling transformation with a broad class of networks [16] (Secs. 2.2 and 2.3).

This bound is derived from the scale invariance of the prior and poster distributions, which guarantees not only the scale invariance of the bound but also its substance the Kullback-Leibler (KL) divergence-based kernel; we named this new term with scale-invariance property as empirical Connectivity Tangent Kernel (CTK) as it can be considered as a modification of empirical Neural Tangent Kernel [17]. Consequently, we define a novel sharpness metric named Connectivity Sharpness (CS) as a trace of CTK. Empirically, we verify our CS has a better prediction for generalization performance of NNs than existing sharpness measures [18, 19, 20] with a low-complexity (Sec. 2.5), with confirming its stronger correlation to generalization error (Sec. 4.1).

We also found the contradiction of the FM hypothesis turns into meaningless predictive uncertainty amplifying issues in the Bayesian NN regime (Sec. 3.1), and can alleviate this issue by using Bayesian NN based on the posterior distribution of our PAC-Bayesian analysis. We call the resulting Bayesian NN as Connectivity Laplace (CL), as it can be seen as a variation of Laplace Approximation (LA; MacKay [8]) using different Jacobian. In particular, we provide pitfalls of weight decay regularization with BN in LA and its remedy using our posterior (Sec. 3.1) to suggest practical utilities of our Bayesian NNs (Sec. 4.2). We summarize our contributions as follows:

  • Unlike related studies, our resulting (PAC-Bayes) generalization bound guarantees the invariance for a general class of function-preserving parameter scaling transformation with a broad class of networks (Sec. 2.2 and 2.3). Based on this novel PAC-Bayes bound, we propose a low-complexity sharpness measure (Sec. 2.5).

  • We provide pitfalls of weight decay regularization with BN in LA and its remedy using our posterior (Sec. 3.1).

  • We empirically confirm the strong correlation between generalization error and our sharpness metric (Sec. 4.1) and visualize pitfalls of weight decay with LA in synthetic data and practical utilities of our Bayesian NNs (Sec. 4.2).

2 PAC-Bayes bound with scale-invariance

2.1 Background

Setup and Definitions We consider a Neural Network (NN), f(,):D×PKf(\cdot,\cdot):\mathbb{R}^{D}\times\mathbb{R}^{P}\rightarrow\mathbb{R}^{K}, given input xDx\in\mathbb{R}^{D} and network parameter θP\theta\in\mathbb{R}^{P}. Hereafter, we consider one dimensional vector a one-dimensional vector as a single column matrix unless otherwise stated. We use the output of NN f(x,θ)f(x,\theta) as a prediction for input xx. Let 𝒮:={(xn,yn)}n=1N\mathcal{S}:=\{(x_{n},y_{n})\}_{n=1}^{N} denote the independently and identically distributed (i.i.d.) training data drawn from true data distribution 𝒟\mathcal{D}, where xnDx_{n}\in\mathbb{R}^{D} and ynKy_{n}\in\mathbb{R}^{K} are input and output representation of nn-th training instance, respectively. For simplicity, we denote concatenated input and output of all instances as 𝒳:={x:(x,y)𝒮}\mathcal{X}:=\{x:(x,y)\in\mathcal{S}\} and 𝒴:={y:(x,y)𝒮}\mathcal{Y}:=\{y:(x,y)\in\mathcal{S}\}, respectively and f(𝒳,θ)NKf(\mathcal{X},\theta)\in\mathbb{R}^{NK} as a concatenation of {f(xn,θ)}n=1N\{f(x_{n},\theta)\}_{n=1}^{N}. Given a prior distribution of network parameter p(θ)p(\theta) and a likelihood function p(𝒮|θ):=n=1Np(yn|xn,θ):=n=1Np(yn|f(xn,θ))p(\mathcal{S}|\theta):=\prod_{n=1}^{N}p(y_{n}|x_{n},\theta):=\prod_{n=1}^{N}p(y_{n}|f(x_{n},\theta)), Bayesian inference defines posterior distribution of network parameter θ\theta as p(θ|𝒮)=1Z(𝒮)exp((𝒮,θ)):=1Z(𝒮)p(θ)p(𝒮|θ),Z(𝒮):=p(θ)p(𝒮|θ)𝑑θp(\theta|\mathcal{S})=\frac{1}{Z(\mathcal{S})}\exp(-\mathcal{L}(\mathcal{S},\theta)):=\frac{1}{Z(\mathcal{S})}p(\theta)p(\mathcal{S}|\theta),Z(\mathcal{S}):=\int p(\theta)p(\mathcal{S}|\theta)d\theta where (𝒮,θ):=logp(θ)n=1Nlogp(yn|xn,θ)\mathcal{L}(\mathcal{S},\theta):=-\log p(\theta)-\sum_{n=1}^{N}\log p(y_{n}|x_{n},\theta) is training loss and Z(𝒮)Z(\mathcal{S}) is normalizing factor. For example, the likelihood function for regression task will be Gaussian: p(y|x,θ)=𝒩(y|f(x,θ),σ2𝐈k)p(y|x,\theta)=\mathcal{N}(y|f(x,\theta),\sigma^{2}\mathbf{I}_{k}) where σ\sigma is (homoscedastic) observation noise scale. For classification task, we treat it as a one-hot regression task following Lee et al. [21] and He et al. [22]. While we applied this modification for theoretical tractability, Lee et al. [23], Hui and Belkin [24] showed this modification offers reasonable performance competitive to the cross-entropy loss. Details on this treatment is given in Appendix B.

Laplace Approximation In general, the exact computation for the Bayesian posterior of a network parameter is intractable. The Laplace Approximation (LA; [8]) proposes to approximate the posterior distribution with a Gaussian distribution defined as pLA(ψ|𝒮)𝒩(ψ|θ,(θ2(𝒮,θ))1)p_{\mathrm{LA}}(\psi|\mathcal{S})\sim\mathcal{N}(\psi|\theta^{*},(\nabla^{2}_{\theta}\mathcal{L}(\mathcal{S},\theta^{*}))^{-1}) where θP\theta^{*}\in\mathbb{R}^{P} is a pre-trained parameter with training loss and θ2(𝒮,θ)P×P\nabla^{2}_{\theta}\mathcal{L}(\mathcal{S},\theta^{*})\in\mathbb{R}^{P\times P} is Hessian of loss function w.r.t. parameter at θ\theta^{*}.

Recent works on LA replace the Hessian matrix with (Generalized) Gauss-Newton matrix to make computation tractable [25, 26]. With this approximation, the LA posterior of regression problem can be represented as:

pLA(ψ|𝒮)𝒩(ψ|θ,(𝐈P/α2Damping+𝐉θ𝐉θ/σ2Curvature)1)\displaystyle p_{\mathrm{LA}}(\psi|\mathcal{S})\sim\mathcal{N}(\psi|\theta^{*},(\underbrace{\mathbf{I}_{P}/\alpha^{2}}_{\textrm{Damping}}+\underbrace{\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}/\sigma^{2}}_{\textrm{Curvature}})^{-1}) (1)

where α,σ>0\alpha,\sigma>0 and 𝐈PP×P\mathbf{I}_{P}\in\mathbb{R}^{P\times P} is a identity matrix and 𝐉θNK×P\mathbf{J}_{\theta}\in\mathbb{R}^{NK\times P} is a concatenation of 𝐉θ(x,θ)K×P\mathbf{J}_{\theta}(x,\theta^{*})\in\mathbb{R}^{K\times P} (Jacobian of ff w.r.t. θ\theta at input xx and parameter θ\theta^{*}) along training input 𝒳\mathcal{X}. Since covariance of equation 1 is inverse of P×PP\times P matrix, further sub-curvature approximation was considered including diagonal, Kronecker-factored approximate curvature (KFAC), last-layer, and sub-network [27, 28, 29]. Furthermore, they found that proper selection of prior scale α\alpha is needed to balance the dilemma between overconfidence and underfitting in LA.

PAC-Bayes bound with data-dependent prior We consider a PAC-Bayes generalization error bound of classification task used in McAllester [30], Perez-Ortiz et al. [31] (especially, equation (7) of Perez-Ortiz et al. [31]). Let \mathbb{P} be any PAC-Bayes prior distribution over P\mathbb{R}^{P} independent of training dataset 𝒮\mathcal{S} and err(,):K×K[0,1]\mathrm{err}(\cdot,\cdot):\mathbb{R}^{K\times K}\rightarrow[0,1] be a error function which is defined separately from the loss function. For any constant δ(0,1]\delta\in(0,1] and λ>0\lambda>0, and any PAC-Bayes posterior distribution \mathbb{Q} over P\mathbb{R}^{P}, the following holds with probability at least 1δ1-\delta: err𝒟()err𝒮()+KL[]+log(2N/δ)2N\mathrm{err}_{\mathcal{D}}(\mathbb{Q})\leq\mathrm{err}_{\mathcal{S}}(\mathbb{Q})+\sqrt{\frac{\mathrm{KL}[\mathbb{Q}\|\mathbb{P}]+\log(2\sqrt{N}/\delta)}{2N}} where err𝒟():=𝔼(x,y)𝒟,θ[err(f(x,θ),y)]\mathrm{err}_{\mathcal{D}}(\mathbb{Q}):=\mathbb{E}_{(x,y)\sim\mathcal{D},\theta\sim\mathbb{Q}}[\mathrm{err}(f(x,\theta),y)], err𝒮():=𝔼(x,y)𝒮,θ[err(f(x,θ),y)]\mathrm{err}_{\mathcal{S}}(\mathbb{Q}):=\mathbb{E}_{(x,y)\sim\mathcal{S},\theta\sim\mathbb{Q}}[\mathrm{err}(f(x,\theta),y)], and NN denotes the cardinality of 𝒮\mathcal{S}. That is, err𝒟()\mathrm{err}_{\mathcal{D}}(\mathbb{Q}) and err𝒮()\mathrm{err}_{\mathcal{S}}(\mathbb{Q}) are generalization error and empirical error, respectively. The only restriction on \mathbb{P} here is that it cannot depend on the dataset SS.

Following the recent discussion in Perez-Ortiz et al. [31], one can construct data-dependent PAC-Bayes bounds by (i) randomly partitioning dataset 𝒮\mathcal{S} into 𝒮\mathcal{S}_{\mathbb{Q}} and 𝒮\mathcal{S}_{\mathbb{P}} so that they are independent, (ii) using a PAC-Bayes prior distribution 𝒟\mathbb{P}_{\mathcal{D}} only dependent of 𝒮\mathcal{S}_{\mathbb{P}} (i.e., independent of 𝒮\mathcal{S}_{\mathbb{Q}} so 𝒟\mathbb{P}_{\mathcal{D}} belongs to \mathbb{P}), (iii) using a PAC-Bayes posterior distribution \mathbb{Q} dependent of entire dataset 𝒮\mathcal{S}, and (iv) computing empirical error err𝒮()\mathrm{err}_{\mathcal{S}_{\mathbb{Q}}}(\mathbb{Q}) with target subset 𝒮\mathcal{S}_{\mathbb{Q}} (not entire dataset 𝒮\mathcal{S}). In summary, one can modify the aforementioned original PAC-Bayes bound through our data-dependent prior c\mathbb{P}_{c} as

err𝒟()\displaystyle\mathrm{err}_{\mathcal{D}}(\mathbb{Q}) err𝒮()+KL[𝒟]+log(2N/δ)2N\displaystyle\leq\mathrm{err}_{\mathcal{S}_{\mathbb{Q}}}(\mathbb{Q})+\sqrt{\frac{\mathrm{KL}[\mathbb{Q}\|\mathbb{P}_{\mathcal{D}}]+\log(2\sqrt{N_{\mathbb{Q}}}/\delta)}{2N_{\mathbb{Q}}}} (2)

where NN_{\mathbb{Q}} is the cardinality of 𝒮\mathcal{S}_{\mathbb{Q}}. We denote sets of input and output of splitted datasets (𝒮,𝒮\mathcal{S}_{\mathbb{P}},\mathcal{S}_{\mathbb{Q}}) as 𝒳,𝒴,𝒳,𝒴\mathcal{X}_{\mathbb{P}},\mathcal{Y}_{\mathbb{P}},\mathcal{X}_{\mathbb{Q}},\mathcal{Y}_{\mathbb{Q}} for simplicity.

2.2 Design of PAC-Bayes prior and posterior

Our goal is to construct scale-invariant 𝒟\mathbb{P}_{\mathcal{D}} and \mathbb{Q}. To this end, we first assume a pre-trained parameter θP\theta^{*}\in\mathbb{R}^{P} of the negative log-likelihood function with 𝒮\mathcal{S}_{\mathbb{P}}. This parameter can be attained with standard NN optimization procedures (e.g., stochastic gradient descent (SGD) with momentum). Then, we consider linearized NN at the pre-trained parameter with the auxiliary variable cPc\in\mathbb{R}^{P} as

gθlin(x,c):=f(x,θ)+𝐉θ(x,θ)diag(θ)c\displaystyle g^{\mathrm{lin}}_{\theta^{*}}(x,c):=f(x,\theta^{*})+\mathbf{J}_{\theta}(x,\theta^{*})\mathrm{diag}(\theta^{*})c (3)

where diag\mathrm{diag} is a vector-to-matrix diagonal operator. Note that equation 3 is the first-order Taylor approximation (i.e., linearization) of NN with perturbation θc\theta^{*}\odot c given input xx and parameter θ\theta^{*}: gθpert(x,c):=f(x,θ+θc)=f(x,θ+diag(θ)c)gθlin(x,c)g^{\mathrm{pert}}_{\theta^{*}}(x,c):=f(x,\theta^{*}+\theta^{*}\odot c)=f(x,\theta^{*}+\mathrm{diag}(\theta^{*})c)\approx g^{\mathrm{lin}}_{\theta^{*}}(x,c), where \odot denotes element-wise multiplication (Hadamard product) of two vectors. Here we write the perturbation in parameter space as θc\theta^{*}\odot c instead of single variable such as δP\delta\in\mathbb{R}^{P}. This design of linearization matches the scale of perturbation (i.e., diag(θ)c\mathrm{diag}(\theta^{*})\odot c) to the scale of θ\theta^{*} in a component-wise manner. Similar idea was proposed in the context of pruning at initialization [32, 33] to measure the importance of each connection independently of its weight. In this context, our perturbation can be viewed as perturbation in connectivity space by decomposing the scale and connectivity of parameter.

Based on this, we define a data-dependent prior (𝒟\mathbb{P}_{\mathcal{D}}) over connectivity as

θ(c)\displaystyle\mathbb{P}_{\theta^{*}}(c) :=𝒩(c| 0P,α2𝐈p).\displaystyle:=\mathcal{N}(c\,|\,\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{p}). (4)

This distribution can be translated to a distribution over parameter by considering the distribution of perturbed parameter (ψ:=θ+θc\psi:=\theta^{*}+\theta^{*}\odot c): θ(ψ):=𝒩(ψ|θ,α2diag(θ)2)\mathbb{P}_{\theta^{*}}(\psi):=\mathcal{N}(\psi\,|\,\theta^{*},\alpha^{2}\mathrm{diag}(\theta^{*})^{2}). We now define the PAC-Bayes posterior over connectivity (c)\mathbb{Q}(c) as follows:

θ(c)\displaystyle\mathbb{Q}_{\theta^{*}}(c) :=𝒩(c|μ,Σ)\displaystyle:=\mathcal{N}(c|\mu_{\mathbb{Q}},\Sigma_{\mathbb{Q}}) (5)
μ\displaystyle\mu_{\mathbb{Q}} :=Σ𝐉c(𝒴f(𝒳,θ))σ2=Σdiag(θ)𝐉θ(𝒴f(𝒳,θ))σ2\displaystyle:=\frac{\Sigma_{\mathbb{Q}}\mathbf{J}_{c}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}=\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}} (6)
Σ\displaystyle\Sigma_{\mathbb{Q}} :=(𝐈Pα2+𝐉c𝐉cσ2)1=(𝐈Pα2+diag(θ)𝐉θ𝐉θdiag(θ)σ2)1\displaystyle:=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}\mathrm{diag}(\theta^{*})}{\sigma^{2}}\right)^{-1} (7)

where 𝐉cNK×P\mathbf{J}_{c}\in NK\times P is a concatenation of 𝐉c(x,𝟎P):=𝐉θ(x,θ)diag(θ)K×P\mathbf{J}_{c}(x,\mathbf{0}_{P}):=\mathbf{J}_{\theta}(x,\theta^{*})\mathrm{diag}(\theta^{*})\in\mathbb{R}^{K\times P} (i.e., Jacobian of perturbed NN gθpert(x,c)g^{\mathrm{pert}}_{\theta^{*}}(x,c) w.r.t. cc at input xx and connectivity 𝟎P\mathbf{0}_{P}) along training input 𝒳\mathcal{X}. Our PAC-Bayes posterior θ\mathbb{Q}_{\theta^{*}} is the posterior of Bayesian linear regression problem w.r.t. connectivity cc : yi=f(xi,θ)+𝐉θ(xi,θ)diag(θ)c+ϵiy_{i}=f(x_{i},\theta^{*})+\mathbf{J}_{\theta}(x_{i},\theta^{*})\mathrm{diag}(\theta^{*})c+\epsilon_{i} where (xi,yi)𝒮(x_{i},y_{i})\in\mathcal{S} and ϵi\epsilon_{i} are i.i.d. samples of 𝒩(ϵi|𝟎K,σ2𝐈K)\mathcal{N}(\epsilon_{i}|\mathbf{0}_{K},\sigma^{2}\mathbf{I}_{K}). Again, it is equivalent to the posterior distribution over parameter θ(ψ)=𝒩(ψ|θ+θμ,(diag(θ)2/α2+𝐉θ𝐉θ/σ2)1)\mathbb{Q}_{\theta^{*}}(\psi)=\mathcal{N}\left(\psi|\theta^{*}+\theta^{*}\odot\mu_{\mathbb{Q}},(\mathrm{diag}(\theta^{*})^{-2}/\alpha^{2}+\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}/\sigma^{2}\right)^{-1}) where diag(θ)2:=(diag(θ)1)2\mathrm{diag}(\theta^{*})^{-2}:=(\mathrm{diag}(\theta^{*})^{-1})^{2} by assuming that all components of θ\theta^{*} are non-zero. Note that this assumption can be easily satisfied by considering the prior and posterior distribution of non-zero components of NNs only. Although we choose this restriction for theoretical tractability, future work can modify this choice to achieve diverse predictions by considering the distribution of zero coordinates. We refer to Appendix C for detailed derivations of θ(c)\mathbb{Q}_{\theta^{*}}(c) and θ(ψ)\mathbb{Q}_{\theta^{*}}(\psi).

Remark 2.1 (Two-phase training).

The prior distribution in equation 4 is data-dependent priors since they depend on the pre-trained parameter θ\theta^{*} optimized on training dataset 𝒮\mathcal{S}_{\mathbb{P}}. On the other hand, posterior distribution in equation 5 depend on both 𝒮\mathcal{S}_{\mathbb{P}} (through θ\theta^{*}) and 𝒮\mathcal{S}_{\mathbb{Q}} (through Bayesian linear regression). Intuitively, one attain the PAC-Bayes posterior \mathbb{Q} with two-phase training: pre-train with 𝒮\mathcal{S}_{\mathbb{P}} and Bayesian fine-tuning with 𝒮\mathcal{S}_{\mathbb{Q}}. A similar idea of linearized fine-tuning was proposed in the context of transfer learning in Achille et al. [34], Maddox et al. [35].

Now we provide an invariance property of prior and posterior distributions w.r.t. function-preserving scale transformations as follows: The main intuition behind this proposition is that Jacobian w.r.t. connectivity is invariant to the function-preserving scaling transformation, i.e.,
𝐉θ(x,𝒯(θ))diag(𝒯(θ))=𝐉θ(x,θ)diag(θ)\mathbf{J}_{\theta}(x,\mathcal{T}(\theta^{*}))\mathrm{diag}(\mathcal{T}(\theta^{*}))=\mathbf{J}_{\theta}(x,\theta^{*})\mathrm{diag}(\theta^{*}).

Proposition 2.2 (Scale-invariance of PAC-Bayes prior and posterior).

Let 𝒯:PP\mathcal{T}:\mathbb{R}^{P}\rightarrow\mathbb{R}^{P} is a invertible diagonal linear transformation such that f(x,𝒯(ψ))=f(x,ψ) , xD,ψPf(x,\mathcal{T}(\psi))=f(x,\psi)\text{ , }\forall x\in\mathbb{R}^{D},\forall\psi\in\mathbb{R}^{P}. Then, both PAC-Bayes prior and posterior are invariant under 𝒯\mathcal{T}:

𝒯(θ)(c)=dθ(c),𝒯(θ)(c)=dθ(c).\displaystyle\mathbb{P}_{\mathcal{T}(\theta^{*})}(c)\stackrel{{\scriptstyle d}}{{=}}\mathbb{P}_{\theta^{*}}(c),\quad\mathbb{Q}_{\mathcal{T}(\theta^{*})}(c)\stackrel{{\scriptstyle d}}{{=}}\mathbb{Q}_{\theta^{*}}(c).

Furthermore, generalization and empirical errors are also invariant to 𝒯\mathcal{T}.

2.3 Resulting PAC-Bayes bound

Now we plug in prior and posterior into the modified PAC-Bayes generalization error bound in equation 2. Accordingly, we obtain a novel generalization error bound, named PAC-Bayes-CTK, which is guaranteed to be invariant to scale transformations (hence without the contradiction of FM hypothesis mentioned in Sec. 1).

Theorem 2.3 (PAC-Bayes-CTK and its invariance).

Let us assume pre-trained parameter θ\theta^{*} with data 𝒮\mathcal{S}_{\mathbb{P}}. By applying θ,θ\mathbb{P}_{\theta^{*}},\mathbb{Q}_{\theta^{*}} to data-dependent PAC-Bayes bound (equation 2), we get

err𝒟(θ)\displaystyle\mathrm{err}_{\mathcal{D}}(\mathbb{Q}_{\theta^{*}}) err𝒮(θ)+μμ4α2N(average) perturbation+i=1Ph(βi)4NsharpnessKL divergence+log(2N/δ)2N\displaystyle\leq\mathrm{err}_{\mathcal{S}_{\mathbb{Q}}}(\mathbb{Q}_{\theta^{*}})+\sqrt{\overbrace{\underbrace{\frac{\mu_{\mathbb{Q}}^{\top}\mu_{\mathbb{Q}}}{4\alpha^{2}N_{\mathbb{Q}}}}_{\textrm{(average) perturbation}}+\underbrace{\sum_{i=1}^{P}\frac{h\left(\beta_{i}\right)}{{4N_{\mathbb{Q}}}}}_{\textrm{sharpness}}}^{\textrm{KL divergence}}+\frac{\log(2\sqrt{N_{\mathbb{Q}}}/\delta)}{2N_{\mathbb{Q}}}} (8)

where {βi}i=1P\{\beta_{i}\}_{i=1}^{P} are eigenvalues of (𝐈P+α2σ2𝐉c𝐉c)1(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\mathbf{J}_{c}^{\top}\mathbf{J}_{c})^{-1} and h(x):=xlog(x)1h(x):=x-\log(x)-1. This upper bound is invariant to 𝒯\mathcal{T} for the function-preserving scale transformation in Proposition 2.2.

Note that recent works on solving FM contradiction only focused on the scale-invariance of sharpness metric: Indeed, their generalization bounds are not invariant to scale transformations due to the scale-dependent terms (equation (34) in Tsuzuku et al. [14] and equation (6) in Kwon et al. [15]). On the other hand, generalization bound in Petzka et al. [16] (Theorem 11 in their paper) only holds for single-layer NNs, whereas ours has no restrictions for network structure. Therefore, to the best of our knowledge, our PAC-Bayes bound is the first scale-invariant PAC-Bayes bound. To highlight our theoretical implications, we show the representative cases of 𝒯\mathcal{T} in Proposition 2.2 in Appendix D (e.g., weight decay for network with BN), where the generalization bounds of the other studies are variant, but ours is invariant, resolving the FM contradiction on bound level.

The following corollary explains why we name PAC-Bayes bound in Theorem 2.3 PAC-Bayes-CTK.

Corollary 2.4 (Relation between CTK and PAC-Bayes-CTK).

Let us define empirical Connectivity Tangent Kernel (CTK) of 𝒮\mathcal{S} as 𝐂𝒳θ:=𝐉c𝐉c=𝐉θdiag(θ)2𝐉θNK×NK\mathbf{C}_{\mathcal{X}}^{\theta^{*}}:=\mathbf{J}_{c}\mathbf{J}_{c}^{\top}=\mathbf{J}_{\theta}\mathrm{diag}(\theta^{*})^{2}\mathbf{J}_{\theta}^{\top}\in\mathbb{R}^{NK\times NK} by removing below term? Note that empirical CTK has T(NK)T(\leq NK) non-zero eigenvalues of {λi}i=1T\{\lambda_{i}\}_{i=1}^{T}, then following holds for {β}i=1P\{\beta\}_{i=1}^{P} in Theorem 2.3: (i) βi=σ2/(σ2+α2λi)<1\beta_{i}=\sigma^{2}/(\sigma^{2}+\alpha^{2}\lambda_{i})<1 for i=1,,Ti=1,\dots,T and (ii) βi=1\beta_{i}=1 for i=T+1,,Pi=T+1,\dots,P. Since h(1)=0h(1)=0, this means PTP-T terms of summation in sharpness part of PAC-Bayes-CTK vanishes to 0. Furthermore, this sharpness term of PAC-Bayes-CTK is a monotonically increasing function for each eigenvalue of empirical CTK.

Note that Corollary 2.4 clarifies why i=1Ph(βi)/4N\sum_{i=1}^{P}h(\beta_{i})/4N_{\mathbb{Q}} in Theorem 2.3 is called the sharpness term of PAC-Bayes-CTK. As eigenvalues of CTK measures the sensitivity of output w.r.t. perturbation on connectivity, a sharp pre-trained parameter would have large CTK eigenvalues, so increasing the sharpness term and the generalization gap by according to Corollary 2.4.

Finally, Proposition 2.5 shows that empirical CTK is also scale-invariant.

Proposition 2.5 (Scale-invariance of empirical CTK).

Let 𝒯:PP\mathcal{T}:\mathbb{R}^{P}\rightarrow\mathbb{R}^{P} be an function-preserving scale transformation in Proposition 2.2. Then empirical CTK at parameter ψ\psi is invariant under 𝒯\mathcal{T}:

𝐂xy𝒯(ψ)\displaystyle\mathbf{C}^{\mathcal{T}(\psi)}_{xy} :=𝐉θ(x,𝒯(ψ))diag(𝒯(ψ)2)𝐉θ(y,𝒯(ψ))=𝐂xyψ , x,yD,ψP.\displaystyle:=\mathbf{J}_{\theta}(x,\mathcal{T}(\psi))\mathrm{diag}(\mathcal{T}(\psi)^{2})\mathbf{J}_{\theta}(y,\mathcal{T}(\psi))^{\top}=\mathbf{C}^{\psi}_{xy}\text{ , }\forall x,y\in\mathbb{R}^{D},\forall\psi\in\mathbb{R}^{P}. (9)
Remark 2.6 (Connections to empirical NTK).

The empirical CTK 𝐂xyψ\mathbf{C}^{\psi}_{xy} resembles the existing empirical Neural Tangent Kernel (NTK) at parameter ψ\psi [17]: Θxyψ:=𝐉θ(x,ψ)𝐉θ(y,ψ)k×k\Theta_{xy}^{\psi}:=\mathbf{J}_{\theta}(x,\psi)\mathbf{J}_{\theta}(y,\psi)^{\top}\in\mathbb{R}^{k\times k}. Note that the deterministic NTK in Jacot et al. [17] is the infinite-width limiting kernel at initialized parameters, while empirical NTK can be defined on any parameter of a finite-width NN. The main difference between empirical CTK and the existing empirical NTK is in the definition of Jacobian. In CTK, Jacobian is computed w.r.t. connectivity cc while the empirical NTK uses Jacobian w.r.t. parameters θ\theta. Therefore, another PAC-Bayes bound can be derived from the linearization of fθpert(x,δ):=f(x,θ+δ)fθlin(x,δ)f^{\mathrm{pert}}_{\theta^{*}}(x,\delta):=f(x,\theta^{*}+\delta)\approx f^{\mathrm{lin}}_{\theta^{*}}(x,\delta) where fθlin(x,δ):=f(x,θ)+𝐉θ(x,θ)δf^{\mathrm{lin}}_{\theta^{*}}(x,\delta):=f(x,\theta^{*})+\mathbf{J}_{\theta}(x,\theta^{*})\delta. As this bound is related to the eigenvalues of Θ𝒳θ\Theta^{\theta^{*}}_{\mathcal{X}}, we call this bound PAC-Bayes-NTK and provide derivations in Appendix A. Note PAC-Bayes-NTK is not scale-invariant as Θxy𝒯(ψ)Θxyψ\Theta^{\mathcal{T}(\psi)}_{xy}\neq\Theta^{\psi}_{xy} in general.

2.4 Computing approximate bound in real world problems

To verify that PAC-Bayes bound in Theorem 2.3 is non-vacuous, we compute it for real-world problems. We use CIFAR-10 and 100 datasets [36], where the 50K training instances are randomly partitioned into 𝒮\mathcal{S}_{\mathbb{P}} of cardinality 45K and 𝒮\mathcal{S}_{\mathbb{Q}} of cardinality 5K. We pre-train ResNet-18 [37] with a mini-batch size of 1K on 𝒮\mathcal{S}_{\mathbb{P}} with SGD of initial learning rate 0.4 and momentum 0.9. We use cosine annealing for learning rate scheduling [38] with a warmup for the initial 10% training step. We fix δ=0.1\delta=0.1 and select α,σ\alpha,\sigma based on the negative log-likelihood of 𝒮\mathcal{S}_{\mathbb{Q}}.

To compute the equation 8, one need (i) μ\mu_{\mathbb{Q}}-based perturbation term, (ii) 𝐂𝒳θ\mathbf{C}^{\theta^{*}}_{\mathcal{X}}-based sharpness term, and (iii) samples from PAC-Bayes posterior θ\mathbb{Q}_{\theta^{*}}. μ\mu_{\mathbb{Q}} in equation 6 can be obtained by minimizing argmincPL(c)=12N𝒴f(𝒳,θ)𝐉cc2+σ22α2Ncc\arg\min_{c\in\mathbb{R}^{P}}L(c)=\frac{1}{2N}\|\mathcal{Y}-f(\mathcal{X},\theta^{*})-\mathbf{J}_{c}c\|^{2}+\frac{\sigma^{2}}{2\alpha^{2}N}c^{\top}c by first-order optimality condition. Note that this problem is a convex optimization problem w.r.t. cc, since cc is the parameter of the linear regression problem. We use Adam optimizer [39] with fixed learning rate 1e-4 to solve this. For sharpness term, we apply the Lanczos algorithm to approximate the eigenspectrum of 𝐂𝒳θ\mathbf{C}^{\theta^{*}}_{\mathcal{X}} following [40]. We use 100 Lanczos iterations based on the their setting. Lastly, we estimate empirical error and test error with 8 samples of CL/LL implementation of Randomize-Then-Optimize (RTO) framework [41, 42]. We refer to Appendix E for the RTO implementation of CL/LL.

Table 1: Results for experiments on PAC-Bayes-CTK estimation.
CIFAR-10 CIFAR-100
Parameter scale 0.5 1.0 2.0 4.0 0.5 1.0 2.0 4.0
tr(𝐂𝒳θ)\mathrm{tr}(\mathbf{C}^{\theta^{*}}_{\mathcal{X}}) 14515.0039 14517.7793 14517.3506 14518.4746 12872.6895 12874.4395 12873.9512 12875.541
Bias 13.9791 13.4685 13.3559 13.3122 25.3686 24.8064 24.9102 24.7557
Sharpness 24.6874 24.6938 24.6926 24.6941 26.0857 26.0894 26.0874 26.0916
KL 19.3332 19.0812 19.0243 19.0032 25.7271 25.4479 25.4988 25.4236
Test err. 0.0468 ± 0.0014 0.0463 ± 0.0013 0.0462 ± 0.0012 0.0460 ± 0.0013 0.2257 ± 0.0020 0.2252 ± 0.0017 0.2256 ± 0.0015 0.2253 ± 0.0017
PAC-Bayes-CTK 0.0918 ± 0.0013 0.0911 ± 0.0011 0.0909 ± 0.0011 0.0907 ± 0.0009 0.2874 ± 0.0034 0.2862 ± 0.0031 0.2860 ± 0.0032 0.2862 ± 0.0032

Table 1 provides the bound and related term of the resulting model. First of all, we found that our estimated PAC-Bayes-CTK is non-vacuous [43]: estimated bound is better than guessing at random. Note that the non-vacuous bound is not trivial in PAC-Bayes analysis: only a few PAC-Bayes literatures [44, 43, 31] verified the non-vacuous property of their bound, and other PAC-Bayes literatures [45, 14] do not. To check the invariance property of our bound, we scale the scale-invariant parameters in ResNet-18 (i.e., parameters preceding BN layers) for fixed constants {0.5,1.0,2.0,4.0}\{0.5,1.0,2.0,4.0\}. Note that this scaling does not change the function represented by NN due to BN layers, and the error bound should be preserved. Table 1 shows that our bound and related terms are invariant to these transformations. On the other hand, PAC-Bayes-NTK is very sensitive to parameter scale, as shown in Table 7 in Appendix J.

2.5 Connectivity Sharpness and its efficient computation

Now, we focus on the fact that the trace of CTK is also invariant to the parameter scale by Proposition 2.5. Unlike PAC-Bayes-CTK/NTK, a trace of CTK/NTK does not require onerous hyper-parameter selection of δ,α,σ\delta,\alpha,\sigma. Therefore, we simply define CS(θ):=tr(𝐂𝒳θ)\textbf{CS}(\theta^{*}):=\mathrm{tr}(\mathbf{C}^{\theta^{*}}_{\mathcal{X}}) as a practical sharpness measure at θ\theta^{*}, named Connectivity Sharpness (CS) to detour the complex computation of PAC-Bayes-CTK. This metric can be easily applied to find NNs with better generalization, similar to other sharpness metrics (e.g., trace of Hessian), as shown in [7]. We evaluate the detecting performance of CS in Sec. 4.1. The following corollary shows how CS can explain the generalization performance of NNs, conceptually.

Corollary 2.7 (Connectivity sharpness, Informal).

Let us assume CTK and KL divergence term of PAC-Bayes-CTK as defined in Theorem 2.3. Then, if CS vanishes to zero or infinity, the KL divergence term of PAC-Bayes-CTK also does so.

As the trace of a matrix can be efficiently estimated by Hutchinson’s method [46], one can compute the CS without explicitly computing the entire CTK. We refer to Appendix F for detailed procedures of computing CS. As CS is invariant to function-preserving scale transformations by Theorem 2.5, it also does not contradict the FM hypothesis.

3 Bayesian NNs with scale-invariance

This section provides the practical implications of the posterior distribution used in PAC-Bayes analysis. We interpret the PAC-Bayes posterior θ\mathbb{Q}_{\theta^{*}} in equation 5 as a modified result of LA [8]. Then, we show this modification improves existing LA in the presence of weight decay regularization. Finally, we explain how one can efficiently construct a Bayesian NN from equation 5.

3.1 Pitfalls of weight decay with BN in Laplace Approximation

One can view parameter space version of θ\mathbb{Q}_{\theta^{*}} as a modified version of pLA(ψ|𝒮)p_{\mathrm{LA}}(\psi|\mathcal{S}) in equation 1 by (i) replacing isotropic damping term to the parameter scale-dependent damping term (diag(θ)2\mathrm{diag}(\theta^{*})^{-2}) and (ii) adding perturbation θμ\theta^{*}\odot\mu_{\mathbb{Q}} to the mean of Gaussian distribution. In this section, we focus on the effect of replacing the damping term of LA in the presence of weight decay of batch normalized NNs. We refer to [47, 48] for the discussion on the effect of adding perturbation to the LA with linearized NNs.

The main difference between covariance term of LA equation 1 and equation 7 is the definition of Jacobian (i.e. parameter or connectivity) similar to the difference between empirical CTK and NTK in remark 2.6. Therefore, we name pCL(ψ|𝒮)𝒩(ψ|θ,(diag(θ)2/α2+𝐉θ𝐉θ/σ2)1)p_{\mathrm{CL}}(\psi|\mathcal{S})\sim\mathcal{N}(\psi|\theta^{*},\left(\mathrm{diag}(\theta^{*})^{-2}/\alpha^{2}+\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}/\sigma^{2}\right)^{-1}) as Connectivity Laplace (CL) approximated posterior.

To compare CL posterior and existing LA, we explain how weight decay regularization with BN produces unexpected side effects in existing LA. This side effect can be quantified if we consider linearized NN for LA, called Linearized Laplace (LL; Foong et al. [49]). Note that LL is well known to be better calibrated than non-linearized LA for estimating ’in-between’ uncertainty. By assuming σ2α2\sigma^{2}\ll\alpha^{2}, the predictive distribution of linearized NN for equation 1 and CL are

fθlin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\psi) |pLA(ψ|𝒮)𝒩(f(x,θ),α2Θxxθα2Θx𝒳θΘ𝒳θ1Θ𝒳xθ)\displaystyle|p_{\mathrm{LA}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\Theta_{xx}^{\theta^{*}}-\alpha^{2}\Theta_{x\mathcal{X}}^{\theta^{*}}\Theta_{\mathcal{X}}^{\theta^{*}-1}\Theta_{\mathcal{X}x}^{\theta^{*}}) (10)
fθlin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\psi) |pCL(ψ|𝒮)𝒩(f(x,θ),α2𝐂xxθα2𝐂x𝒳θ𝐂𝒳θ1𝐂𝒳xθ).\displaystyle|p_{\mathrm{CL}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\mathbf{C}_{xx}^{\theta^{*}}-\alpha^{2}\mathbf{C}_{x\mathcal{X}}^{\theta^{*}}\mathbf{C}_{\mathcal{X}}^{\theta^{*}-1}\mathbf{C}_{\mathcal{X}x}^{\theta^{*}}). (11)

for any input xdx\in\mathbb{R}^{d} where 𝒳\mathcal{X} in subscript of CTK/NTK means concatenation. We refer to Appendix G for the detailed derivations. The following proposition shows how weight decay regularization on scale-invariant parameters introduced by BN can amplify the predictive uncertainty of equation 10. Note that the primal regularizing effect of weight decay originates through regularization on scale-invariant parameters [50, 13].

Proposition 3.1 (Uncertainty amplifying effect for LL).

Let us assume that 𝒲γ:PP\mathcal{W}_{\gamma}:\mathbb{R}^{P}\rightarrow\mathbb{R}^{P} is a weight decay regularization on scale-invariant parameters (e.g., parameters preceding BN layers) by multiplying γ<1\gamma<1 and all the non-scale-invariant parameters are fixed. Then, predictive uncertainty of LL is amplified by 1/γ2>11/\gamma^{2}>1 while predictive uncertainty of CTK is preserved:

VarψpLA(ψ|𝒮)(f𝒲γ(θ)lin(x,ψ))=VarψpLA(ψ|𝒮)(fθlin(x,ψ))/γ2\displaystyle\textrm{Var}_{\psi\sim p_{\mathrm{LA}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi))=\textrm{Var}_{\psi\sim p_{\mathrm{LA}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\theta^{*}}(x,\psi))/\gamma^{2}
VarψpCL(ψ|𝒮)(f𝒲γ(θ)lin(x,ψ))=VarψpCL(ψ|𝒮)(fθlin(x,ψ))\displaystyle\textrm{Var}_{\psi\sim p_{\mathrm{CL}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi))=\textrm{Var}_{\psi\sim p_{\mathrm{CL}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\theta^{*}}(x,\psi))

where Var()\textrm{Var}(\cdot) is variance of random variable.

Recently, [47, 48] observed similar pitfalls in Proposition 3.1. However, their solution requires a more complicated hyper-parameter search: independent prior selection for each normalized parameter group. On the other hand, CL does not increase the hyper-parameter to be optimized compared to LL. We believe this difference will make CL more attractive to practitioners.

4 Experiments

Here we describe experiments that demonstrate (i) the effectiveness of Connectivity Sharpness (CS) as a generalization measurement metric and (ii) the usefulness of Connectivity Laplace (CL) as an efficient general-purpose Bayesian NN: CL resolves the contradiction of FM hypothesis and shows stable calibration performance to the selection of prior scale.

4.1 Connectivity Sharpness as a generalization measurement metric

To verify that the CS actually has a better correlation with generalization performance compared to existing sharpness measures, we evaluate the three correlation metrics on the CIFAR-10 dataset: (a) Kendall’s rank-correlation coefficient (τ\tau) [51] (b) granulated Kendall’s coefficient and their average (Ψ\Psi) [7] (c) conditional independence test (𝒦\mathcal{K}) [7]. For all correlation metrics, a larger value means a stronger correlation between sharpness and generalization.

We compare CS to following baseline sharpness measures: trace of Hessian (tr(𝐇)\mathrm{tr}(\mathbf{H}); Keskar et al. [19]), trace of empirical Fisher (tr(𝐅)\mathrm{tr}(\mathbf{F}); Jastrzebski et al. [52]), trace of empirical NTK at θ\theta^{*} (tr(𝚯θ)\mathrm{tr}(\mathbf{\Theta^{\theta^{*}}})), Fisher-Rao (FR ;Liang et al. [18]) metric, Adaptive Sharpness (AS; Kwon et al. [15]), and four PAC-Bayes bound based measures: Sharpness-Orig. (SO), Pacbayes-Orig. (PO), Sharpness-Mag. (SM), and Pacbayes-Mag. (PM), which are eq. (52), (49), (62), (61) in Jiang et al. [7]. For computing granulated Kendall’s correlation, we use 5 hyper-parameters (network depth, network width, learning rate, weight decay, and mini-batch size) and 3 options for each (thus we train models with 35=2433^{5}=243 different training configurations). We vary the depth and width of NN based on VGG-13 [53]. We refer to Appendix H for experimental details.

Table 2: Correlation analysis of sharpness measures with generalization gap. We refer Sec. 4.1 for the details of sharpness measures (row) and correlation metrics for sharpness-generalization relationship (column).
tr(𝐇\mathbf{H}) tr(𝐅\mathbf{F}) tr(𝚯θ\mathbf{\Theta}^{\theta^{*}}) SO PO SM PM AS FR CS
τ\tau (rank corr.) 0.706 0.679 0.703 0.490 0.436 0.473 0.636 0.755 0.649 0.837
network depth 0.764 0.652 0.978 -0.358 -0.719 0.774 0.545 0.756 0.771 0.978
network width 0.687 0.922 0.330 -0.533 -0.575 0.495 0.564 0.827 0.921 0.978
mini-batch size 0.976 0.810 0.988 0.859 0.893 0.909 0.750 0.829 0.685 0.905
learning rate 0.966 0.713 1.000 0.829 0.874 0.057 0.621 0.794 0.565 0.897
weight decay -0.031 -0.103 0.402 0.647 0.711 0.168 0.211 0.710 0.373 0.742
Ψ\Psi (avg.) 0.672 0.599 0.739 0.289 0.237 0.481 0.538 0.783 0.663 0.900
𝒦\mathcal{K} (cond. MI) 0.320 0.243 0.352 0.039 0.041 0.049 0.376 0.483 0.288 0.539

In Table 2, CS shows the best results for τ\tau, Ψ\Psi, and 𝒦\mathcal{K} compared to all other sharpness measures. Also, granulated Kendall of CS is higher than other sharpness measures for 3 out of 5 hyper-parameters and competitive to other sharpness measures with the leftover hyper-parameters. The main difference of CS with other sharpness measures is in the correlation with network width and weight decay: For network width, we found that sharpness measures except CS, tr(𝐅)\mathrm{tr}(\mathbf{F}), AS and FR fail to capture strong correlation. While SO/PO can capture the correlation with weight decay, we believe it is due to the weight norm term of SO/PO. However, this term would interfere in capturing the sharpness-generalization correlation related to the number of parameters (i.e., width/depth), while CS/AS does not suffer from such a problem. Also, it is notable that FR fails to capture this correlation despite its invariant property.

4.2 Connectivity Laplace as an efficient general-purpose Bayesian NN

To assess the effectiveness of CL as a general-purpose Bayesian NN, we consider uncertainty calibration on UCI dataset and CIFAR-10/100.

Table 3: Test negative log-likelihood on two UCI variants [54, 49]
Original [54] GAP variants [49]
Deep Ensemble MCDO LL CL Deep Ensemble MCDO LL CL
boston_housing 2.90 ± 0.03 2.63 ± 0.01 2.85 ± 0.01 2.88 ± 0.02 2.71 ± 0.01 2.68 ± 0.01 2.74 ± 0.01 2.75 ± 0.01
concrete_strength 3.06 ± 0.01 3.20 ± 0.00 3.22 ± 0.01 3.11 ± 0.02 4.03 ± 0.07 3.42 ± 0.00 3.47 ± 0.01 4.03 ± 0.02
energy_efficiency 0.74 ± 0.01 1.92 ± 0.01 2.12 ± 0.01 0.83 ± 0.01 0.77 ± 0.01 1.78 ± 0.01 2.02 ± 0.01 0.90 ± 0.02
kin8nm -1.07 ± 0.00 -0.80 ± 0.01 -0.90 ± 0.00 -1.07 ± 0.00 -0.94 ± 0.00 -0.71 ± 0.00 -0.87 ± 0.00 -0.93 ± 0.00
naval_propulsion -4.83 ± 0.00 -3.85 ± 0.00 -4.57 ± 0.00 -4.76 ± 0.00 -2.22 ± 0.33 -3.36 ± 0.01 -3.66 ± 0.11 -3.80 ± 0.07
power_plant 2.81 ± 0.00 2.91 ± 0.00 2.91 ± 0.00 2.81 ± 0.00 2.91 ± 0.00 2.97 ± 0.00 2.98 ± 0.00 2.91 ± 0.00
protein_structure 2.89 ± 0.00 2.96 ± 0.00 2.91 ± 0.00 2.89 ± 0.00 3.11 ± 0.00 3.07 ± 0.00 3.07 ± 0.00 3.13 ± 0.00
wine 1.21 ± 0.00 0.96 ± 0.01 1.24 ± 0.01 1.27 ± 0.01 1.48 ± 0.01 1.03 ± 0.00 1.45 ± 0.01 1.43 ± 0.00
yacht_hydrodynamics 1.26 ± 0.04 2.17 ± 0.06 1.20 ± 0.04 1.25 ± 0.04 1.71 ± 0.03 3.06 ± 0.02 1.78 ± 0.02 1.74 ± 0.01

UCI regression datasets We implement full-curvature versions of LL and CL and evaluate these to the 9 UCI regression datasets [54] and its GAP-variants [49] to compare calibration performance on in-between uncertainty. We train MLP with a single hidden layer. We fix σ=1\sigma=1 and choose α\alpha from {0.01, 0.1, 1, 10, 100} using log-likelihood of validation dataset. We use 8 random seeds to compute the average and standard error of the test negative log-likelihoods. The following two tables show test NLL for LL/CL and 2 baselines (Deep Ensemble [55] and Monte-Carlo DropOut (MCDO; Gal and Ghahramani [56])). Eight ensemble members are used in Deep Ensemble, and 32 MC samples are used in LL, CL, and MCDO. Table 3 show that CL performs better than LL for 6 out of 9 datasets. Although LL shows better calibration results for 3 datasets in both settings, we would like to point out that the performance gaps between LL and CL were not severe as in the other 6 datasets, where CL performs better.

Image Classification We evaluate the uncertainty calibration performance of CL on CIFAR-10/100. As baseline methods, we consider Deterministic network, Monte-Carlo Dropout (MCDO; [56]), Monte-Carlo Batch Normalization (MCBN; [57]), and Deep Ensemble [55], Batch Ensemble [58], LL [25, 9]. We use Randomize-Then-Optimize (RTO) implementation of LL/CL in Appendix E. We measure Expected Calibration Error (ECE; Guo et al. [59]), negative log-likelihood (NLL), and Brier score (Brier.) for ensemble predictions. We also measure the area under receiver operating curve (AUC) for OOD detection, where we set the SVHN [60] dataset as an OOD dataset. For more details on the experimental setting, please refer to Appendix I.

Table 4 shows uncertainty calibration results on CIFAR-100. We refer to Appendix I for results on other settings, including CIFAR-10 and VGGNet [53]. Our CL shows better results than baselines for all uncertainty calibration metrics (NLL, ECE, Brier., and AUC) except Deep Ensemble. This means scale-invariance of CTK improves Bayesian inference, which is consistent with the results in toy examples. Although the Deep Ensemble presents the best results in 3 out of 4 metrics, it requires full training from initialization for each ensemble member, while LL/CL requires only a post-hoc training upon the pre-trained NN for each member. Particularly noteworthy is that CL presents competitive results with Deep Ensemble, even with much smaller computations.

Table 4: Uncertainty calibration results on CIFAR-100 [36] for ResNet-18 [37]
CIFAR-100
NLL (\downarrow) ECE (\downarrow) Brier. (\downarrow) AUC (\uparrow)
Deterministic 1.5370 ± 0.0117 0.1115 ± 0.0017 0.3889 ± 0.0031 -
MCDO 1.4264 ± 0.0110 0.0651 ± 0.0008 0.3925 ± 0.0020 0.6907 ± 0.0121
MCBN 1.4689 ± 0.0106 0.0998 ± 0.0016 0.3750 ± 0.0028 0.7982 ± 0.0210
Batch Ensemble 1.4029 ± 0.0031 0.0842 ± 0.0005 0.3582 ± 0.0010 0.7887 ± 0.0115
Deep Ensemble 1.0110 0.0507 0.2740 0.7802
Linearized Laplace 1.1673 ± 0.0093 0.0632 ± 0.0010 0.3597 ± 0.0020 0.8066 ± 0.0120
Connectivity Laplace (Ours) 1.1307 ± 0.0042 0.0524 ± 0.0009 0.3319 ± 0.0005 0.8423 ± 0.0204
Refer to caption
(a) NLL
Refer to caption
(b) ECE
Refer to caption
(c) Brier Score
Figure 1: Sensitivity to α\alpha. Expected calibration error (ECE), Negative Log-likelihood (NLL), and Brier score results on corrupted CIFAR-100 for ResNet-18. Showing the mean (line) and standard deviation (shaded area) across four different seeds.

Robustness to the selection of prior scale Figure 1 shows the uncertainty calibration (i.e. NLL, ECE, Brier) results over various α\alpha values for LL, CL, and Deterministic (Det.) baseline. As mentioned in previous works [27, 28], the uncertainty calibration results of LL is extremely sensitive to the selection of α\alpha. Especially, LL shows severe under-fitting for large α\alpha (i.e. small damping) regime. On the other hand, CL shows stable performance in the various ranges of α\alpha.

5 Conclusion

This study introduced novel PAC-Bayes prior and posterior distributions to extend the robustness of generalization bound w.r.t. parameter transformation by decomposing the scale and connectivity of parameters. The resulting generalization bound is guaranteed to be invariant of any function-preserving scale transformations. This result successfully solved the problem that the contradiction of the FM hypothesis caused by the general scale transformation could not be solved in the existing generalization error bound, thereby allowing the theory to be much closer to reality. As a result of the theoretical enhancement, our posterior distribution for PAC-analysis can also be interpreted as an improved Laplace Approximation without pathological failures in weight decay regularization. Therefore, we expect this fact contributes to reducing the theory-practice gap in understanding the generalization effect of NN, leading to follow-up studies that interpret this effect more clearly.

References

  • Kendall and Gal [2017] Alex Kendall and Yarin Gal. What uncertainties do we need in bayesian deep learning for computer vision? In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Ovadia et al. [2019] Yaniv Ovadia, Emily Fertig, Jie Ren, Zachary Nado, D. Sculley, Sebastian Nowozin, Joshua Dillon, Balaji Lakshminarayanan, and Jasper Snoek. Can you trust your model's uncertainty? evaluating predictive uncertainty under dataset shift. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Neyshabur et al. [2015a] Behnam Neyshabur, Ryota Tomioka, and Nathan Srebro. In search of the real inductive bias: On the role of implicit regularization in deep learning. In ICLR (Workshop), 2015a.
  • Zhang et al. [2017] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals. Understanding deep learning requires rethinking generalization, 2017.
  • Arora et al. [2018] Sanjeev Arora, Rong Ge, Behnam Neyshabur, and Yi Zhang. Stronger generalization bounds for deep nets via a compression approach. In Jennifer Dy and Andreas Krause, editors, Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 254–263. PMLR, 10–15 Jul 2018.
  • Hochreiter and Schmidhuber [1995] Sepp Hochreiter and Jürgen Schmidhuber. Simplifying neural nets by discovering flat minima. In G. Tesauro, D. Touretzky, and T. Leen, editors, Advances in Neural Information Processing Systems, volume 7. MIT Press, 1995.
  • Jiang et al. [2020] Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and Samy Bengio. Fantastic generalization measures and where to find them. In International Conference on Learning Representations, 2020.
  • MacKay [1992] David J. C. MacKay. A practical bayesian framework for backpropagation networks. Neural Comput., 4(3):448–472, may 1992. ISSN 0899-7667. doi: 10.1162/neco.1992.4.3.448.
  • Daxberger et al. [2021a] Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. Laplace redux - effortless bayesian deep learning. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021a.
  • Dinh et al. [2017] Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 1019–1028. PMLR, 06–11 Aug 2017.
  • Li et al. [2018] Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, and Tom Goldstein. Visualizing the loss landscape of neural nets. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018.
  • Loshchilov and Hutter [2019] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. In International Conference on Learning Representations, 2019.
  • Zhang et al. [2019] Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mechanisms of weight decay regularization. In International Conference on Learning Representations, 2019.
  • Tsuzuku et al. [2020] Yusuke Tsuzuku, Issei Sato, and Masashi Sugiyama. Normalized flat minima: Exploring scale invariant definition of flat minima for neural networks using PAC-Bayesian analysis. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 9636–9647. PMLR, 13–18 Jul 2020.
  • Kwon et al. [2021] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. arXiv preprint arXiv:2102.11600, 2021.
  • Petzka et al. [2021] Henning Petzka, Michael Kamp, Linara Adilova, Cristian Sminchisescu, and Mario Boley. Relative flatness and generalization. Advances in Neural Information Processing Systems, 34, 2021.
  • Jacot et al. [2018] Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018.
  • Liang et al. [2019] Tengyuan Liang, Tomaso Poggio, Alexander Rakhlin, and James Stokes. Fisher-rao metric, geometry, and complexity of neural networks. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 888–896. PMLR, 2019.
  • Keskar et al. [2017] Nitish Shirish Keskar, Jorge Nocedal, Ping Tak Peter Tang, Dheevatsa Mudigere, and Mikhail Smelyanskiy. On large-batch training for deep learning: Generalization gap and sharp minima. In 5th International Conference on Learning Representations, ICLR 2017, 2017.
  • Neyshabur et al. [2017] Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, and Nathan Srebro. Exploring generalization in deep learning. arXiv preprint arXiv:1706.08947, 2017.
  • Lee et al. [2019a] Jaehoon Lee, Lechao Xiao, Samuel Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, and Jeffrey Pennington. Wide neural networks of any depth evolve as linear models under gradient descent. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019a.
  • He et al. [2020] Bobby He, Balaji Lakshminarayanan, and Yee Whye Teh. Bayesian deep ensembles via the neural tangent kernel. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 1010–1022. Curran Associates, Inc., 2020.
  • Lee et al. [2020] Jaehoon Lee, Samuel Schoenholz, Jeffrey Pennington, Ben Adlam, Lechao Xiao, Roman Novak, and Jascha Sohl-Dickstein. Finite versus infinite neural networks: an empirical study. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 15156–15172. Curran Associates, Inc., 2020.
  • Hui and Belkin [2021] Like Hui and Mikhail Belkin. Evaluation of neural architectures trained with square loss vs cross-entropy in classification tasks. In International Conference on Learning Representations, 2021.
  • Khan et al. [2019] Mohammad Emtiyaz E Khan, Alexander Immer, Ehsan Abedi, and Maciej Korzepa. Approximate inference turns deep networks into gaussian processes. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Immer et al. [2021] Alexander Immer, Maciej Korzepa, and Matthias Bauer. Improving predictions of bayesian neural nets via local linearization. In AISTATS, pages 703–711, 2021.
  • Ritter et al. [2018] Hippolyt Ritter, Aleksandar Botev, and David Barber. A scalable laplace approximation for neural networks. In International Conference on Learning Representations, 2018.
  • Kristiadi et al. [2020] Agustinus Kristiadi, Matthias Hein, and Philipp Hennig. Being bayesian, even just a bit, fixes overconfidence in relu networks. In International conference on machine learning, pages 5436–5446. PMLR, 2020.
  • Daxberger et al. [2021b] Erik Daxberger, Eric Nalisnick, James U Allingham, Javier Antoran, and Jose Miguel Hernandez-Lobato. Bayesian deep learning via subnetwork inference. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 2510–2521. PMLR, 18–24 Jul 2021b.
  • McAllester [1999] David A McAllester. Pac-bayesian model averaging. In Proceedings of the twelfth annual conference on Computational learning theory, pages 164–170, 1999.
  • Perez-Ortiz et al. [2021] Maria Perez-Ortiz, Omar Risvaplata, John Shawe-Taylor, and Csaba Szepesvári. Tighter risk certificates for neural networks. Journal of Machine Learning Research, 22(227):1–40, 2021.
  • Lee et al. [2019b] Namhoon Lee, Thalaiyasingam Ajanthan, and Philip Torr. SNIP: SINGLE-SHOT NETWORK PRUNING BASED ON CONNECTION SENSITIVITY. In International Conference on Learning Representations, 2019b.
  • Lee et al. [2019c] Namhoon Lee, Thalaiyasingam Ajanthan, Stephen Gould, and Philip HS Torr. A signal propagation perspective for pruning neural networks at initialization. arXiv preprint arXiv:1906.06307, 2019c.
  • Achille et al. [2021] Alessandro Achille, Aditya Golatkar, Avinash Ravichandran, Marzia Polito, and Stefano Soatto. Lqf: Linear quadratic fine-tuning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15729–15739, 2021.
  • Maddox et al. [2021] Wesley Maddox, Shuai Tang, Pablo Moreno, Andrew Gordon Wilson, and Andreas Damianou. Fast adaptation with linearized neural networks. In International Conference on Artificial Intelligence and Statistics, pages 2737–2745. PMLR, 2021.
  • Krizhevsky [2009] A Krizhevsky. Learning multiple layers of features from tiny images. Master’s thesis, University of Toronto, 2009.
  • He et al. [2016] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • Loshchilov and Hutter [2016] Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  • Kingma and Ba [2014] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Ghorbani et al. [2019] Behrooz Ghorbani, Shankar Krishnan, and Ying Xiao. An investigation into neural net optimization via hessian eigenvalue density. In International Conference on Machine Learning, pages 2232–2241. PMLR, 2019.
  • Bardsley et al. [2014] Johnathan M Bardsley, Antti Solonen, Heikki Haario, and Marko Laine. Randomize-then-optimize: A method for sampling from posterior distributions in nonlinear inverse problems. SIAM Journal on Scientific Computing, 36(4):A1895–A1910, 2014.
  • Matthews et al. [2017] Alexander G de G Matthews, Jiri Hron, Richard E Turner, and Zoubin Ghahramani. Sample-then-optimize posterior sampling for bayesian linear models. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2017.
  • Zhou et al. [2018] Wenda Zhou, Victor Veitch, Morgane Austern, Ryan P Adams, and Peter Orbanz. Non-vacuous generalization bounds at the imagenet scale: a pac-bayesian compression approach. arXiv preprint arXiv:1804.05862, 2018.
  • Dziugaite and Roy [2017] Gintare Karolina Dziugaite and Daniel M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In Proceedings of the 33rd Annual Conference on Uncertainty in Artificial Intelligence (UAI), 2017.
  • Foret et al. [2020] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412, 2020.
  • Hutchinson [1989] Michael F Hutchinson. A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics-Simulation and Computation, 18(3):1059–1076, 1989.
  • Antoran et al. [2021] Javier Antoran, James Urquhart Allingham, David Janz, Erik Daxberger, Eric Nalisnick, and José Miguel Hernández-Lobato. Linearised laplace inference in networks with normalisation layers and the neural g-prior. In Fourth Symposium on Advances in Approximate Bayesian Inference, 2021.
  • Antoran et al. [2022] Javier Antoran, David Janz, James U Allingham, Erik Daxberger, Riccardo Rb Barbano, Eric Nalisnick, and José Miguel Hernández-Lobato. Adapting the linearised laplace model evidence for modern deep learning. In International Conference on Machine Learning, pages 796–821. PMLR, 2022.
  • Foong et al. [2019] Andrew YK Foong, Yingzhen Li, José Miguel Hernández-Lobato, and Richard E Turner. ’in-between’uncertainty in bayesian neural networks. arXiv preprint arXiv:1906.11537, 2019.
  • Van Laarhoven [2017] Twan Van Laarhoven. L2 regularization versus batch and weight normalization. arXiv preprint arXiv:1706.05350, 2017.
  • Kendall [1938] Maurice G Kendall. A new measure of rank correlation. Biometrika, 30(1/2):81–93, 1938.
  • Jastrzebski et al. [2021] Stanislaw Jastrzebski, Devansh Arpit, Oliver Astrand, Giancarlo B Kerg, Huan Wang, Caiming Xiong, Richard Socher, Kyunghyun Cho, and Krzysztof J Geras. Catastrophic fisher explosion: Early phase fisher matrix impacts generalization. In International Conference on Machine Learning, pages 4772–4784. PMLR, 2021.
  • Simonyan and Zisserman [2015] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. In International Conference on Learning Representations, 2015.
  • Hernández-Lobato and Adams [2015] José Miguel Hernández-Lobato and Ryan Adams. Probabilistic backpropagation for scalable learning of bayesian neural networks. In International conference on machine learning, pages 1861–1869. PMLR, 2015.
  • Lakshminarayanan et al. [2017] Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Gal and Ghahramani [2016] Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In international conference on machine learning, pages 1050–1059. PMLR, 2016.
  • Teye et al. [2018] Mattias Teye, Hossein Azizpour, and Kevin Smith. Bayesian uncertainty estimation for batch normalized deep networks. In International Conference on Machine Learning, pages 4907–4916. PMLR, 2018.
  • Wen et al. [2020] Yeming Wen, Dustin Tran, and Jimmy Ba. Batchensemble: an alternative approach to efficient ensemble and lifelong learning. arXiv preprint arXiv:2002.06715, 2020.
  • Guo et al. [2017] Chuan Guo, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger. On calibration of modern neural networks. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 1321–1330. PMLR, 06–11 Aug 2017.
  • Netzer et al. [2011] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y. Ng. Reading digits in natural images with unsupervised feature learning. In NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011, 2011.
  • Deng et al. [2009] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • Bishop [2006] Christopher M. Bishop. Pattern Recognition and Machine Learning. Springer, 2006.
  • Murphy [2012] Kevin P Murphy. Machine learning: a probabilistic perspective. MIT press, 2012.
  • Neyshabur et al. [2015b] Behnam Neyshabur, Russ R Salakhutdinov, and Nati Srebro. Path-sgd: Path-normalized optimization in deep neural networks. In C. Cortes, N. Lawrence, D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 28. Curran Associates, Inc., 2015b.
  • Ioffe and Szegedy [2015] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In Francis Bach and David Blei, editors, Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pages 448–456, Lille, France, 07–09 Jul 2015. PMLR.
  • Lobacheva et al. [2021] Ekaterina Lobacheva, Maxim Kodryan, Nadezhda Chirkova, Andrey Malinin, and Dmitry P Vetrov. On the periodic behavior of neural network training with batch normalization and weight decay. Advances in Neural Information Processing Systems, 34:21545–21556, 2021.
  • Devlin et al. [2018] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • Meyer et al. [2021] Raphael A Meyer, Cameron Musco, Christopher Musco, and David P Woodruff. Hutch++: Optimal stochastic trace estimation. In Symposium on Simplicity in Algorithms (SOSA), pages 142–155. SIAM, 2021.
  • Bradbury et al. [2018] James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs. github, 2018.
  • Woodbury [1950] M.A. Woodbury. Inverting Modified Matrices. Memorandum Report / Statistical Research Group, Princeton. Statistical Research Group, 1950.
  • Ren et al. [2019] Jie Ren, Peter J. Liu, Emily Fertig, Jasper Snoek, Ryan Poplin, Mark Depristo, Joshua Dillon, and Balaji Lakshminarayanan. Likelihood ratios for out-of-distribution detection. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019.
  • Van Amersfoort et al. [2020] Joost Van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Uncertainty estimation using a single deep deterministic neural network. In Hal Daumé III and Aarti Singh, editors, Proceedings of the 37th International Conference on Machine Learning, volume 119 of Proceedings of Machine Learning Research, pages 9690–9700. PMLR, 13–18 Jul 2020.

Appendix A Proofs

A.1 Proof of Proposition 2.2

Proof.

Since the prior θ(c)\mathbb{P}_{\theta^{*}}(c) is independent to the parameter scale, θ(c)=d𝒯(θ)(c)\mathbb{P}_{\theta^{*}}(c)\stackrel{{\scriptstyle d}}{{=}}\mathbb{P}_{\mathcal{T}(\theta^{*})}(c) is trivial. For Jacobian w.r.t. parameters, we have

𝐉θ(x,𝒯(ψ))=𝒯(ψ)f(x,𝒯(ψ))=𝒯(ψ)f(x,ψ)=𝐉θ(x,ψ)𝒯1\displaystyle\mathbf{J}_{\theta}(x,\mathcal{T}(\psi))=\frac{\partial}{\partial\mathcal{T}(\psi)}f(x,\mathcal{T}(\psi))=\frac{\partial}{\partial\mathcal{T}(\psi)}f(x,\psi)=\mathbf{J}_{\theta}(x,\psi)\mathcal{T}^{-1}

Then, the Jacobian of NN w.r.t. connectivity at 𝒯(ψ)\mathcal{T}(\psi) holds

𝐉θ(x,𝒯(ψ))diag(𝒯(ψ))\displaystyle\mathbf{J}_{\theta}(x,\mathcal{T}(\psi))\mathrm{diag}(\mathcal{T}(\psi)) =𝐉θ(x,ψ)𝒯1𝒯diag(ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathcal{T}^{-1}\mathcal{T}\mathrm{diag}(\psi) (12)
=𝐉θ(x,ψ)diag(ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\psi) (13)

where the first equality holds from the above one and the fact that 𝒯\mathcal{T} is a diagonal linear transformation. Therefore, the covariance of posterior is invariant to 𝒯\mathcal{T}.

(𝐈Pα2+diag(𝒯(θ))𝐉θ(𝒳,𝒯(θ))𝐉θ(𝒳,𝒯(θ))diag(𝒯(θ))σ2)1\displaystyle\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathrm{diag}(\mathcal{T}(\theta^{*}))\mathbf{J}_{\theta}^{\top}(\mathcal{X},\mathcal{T}(\theta^{*}))\mathbf{J}_{\theta}(\mathcal{X},\mathcal{T}(\theta^{*}))\mathrm{diag}(\mathcal{T}(\theta^{*}))}{\sigma^{2}}\right)^{-1}
=(𝐈Pα2+diag(θ)𝐉θ(𝒳,θ)𝐉θ(𝒳,θ)diag(θ)σ2)1\displaystyle=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}(\mathcal{X},\theta^{*})\mathbf{J}_{\theta}(\mathcal{X},\theta^{*})\mathrm{diag}(\theta^{*})}{\sigma^{2}}\right)^{-1}
=(𝐈Pα2+diag(θ)𝐉θ𝐉θdiag(θ)σ2)1\displaystyle=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}\mathrm{diag}(\theta^{*})}{\sigma^{2}}\right)^{-1}

Moreover, the mean of posterior is also invariant to 𝒯\mathcal{T}.

Σdiag(𝒯(θ))𝐉θ(𝒳,𝒯(θ))(𝒴f(𝒳,𝒯(θ)))σ2\displaystyle\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\mathcal{T}(\theta^{*}))\mathbf{J}_{\theta}^{\top}(\mathcal{X},\mathcal{T}(\theta^{*}))\left(\mathcal{Y}-f(\mathcal{X},\mathcal{T}(\theta^{*}))\right)}{\sigma^{2}}
=Σdiag(𝒯(θ))𝐉θ(𝒳,𝒯(θ))(𝒴f(𝒳,θ))σ2\displaystyle=\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\mathcal{T}(\theta^{*}))\mathbf{J}_{\theta}^{\top}(\mathcal{X},\mathcal{T}(\theta^{*}))\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}
=Σdiag(θ)𝐉θ(𝒳,θ)(𝒴f(𝒳,θ))σ2\displaystyle=\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}(\mathcal{X},\theta^{*})\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}
=Σdiag(θ)𝐉θ(𝒴f(𝒳,θ))σ2\displaystyle=\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}

Therefore, equation 6 and equation 7 are invariant to function-preserving scale transformations. The remain part of the theorem is related to the definition of function-preserving scale transformation 𝒯\mathcal{T}. For generalization error, following holds

err𝒟(𝒯(θ))\displaystyle\mathrm{err}_{\mathcal{D}}(\mathbb{Q}_{\mathcal{T}(\theta^{*})}) =𝔼(x,y)𝒟,ψ𝒯(θ)[err(f(x,ψ),y)]\displaystyle=\mathbb{E}_{(x,y)\sim\mathcal{D},\psi\sim\mathbb{Q}_{\mathcal{T}(\theta^{*})}}[\mathrm{err}(f(x,\psi),y)]
=𝔼(x,y)𝒟,c𝒯(θ)[err(gθpert(x,c),y)]\displaystyle=\mathbb{E}_{(x,y)\sim\mathcal{D},c\sim\mathbb{Q}_{\mathcal{T}(\theta^{*})}}[\mathrm{err}(g^{\mathrm{pert}}_{\theta^{*}}(x,c),y)]
=𝔼(x,y)𝒟,cθ[err(gθpert(x,c),y)]\displaystyle=\mathbb{E}_{(x,y)\sim\mathcal{D},c\sim\mathbb{Q}_{\theta^{*}}}[\mathrm{err}(g^{\mathrm{pert}}_{\theta^{*}}(x,c),y)]
=𝔼(x,y)𝒟,ψθ[err(f(x,ψ),y)]\displaystyle=\mathbb{E}_{(x,y)\sim\mathcal{D},\psi\sim\mathbb{Q}_{\theta^{*}}}[\mathrm{err}(f(x,\psi),y)]
=err𝒟(θ)\displaystyle=\mathrm{err}_{\mathcal{D}}(\mathbb{Q}_{\theta^{*}})

This proof can be extended to the empirical error err𝒮\mathrm{err}_{\mathcal{S}_{\mathbb{Q}}}. ∎

A.2 Proof of Theorem 2.3

Proof.

(Construction of KL divergence) To construct PAC-Bayes-CTK, we need to arrange KL divergence between posterior and prior as follows:

KL[]\displaystyle\mathrm{KL}[\mathbb{\mathbb{Q}}\|\mathbb{P}] =12(tr(Σ1(Σ+(μμ)(μμ)))+log|Σ|log|Σ|P)\displaystyle=\frac{1}{2}\left(\mathrm{tr}\left(\Sigma_{\mathbb{P}}^{-1}(\Sigma_{\mathbb{Q}}+(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})^{\top})\right)+\log|\Sigma_{\mathbb{P}}|-\log|\Sigma_{\mathbb{Q}}|-P\right)
=12tr(Σ1(μμ)(μμ)))+12(tr(Σ1Σ)+log|Σ|log|Σ|P)\displaystyle=\frac{1}{2}\mathrm{tr}(\Sigma_{\mathbb{P}}^{-1}(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})^{\top}))+\frac{1}{2}\left(\mathrm{tr}(\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}})+\log|\Sigma_{\mathbb{P}}|-\log|\Sigma_{\mathbb{Q}}|-P\right)
=12(μμ)Σ1(μμ)+12(tr(Σ1Σ)log|Σ1Σ|P)\displaystyle=\frac{1}{2}(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})^{\top}\Sigma_{\mathbb{P}}^{-1}(\mu_{\mathbb{Q}}-\mu_{\mathbb{P}})+\frac{1}{2}\left(\mathrm{tr}(\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}})-\log|\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}}|-P\right)
=μμ2α2perturbation+12(tr(Σ1Σ)log|Σ1Σ|p)sharpness\displaystyle=\underbrace{\frac{\mu_{\mathbb{Q}}^{\top}\mu_{\mathbb{Q}}}{2\alpha^{2}}}_{\textrm{perturbation}}+\underbrace{\frac{1}{2}\left(\mathrm{tr}(\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}})-\log|\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}}|-p\right)}_{\textrm{sharpness}}

where the first equality uses the KL divergence between two Gaussian distributions, the thrid equality uses trace property (tr(AB)=tr(BA)\mathrm{tr}(AB)=\mathrm{tr}(BA) and tr(a)=a\mathrm{tr}(a)=a for scalar aa), and the last equality uses the definition of PAC-Bayes prior (θ(c)=𝒩(c|𝟎P,α2𝐈P)\mathbb{P}_{\theta^{*}}(c)=\mathcal{N}(c|\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{P})). For sharpness term, we first compute the Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} term as

Σ1Σ=(𝐈P+α2σ2𝐉c𝐉c)1\displaystyle\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}}=\left(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\mathbf{J}_{c}^{\top}\mathbf{J}_{c}\right)^{-1}

Since α2,σ2>0\alpha^{2},\sigma^{2}>0 and 𝐉c𝐉c\mathbf{J}_{c}^{\top}\mathbf{J}_{c} is positive semi-definite, the matrix Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} have non-zero eigenvalues of {βi}i=1P\{\beta_{i}\}_{i=1}^{P}. Since trace is the sum of eigenvalues and log-determinant is the sum of log-eigenvalues, we have

KL[]\displaystyle\mathrm{KL}[\mathbb{\mathbb{Q}}\|\mathbb{P}] =μμ2α2+12i=1P(βilog(βi)1)\displaystyle=\frac{\mu_{\mathbb{Q}}^{\top}\mu_{\mathbb{Q}}}{2\alpha^{2}}+\frac{1}{2}\sum_{i=1}^{P}\left(\beta_{i}-\log(\beta_{i})-1\right)
=μμ2α2+12i=1Ph(βi)\displaystyle=\frac{\mu_{\mathbb{Q}}^{\top}\mu_{\mathbb{Q}}}{2\alpha^{2}}+\frac{1}{2}\sum_{i=1}^{P}h(\beta_{i})

where h(x)=xlog(x)1h(x)=x-\log(x)-1. By plugging this KL divergence to the equation 2, we get equation 8.

(Eigenvalues of Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}}) To show the scale-invariance of PAC-Bayes-CTK, it is sufficient to show that KL divergence posterior and prior is scale-invariant: log(2N/δ)2N\frac{\log(2\sqrt{N_{\mathbb{Q}}}/\delta)}{2N_{\mathbb{Q}}} is independent to KL PAC-Bayes prior/posterior and we already show the invariance property of empirical/generalization error term in Proposition 2.2. To show the invariance property of KL divergence, we consider the Connectivity Tangent Kernel (CTK) as defined in equation 2.4:

𝐂𝒳θ:=𝐉c𝐉c=𝐉θdiag(θ)2𝐉θNK×NK.\displaystyle\mathbf{C}_{\mathcal{X}}^{\theta^{*}}:=\mathbf{J}_{c}\mathbf{J}_{c}^{\top}=\mathbf{J}_{\theta}\mathrm{diag}(\theta^{*})^{2}\mathbf{J}_{\theta}^{\top}\in\mathbb{R}^{NK\times NK}.

Since CTK is a real-symmetric matrix, one can assume the eigenvalue decomposition of CTK as 𝐂𝒳θ=QΛQ\mathbf{C}_{\mathcal{X}}^{\theta^{*}}=Q\Lambda Q^{\top} where QNK×NKQ\in\mathbb{R}^{NK\times NK} is an orthogonal matrix and ΛNK×NK\Lambda\in\mathbb{R}^{NK\times NK} is a diagonal matrix. Then following holds for Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}}

Σ1Σ\displaystyle\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} =(𝐈P+α2σ2𝐉c𝐉c)1\displaystyle=\left(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\mathbf{J}_{c}^{\top}\mathbf{J}_{c}\right)^{-1}
=(𝐈P+α2σ2QΛQ)1\displaystyle=\left(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}Q\Lambda Q^{\top}\right)^{-1}
=Q(𝐈P+α2σ2Λ)1Q\displaystyle=Q\left(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\Lambda\right)^{-1}Q^{\top}

Therefore, eigenvalues of Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} are 11+α2λi/σ2=σ2σ2+α2λi\frac{1}{1+\alpha^{2}\lambda_{i}/\sigma^{2}}=\frac{\sigma^{2}}{\sigma^{2}+\alpha^{2}\lambda_{i}} where {λi}i=1P\{\lambda_{i}\}_{i=1}^{P} are eigenvalues of CTK (and diagonal elements of Λ\Lambda).

(Scale invariance of CTK) The scale-invariance property of CTK is a simple application of equation 13:

𝐂xy𝒯(ψ)\displaystyle\mathbf{C}^{\mathcal{T}(\psi)}_{xy} =𝐉θ(x,𝒯(ψ))diag(𝒯(ψ)2)𝐉θ(y,𝒯(ψ))\displaystyle=\mathbf{J}_{\theta}(x,\mathcal{T}(\psi))\mathrm{diag}(\mathcal{T}(\psi)^{2})\mathbf{J}_{\theta}(y,\mathcal{T}(\psi))^{\top}
=𝐉θ(x,ψ)𝒯1𝒯diag(ψ)diag(ψ)𝒯𝒯1𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathcal{T}^{-1}\mathcal{T}\mathrm{diag}(\psi)\mathrm{diag}(\psi)\mathcal{T}\mathcal{T}^{-1}\mathbf{J}_{\theta}(x,\psi)^{\top}
=𝐉θ(x,ψ)diag(ψ)diag(ψ)𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\psi)\mathrm{diag}(\psi)\mathbf{J}_{\theta}(x,\psi)^{\top}
=𝐂xyψ , x,yD,ψP.\displaystyle=\mathbf{C}^{\psi}_{xy}\text{ , }\forall x,y\in\mathbb{R}^{D},\forall\psi\in\mathbb{R}^{P}.

Therefore, CTK is invariant to any function-preserving scale transformation 𝒯\mathcal{T} and so do its eigenvalues. This guarantees the invariance of Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} and its eigenvalues. In summary, we showed the scale-invariance property of sharpness term of KL divergence. Now all that remains is to show the invariance of the perturbation term. However, this is already proved in the proof of Proposition 2.2. Therefore, we show PAC-Bayes-CTK is invariant to any function-preserving scale transformation. ∎

A.3 Proof of Corollary 2.4

Proof.

In proof of Theorem 2.3, we showed that eigenvalues of Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} can be represented as

σ2σ2+α2λi\displaystyle\frac{\sigma^{2}}{\sigma^{2}+\alpha^{2}\lambda_{i}}

where {λi}i=1P\{\lambda_{i}\}_{i=1}^{P} are eigenvalues of CTK. Now, we identify the eigenvalues of CTK. To this end, we assume the singular value decomposition (SVD) of Jacobian w.r.t. connectivity 𝐉cNK×P\mathbf{J}_{c}\in\mathbb{R}^{NK\times P} as 𝐉c=UΣV\mathbf{J}_{c}=U\Sigma V^{\top} where UNK×NKU\in\mathbb{R}^{NK\times NK} and VP×PV\in\mathbb{R}^{P\times P} are orthogonal matrices and ΣNK×P\Sigma\in\mathbb{R}^{NK\times P} is a rectangular diagonal matrix. Then, CTK can be represented as 𝐂𝒳θ=𝐉c𝐉c=UΣVVΣU=UΣ2U\mathbf{C}_{\mathcal{X}}^{\theta^{*}}=\mathbf{J}_{c}\mathbf{J}_{c}^{\top}=U\Sigma V^{\top}V\Sigma U^{\top}=U\Sigma^{2}U^{\top}. In summary, the column vectors of UU are eigenvectors of CTK and eigenvalues of CTK are square of singular values of 𝐉c\mathbf{J}_{c} and so λi0,i\lambda_{i}\geq 0,\forall i. Therefore βi1\beta_{i}\leq 1 for all i=1,,Pi=1,\dots,P for eigenvalues {βi}i=1P\{\beta_{i}\}_{i=1}^{P} of Σ1Σ\Sigma_{\mathbb{P}}^{-1}\Sigma_{\mathbb{Q}} and equality holds for λi=0\lambda_{i}=0. Now all that remains is to show that the sharpness term of PAC-Bayes-CTK is a monotonically increasing function on each eigenvalues of CTK. To show this, we first keep in mind that

s(x):=σ2σ2+α2xs(x):=\frac{\sigma^{2}}{\sigma^{2}+\alpha^{2}x}

is a monotonically decreasing function for x0x\geq 0 and h(x):=xlog(x)1h(x):=x-\log(x)-1 is a monotonically decreasing function for x(0,1]x\in(0,1]. Since sharpness term of KL divergence is

i=1Ph(βi)4N=i=1P(hs)(λi)4N\displaystyle\sum_{i=1}^{P}\frac{h(\beta_{i})}{4N_{\mathbb{Q}}}=\sum_{i=1}^{P}\frac{(h\circ s)(\lambda_{i})}{4N_{\mathbb{Q}}}

This is a monotonically increasing function for x0x\geq 0 since s(x)1s(x)\leq 1 for x0x\geq 0. For your information, we plot y=h(x)y=h(x) and y=(hs)(x)y=(h\circ s)(x) in Figure 2.

Refer to caption
(a) y=h(x)y=h(x)
Refer to caption
(b) y=(hs)(x)y=(h\circ s)(x) where σ=α=1\sigma=\alpha=1
Figure 2: Functions used in proofs

A.4 Proof of Proposition 2.5

We refer to Scale invariance of CTK part of proof of Theorem 2.3. This is a direct application of scale-invariance property of Jacobian w.r.t. connectivity.

A.5 Proof of Corollary 2.7

Proof.

Since CS is trace of CTK, it is a sum of eigenvalues of CTK. As shown in the proof of Corollary 2.4, eigenvalues of CTK are square of singular values of Jacobian w.r.t. connectivity cc. Therefore, eigenvalues of CTK are non-negative and vanishes to zero if CS vanishes to zero.

i=1Pλi=0λi=0βi=s(λi)=1h(βi)=0,i=1,,P\displaystyle\sum_{i=1}^{P}\lambda_{i}=0\Rightarrow\lambda_{i}=0\Rightarrow\beta_{i}=s(\lambda_{i})=1\Rightarrow h(\beta_{i})=0,\quad\forall i=1,\dots,P

This means the sharpness term of KL divergence vanishes to zero. Furthermore, singular values of Jacobian w.r.t. cc also vanishes to zero in this case. Therefore, μ\mu_{\mathbb{Q}} vanishes to zero, also. Similarly, if CS diverges to infinity, this means (at least) one of eigenvalues of CTK diverges to infinity. In this case, following holds

λiβi=s(λi)0h(βi),i=1,,P\displaystyle\lambda_{i}\rightarrow\infty\Rightarrow\beta_{i}=s(\lambda_{i})\rightarrow 0\Rightarrow h(\beta_{i})\rightarrow\infty,\quad\forall i=1,\dots,P

Therefore, KL divergence term of PAC-Bayes-CTK also diverges to infinity. ∎

A.6 Proof of Proposition 3.1

Proof.

By assumption, we fixed all non-scale invariant parameters. This means we exclude these parameters in sampling procedure of CL and LL. In terms of predictive distribution, this can be translated as

fθlin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\psi) |pLA(ψ|𝒮)𝒩(f(x,θ),α2Θ^xxθα2Θ^x𝒳θΘ^𝒳θ1Θ^𝒳xθ)\displaystyle|p_{\mathrm{LA}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\hat{\Theta}_{xx}^{\theta^{*}}-\alpha^{2}\hat{\Theta}_{x\mathcal{X}}^{\theta^{*}}\hat{\Theta}{\mathcal{X}}^{\theta^{*}-1}\hat{\Theta}_{\mathcal{X}x}^{\theta^{*}})
fθlin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\psi) |pCL(ψ|𝒮)𝒩(f(x,θ),α2𝐂^xxθα2𝐂^x𝒳θ𝐂^𝒳θ1𝐂^𝒳xθ)\displaystyle|p_{\mathrm{CL}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\hat{\mathbf{C}}_{xx}^{\theta^{*}}-\alpha^{2}\hat{\mathbf{C}}_{x\mathcal{X}}^{\theta^{*}}\hat{\mathbf{C}}_{\mathcal{X}}^{\theta^{*}-1}\hat{\mathbf{C}}_{\mathcal{X}x}^{\theta^{*}})

where Θ^xxψ:=i𝒫f(x,ψ)θif(x,ψ)θi\hat{\Theta}_{xx^{\prime}}^{\psi}:=\sum_{i\in\mathcal{P}}\frac{\partial f(x,\psi)}{\partial\theta_{i}}\frac{\partial f(x^{\prime},\psi)}{\partial\theta_{i}} and 𝐂^xxψ:=i𝒫f(x,ψ)θif(x,ψ)θi(ψi)2\hat{\mathbf{C}}_{xx^{\prime}}^{\psi}:=\sum_{i\in\mathcal{P}}\frac{\partial f(x,\psi)}{\partial\theta_{i}}\frac{\partial f(x^{\prime},\psi)}{\partial\theta_{i}}(\psi_{i})^{2} for scale-invariant parameter set 𝒫\mathcal{P}. Thereby, we mask the gradient of non scale-invariant parameter as zero. Therefore, this can be arrange as follows

Θ^xxψ=𝐉θ(x,ψ)diag(𝟏𝒫)𝐉θ(x,ψ),𝐂^xxψ=𝐉θ(x,ψ)diag(ψ)diag(𝟏𝒫)diag(ψ)𝐉θ(x,ψ)\displaystyle\hat{\Theta}_{xx^{\prime}}^{\psi}=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathbf{J}_{\theta}(x,\psi)^{\top},\quad\hat{\mathbf{C}}_{xx^{\prime}}^{\psi}=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathrm{diag}(\psi)\mathbf{J}_{\theta}(x,\psi)^{\top}

where 𝟏𝒫P\mathbf{1}_{\mathcal{P}}\in\mathbb{R}^{P} is a masking vector (i.e., one for included components and zero for excluded components). Then, the weight decay regularization for scale-invariant parameters can be represented as

𝒲γ(ψ)i={γψi,ifψi𝒫.ψi,ifψi𝒫.\displaystyle\mathcal{W}_{\gamma}(\psi)_{i}=\begin{cases}\gamma\psi_{i},&\text{if}\quad\psi_{i}\in\mathcal{P}.\\ \psi_{i},&\text{if}\quad\psi_{i}\not\in\mathcal{P}.\\ \end{cases}

Therefore, we get

Θ^xx𝒲γ(ψ)\displaystyle\hat{\Theta}_{xx^{\prime}}^{\mathcal{W}_{\gamma}(\psi)} =𝐉θ(x,𝒲γ(ψ))diag(𝟏𝒫)𝐉θ(x,𝒲γ(ψ)))\displaystyle=\mathbf{J}_{\theta}(x,\mathcal{W}_{\gamma}(\psi))\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathbf{J}_{\theta}(x,\mathcal{W}_{\gamma}(\psi)))^{\top}
=𝐉θ(x,ψ)𝒲γ1diag(𝟏𝒫)𝒲γ1𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathcal{W}_{\gamma}^{-1}\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathcal{W}_{\gamma}^{-1}\mathbf{J}_{\theta}(x,\psi)^{\top}
=𝐉θ(x,ψ)diag(𝟏𝒫/γ2)𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}}/\gamma^{2})\mathbf{J}_{\theta}(x,\psi)^{\top}
=1/γ2𝐉θ(x,ψ)diag(𝟏𝒫)𝐉θ(x,ψ)\displaystyle=1/\gamma^{2}\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathbf{J}_{\theta}(x,\psi)^{\top}
=1/γ2Θ^xxψ\displaystyle=1/\gamma^{2}\hat{\Theta}_{xx^{\prime}}^{\psi}

for empirical NTK and

𝐂^xx𝒲γ(ψ)\displaystyle\hat{\mathbf{C}}_{xx^{\prime}}^{\mathcal{W}_{\gamma}(\psi)} =𝐉θ(x,𝒲γ(ψ))diag(𝒲γ(ψ))diag(𝟏𝒫)diag(𝒲γ(ψ))𝐉θ(x,𝒲γ(ψ)))\displaystyle=\mathbf{J}_{\theta}(x,\mathcal{W}_{\gamma}(\psi))\mathrm{diag}(\mathcal{W}_{\gamma}(\psi))\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathrm{diag}(\mathcal{W}_{\gamma}(\psi))\mathbf{J}_{\theta}(x,\mathcal{W}_{\gamma}(\psi)))^{\top}
=𝐉θ(x,ψ)𝒲γ1𝒲γdiag(ψ)diag(𝟏𝒫)diag(ψ)𝒲γ𝒲γ1𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathcal{W}_{\gamma}^{-1}\mathcal{W}_{\gamma}\mathrm{diag}(\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathrm{diag}(\psi)\mathcal{W}_{\gamma}\mathcal{W}_{\gamma}^{-1}\mathbf{J}_{\theta}(x,\psi)^{\top}
=𝐉θ(x,ψ)diag(ψ)diag(𝟏𝒫)diag(ψ)𝐉θ(x,ψ)\displaystyle=\mathbf{J}_{\theta}(x,\psi)\mathrm{diag}(\psi)\mathrm{diag}(\mathbf{1}_{\mathcal{P}})\mathrm{diag}(\psi)\mathbf{J}_{\theta}(x,\psi)^{\top}
=𝐂^xxψ\displaystyle=\hat{\mathbf{C}}_{xx^{\prime}}^{\psi}

for empirical CTK. Therefore, we get

f𝒲γ(θ)lin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi) |pLA(ψ|𝒮)𝒩(f(x,θ),α2/γ2Θ^xxθα2/γ2Θ^x𝒳θΘ^𝒳θ1Θ^𝒳xθ)\displaystyle|p_{\mathrm{LA}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}/\gamma^{2}\hat{\Theta}_{xx}^{\theta^{*}}-\alpha^{2}/\gamma^{2}\hat{\Theta}_{x\mathcal{X}}^{\theta^{*}}\hat{\Theta}_{\mathcal{X}}^{\theta^{*}-1}\hat{\Theta}_{\mathcal{X}x}^{\theta^{*}})
f𝒲γ(θ)lin(x,ψ)\displaystyle f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi) |pCL(ψ|𝒮)𝒩(f(x,θ),α2𝐂^xxθα2𝐂^x𝒳θ𝐂^𝒳θ1𝐂^𝒳xθ)\displaystyle|p_{\mathrm{CL}}(\psi|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\hat{\mathbf{C}}_{xx}^{\theta^{*}}-\alpha^{2}\hat{\mathbf{C}}_{x\mathcal{X}}^{\theta^{*}}\hat{\mathbf{C}}_{\mathcal{X}}^{\theta^{*}-1}\hat{\mathbf{C}}_{\mathcal{X}x}^{\theta^{*}})

This gives us

VarψpLA(ψ|𝒮)(f𝒲γ(θ)lin(x,ψ))=VarψpLA(ψ|𝒮)(fθlin(x,ψ))/γ2\displaystyle\textrm{Var}_{\psi\sim p_{\mathrm{LA}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi))=\textrm{Var}_{\psi\sim p_{\mathrm{LA}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\theta^{*}}(x,\psi))/\gamma^{2}
VarψpCL(ψ|𝒮)(f𝒲γ(θ)lin(x,ψ))=VarψpCL(ψ|𝒮)(fθlin(x,ψ))\displaystyle\textrm{Var}_{\psi\sim p_{\mathrm{CL}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\mathcal{W}_{\gamma}(\theta^{*})}(x,\psi))=\textrm{Var}_{\psi\sim p_{\mathrm{CL}}(\psi|\mathcal{S})}(f^{\mathrm{lin}}_{\theta^{*}}(x,\psi))

A.7 Derivation of PAC-Bayes-NTK

Theorem A.1 (PAC-Bayes-NTK).

Let us assume pre-trained parameter θ\theta^{*} with data 𝒮\mathcal{S}_{\mathbb{P}}. Let us assume PAC-Bayes prior and posterior as

θ(δ)\displaystyle\mathbb{P^{\prime}}_{\theta*}(\delta) :=𝒩(δ|𝟎P,α2𝐈P)\displaystyle:=\mathcal{N}(\delta|\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{P}) (14)
θ(δ)\displaystyle\mathbb{Q^{\prime}}_{\theta^{*}}(\delta) :=𝒩(δ|μ,Σ)\displaystyle:=\mathcal{N}(\delta|\mu_{\mathbb{Q^{\prime}}},\Sigma_{\mathbb{Q^{\prime}}}) (15)
μ\displaystyle\mu_{\mathbb{Q^{\prime}}} :=Σ𝐉θ(𝒴f(𝒳,θ))σ2\displaystyle:=\frac{\Sigma_{\mathbb{Q^{\prime}}}\mathbf{J}_{\theta}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}} (16)
Σ\displaystyle\Sigma_{\mathbb{Q^{\prime}}} :=(𝐈Pα2+𝐉θ𝐉θσ2)1\displaystyle:=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}}{\sigma^{2}}\right)^{-1} (17)

By applying θ,θ\mathbb{P^{\prime}}_{\theta^{*}},\mathbb{Q^{\prime}}_{\theta^{*}} to data-dependent PAC-Bayes bound (equation 2), we get

err𝒟(θ)\displaystyle\mathrm{err}_{\mathcal{D}}(\mathbb{Q^{\prime}}_{\theta^{*}}) err𝒮(θ)+μμ4α2N(average) perturbation+i=1Ph(βi)4NsharpnessKL divergence+log(2N/δ)2N\displaystyle\leq\mathrm{err}_{\mathcal{S}_{\mathbb{Q^{\prime}}}}(\mathbb{Q^{\prime}}_{\theta^{*}})+\sqrt{\overbrace{\underbrace{\frac{\mu_{\mathbb{Q^{\prime}}}^{\top}\mu_{\mathbb{Q^{\prime}}}}{4\alpha^{2}N_{\mathbb{Q^{\prime}}}}}_{\textrm{(average) perturbation}}+\underbrace{\sum_{i=1}^{P}\frac{h\left(\beta^{\prime}_{i}\right)}{{4N_{\mathbb{Q^{\prime}}}}}}_{\textrm{sharpness}}}^{\textrm{KL divergence}}+\frac{\log(2\sqrt{N_{\mathbb{Q^{\prime}}}}/\delta)}{2N_{\mathbb{Q^{\prime}}}}} (18)

where {βi}i=1P\{\beta^{\prime}_{i}\}_{i=1}^{P} are eigenvalues of (𝐈P+α2σ2𝐉θ𝐉θ)1(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta})^{-1} and h(x):=xlog(x)1h(x):=x-\log(x)-1. This upper bound is not scale-invariant in general.

Proof.

The main difference between PAC-Bayes-CTK and PAC-Bayes-NTK is the definition of Jacobian: PAC-Bayes-CTK use Jacobian w.r.t connectivity and PAC-Bayes-NTK use Jacobian w.r.t. parameter. Therefore, Construction of KL divergence of proof of Theorem 2.3 is preserved except

Σ1Σ=(𝐈P+α2σ2𝐉θ𝐉θ)1\displaystyle\Sigma_{\mathbb{P^{\prime}}}^{-1}\Sigma_{\mathbb{Q^{\prime}}}=(\mathbf{I}_{P}+\frac{\alpha^{2}}{\sigma^{2}}\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta})^{-1}

and βi\beta^{\prime}_{i} are eigenvalues of Σ1Σ\Sigma_{\mathbb{P^{\prime}}}^{-1}\Sigma_{\mathbb{Q^{\prime}}}. Note that these eigenvalues satisfies

βi=σ2σ2+α2λi\displaystyle\beta^{\prime}_{i}=\frac{\sigma^{2}}{\sigma^{2}+\alpha^{2}\lambda^{\prime}_{i}}

where λi\lambda^{\prime}_{i} are eigenvalues of NTK. ∎

Remark A.2 (Function-preserving scale transformation to NTK).

On the contrary to the CTK, scale invariance property is not applicable to the NTK due to Jacobian w.r.t. parameter:

𝐉θ(x,𝒯(ψ))=𝒯(ψ)f(x,𝒯(ψ))=𝒯(ψ)f(x,ψ)=𝐉θ(x,ψ)𝒯1\displaystyle\mathbf{J}_{\theta}(x,\mathcal{T}(\psi))=\frac{\partial}{\partial\mathcal{T}(\psi)}f(x,\mathcal{T}(\psi))=\frac{\partial}{\partial\mathcal{T}(\psi)}f(x,\psi)=\mathbf{J}_{\theta}(x,\psi)\mathcal{T}^{-1}

If we assume all parameters are scale-invariant (or equivalently masking the Jacobian for all non scale-invariant parameters as in the proof of Proposition 3.1), the scale of NTK is proportional to the inverse scale of parameters.

A.8 Deterministic limiting kernel of CTK

Theorem A.3 (Deterministic limiting kernel of CTK).

Let us assume LL-layered network with Lipschitz activation function and NN with NTK initialization. Then the empirical CTK converges in probability to a deterministic limiting kernel 𝐂xy\mathbf{C}_{xy} as the layers width n1,,nLn_{1},\dots,n_{L}\rightarrow\infty sequentially. Furthermore, 𝐂xy=Θxy\mathbf{C}_{xy}=\Theta_{xy} holds.

Proof.

The proof is a modification to proof of convergence of NTK in Jacot et al. [17] considering NTK initialization (i.e. standard Gaussian for all parameters). We provide proof by induction. For single layer network, The CTK is summed as:

(𝐂xx)kk\displaystyle(\mathbf{C}_{xx^{\prime}})_{kk^{\prime}} =1n0i=1n0j=1n1xixiδjkδjkWikWik+β2j=1n1δjkδjk\displaystyle=\frac{1}{n_{0}}\sum_{i=1}^{n_{0}}\sum_{j=1}^{n_{1}}x_{i}x_{i}^{{}^{\prime}}\delta_{jk}\delta_{jk^{\prime}}W_{ik}W_{ik^{\prime}}+\beta^{2}\sum_{j=1}^{n_{1}}\delta_{jk}\delta_{jk^{\prime}}
(Θxx)kk\displaystyle\rightarrow(\Theta_{xx})_{kk^{\prime}}

since the weight is sampled from standard Gaussian distribution, whose variance is 1, and product of two (independent) random variable converges in probability converges to the product of converged values. If we assume CTK of ll-th layer is converged to NTK of ll-th layer in probability, then the convergence of the (l+1)(l+1)-th layer is also satisfied since multiplication of two random weights, which converges to 1, is multiplied to the empirical NTK of (l+1)(l+1)-th layer, which converges to the deterministic limiting NTK of (l+1)(l+1)-th layer. Therefore, empirical CTK converges in probability to the deterministic limiting CTK, which is equivalent to the deterministic limiting NTK. ∎

Appendix B Details of Squared Loss for Classification Tasks

For the classification tasks in Sec. 4.2, we use the squared loss instead of the cross-entropy loss since our theoretical results are built on the squared loss. Here, we describe how we use the squared loss to mimic the cross-entropy loss. There are several works [23, 24] that utilized the squared loss for the classification task instead of the cross-entropy loss. Specifically, Lee et al. [23] used

(𝒮,θ)=12NK(xi,yi)𝒮f(xi,θ)yi2\displaystyle\mathcal{L}(\mathcal{S},\theta)=\frac{1}{2NK}\sum_{(x_{i},y_{i})\in\mathcal{S}}\|f(x_{i},\theta)-y_{i}\|^{2}

where CC is the number of classes, and Hui and Belkin [24] used

((x,c),θ)=12K(k(fc(x,θ)M)2+i=1,icKfi(x,θ)2)\displaystyle\ell((x,c),\theta)=\frac{1}{2K}\left(k(f_{c}(x,\theta)-M)^{2}+\sum_{i=1,i\neq c}^{K}f_{i}(x,\theta)^{2}\right)

for single data loss, where ((x,c),θ)\ell((x,c),\theta) is sample loss given input xx, target cc and parameter θ\theta, fi(x,θ)f_{i}(x,\theta)\in\mathbb{R} is the ii-th component of f(x,θ)Kf(x,\theta)\in\mathbb{R}^{K}, kk and MM are dataset-specific hyper-parameters.

These works used the mean for reducing the vector-valued loss into a scalar loss. However, this can be problematic when the number of classes is large. When the number of classes increases, the denominator of the mean (the number of classes) increases while the target value remains 1 (one-hot label). As a result, the scale of a gradient for the target class becomes smaller. To avoid such an unfavorable effect, we just use the sum for reducing vector-valued loss into a scalar loss instead of taking the mean, i.e.,

((x,c),θ)=12((fc(x,θ)1)2+i=1,icKfi(x,θ)2)\displaystyle\ell((x,c),\theta)=\frac{1}{2}\left((f_{c}(x,\theta)-1)^{2}+\sum_{i=1,i\neq c}^{K}f_{i}(x,\theta)^{2}\right)

This analysis is consistent with the hyper-parameter selection in Hui and Belkin [24]. They used larger kk and MM as the number of classes increases (e.g., k=1,M=1k=1,M=1 for CIFAR-10 [36], but k=15,M=30k=15,M=30 for ImageNet [61]) which results in manually compensating the scale of gradient to the target class label.

Appendix C Derivation of PAC-Bayes posterior

Derivation of θ(c)\mathbb{Q}_{\theta^{*}}(c)

For Bayesian linear regression, we compute the posterior of βP\beta\in\mathbb{R}^{P}

yi=xiβ+ϵi,for i=1,M\displaystyle y_{i}=x_{i}\beta+\epsilon_{i},\quad\text{for }i=1\dots,M

where ϵi𝒩(0,σ2)\epsilon_{i}\sim\mathcal{N}(0,\sigma^{2}) is i.i.d. sampled and the prior of β\beta is given as β𝒩(𝟎P,α2𝐈P)\beta\sim\mathcal{N}(\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{P}). By concatenating this, we get

𝐲=𝐗β+ε\displaystyle\mathbf{y}=\mathbf{X}\beta+\varepsilon

where 𝐲M,𝐗M×p\mathbf{y}\in\mathbb{R}^{M},\mathbf{X}\in\mathbb{R}^{M\times p} are concatenation of yi,xiy_{i},x_{i}, respectively and ε𝒩(𝟎M,σ2𝐈M)\varepsilon\sim\mathcal{N}(\mathbf{0}_{M},\sigma^{2}\mathbf{I}_{M}). It is well known [62, 63] that the posterior of β\beta for this problem is

β\displaystyle\beta 𝒩(β|μ,Σ)\displaystyle\sim\mathcal{N}\left(\beta|\mu,\Sigma\right)
μ\displaystyle\mu :=Σ𝐗yσ2\displaystyle:=\frac{\Sigma\mathbf{X}^{\top}y}{\sigma^{2}}
Σ\displaystyle\Sigma :=(𝐈Pα2+𝐗𝐗σ2)1.\displaystyle:=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{X}^{\top}\mathbf{X}}{\sigma^{2}}\right)^{-1}.

Similarly, we define Bayesian linear regression problem as

yi=f(xi,θ)+𝐉θ(xi,θ)diag(θ)c+ϵi,for i=1,NK\displaystyle y_{i}=f(x_{i},\theta^{*})+\mathbf{J}_{\theta}(x_{i},\theta^{*})\mathrm{diag}(\theta^{*})c+\epsilon_{i},\quad\text{for }i=1\dots,NK

where M=NKM=NK and the regression coefficient is β=c\beta=c in this case. Thus, we treat yif(xi,θ)y_{i}-f(x_{i},\theta^{*}) as a target and 𝐉θ(xi,θ)diag(θ)\mathbf{J}_{\theta}(x_{i},\theta^{*})\mathrm{diag}({\theta^{*}}) as an input of linear regression problem. By concatenating this, we get

𝒴=f(𝒳,θ)+𝐉cc+ε(𝒴f(𝒳,θ))=𝐉cc+ε.\mathcal{Y}=f(\mathcal{X},\theta^{*})+\mathbf{J}_{c}c+\varepsilon\Rightarrow\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)=\mathbf{J}_{c}c+\varepsilon.

By plugging this to the posterior of Bayesian linear regression problem, we get

θ(c)\displaystyle\mathbb{Q}_{\theta^{*}}(c) :=𝒩(c|μ,Σ)\displaystyle:=\mathcal{N}(c|\mu_{\mathbb{Q}},\Sigma_{\mathbb{Q}})
μ\displaystyle\mu_{\mathbb{Q}} :=Σ𝐉c(𝒴f(𝒳,θ))σ2=Σdiag(θ)𝐉θ(𝒴f(𝒳,θ))σ2\displaystyle:=\frac{\Sigma_{\mathbb{Q}}\mathbf{J}_{c}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}=\frac{\Sigma_{\mathbb{Q}}\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\left(\mathcal{Y}-f(\mathcal{X},\theta^{*})\right)}{\sigma^{2}}
Σ\displaystyle\Sigma_{\mathbb{Q}} :=(𝐈Pα2+𝐉c𝐉cσ2)1=(𝐈Pα2+diag(θ)𝐉θ𝐉θdiag(θ)σ2)1\displaystyle:=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}=\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathrm{diag}(\theta^{*})\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}\mathrm{diag}(\theta^{*})}{\sigma^{2}}\right)^{-1}

Derivation of θ(ψ)\mathbb{Q}_{\theta^{*}}(\psi)

We define perturbed parameter ψ\psi as follows

ψ:=θ+θc.\displaystyle\psi:=\theta^{*}+\theta^{*}\odot c.

Since ψ\psi is affine to cc, we get the distribution of ψ\psi as

θ(ψ)\displaystyle\mathbb{Q}_{\theta^{*}}(\psi) :=𝒩(ψ|μψ,Σψ)\displaystyle:=\mathcal{N}(\psi|\mu_{\mathbb{Q}}^{\psi},\Sigma_{\mathbb{Q}}^{\psi})
μψ\displaystyle\mu_{\mathbb{Q}}^{\psi} :=θ+θμ\displaystyle:=\theta^{*}+\theta^{*}\odot\mu_{\mathbb{Q}}
Σψ\displaystyle\Sigma_{\mathbb{Q}}^{\psi} :=diag(θ)Σdiag(θ)=(diag(θ)2α2+𝐉θ𝐉θσ2)1\displaystyle:=\mathrm{diag}(\theta^{*})\Sigma_{\mathbb{Q}}\mathrm{diag}(\theta^{*})=\left(\frac{\mathrm{diag}(\theta^{*})^{-2}}{\alpha^{2}}+\frac{\mathbf{J}_{\theta}^{\top}\mathbf{J}_{\theta}}{\sigma^{2}}\right)^{-1}

Appendix D Representative cases of function-preserving scaling transformations

Activation-wise rescaling transformation [14, 64] For NNs with ReLU activations, following holds for xd,γ>0\forall x\in\mathbb{R}^{d},\gamma>0: f(x,θ)=f(x,γ,l,k(θ))f(x,\theta)=f(x,\mathcal{R}_{\gamma,l,k}(\theta)), where rescale transformation γ,l,k()\mathcal{R}_{\gamma,l,k}(\cdot)111For a simple two layer linear NN f(x):=W2σ(W1x)f(x):=W_{2}\sigma(W_{1}x) with weight matrix W1,W2W_{1},W_{2}, the first case of equation 19 corresponds to kk-th row of W1W_{1} and the second case of equation 19 corresponds to kk-th column of W2W_{2}. is defined as

(γ,l,k(θ))i={γθi, if θi {param. subset connecting as input edges to k-th activation at l-th layer}θi/γ, if θi {param. subset connecting as output edges to k-th activation at l-th layer}θi, for θi in the other cases\displaystyle\footnotesize{(\mathcal{R}_{\gamma,l,k}(\theta))_{i}=\begin{cases}\gamma\cdot\theta_{i}&,\text{ if $\theta_{i}\in$ \{param. subset connecting as input edges to $k$-th activation at $l$-th layer\}}\\ \theta_{i}/\gamma&,\text{ if $\theta_{i}\in$ \{param. subset connecting as output edges to $k$-th activation at $l$-th layer\}}\\ \theta_{i}&,\text{ for $\theta_{i}$ in the other cases}\end{cases}} (19)

Note that γ,l,k()\mathcal{R}_{\gamma,l,k}(\cdot) is a finer-grained rescaling transformation than layer-wise rescaling transformation (i.e. common γ\gamma for all activations in layer ll) discussed in Dinh et al. [10]. Dinh et al. [10] showed that even layer-wise rescaling transformations can sharpen pre-trained solutions in terms of trace of Hessian (i.e., contradicting the FM hypothesis). This contradiction also occurs to previous PAC-Bayes bounds [14, 15] due to the scale-dependent term.

Weight decay with BN layers [65] For parameters Wnl×nl1W\in\mathbb{R}^{n_{l}\times n_{l-1}} preceding BN layer,

BN((diag(γ)W)u)=BN(Wu)\displaystyle\mathrm{BN}((\mathrm{diag}(\gamma)W)u)=\mathrm{BN}(Wu) (20)

for an input unl1u\in\mathbb{R}^{n_{l-1}} and a positive vector γ+nl\gamma\in\mathbb{R}^{n_{l}}_{+}. This implies that scaling transformations on these parameters preserve function represented by NNs for xd,γ+nl\forall x\in\mathbb{R}^{d},\gamma\in\mathbb{R}^{n_{l}}_{+}: f(x,θ)=f(x,𝒮γ,l,k(θ))f(x,\theta)=f(x,\mathcal{S}_{\gamma,l,k}(\theta)), where scaling transformation 𝒮γ,l,k()\mathcal{S}_{\gamma,l,k}(\cdot) is defined for i=1,,Pi=1,\dots,P

(𝒮γ,l,k(θ))i={γkθi, if θi {param. subset connecting as input edges to k-th activation at l-th layer}θi,for θi in the other cases\displaystyle\left(\mathcal{S}_{\gamma,l,k}(\theta)\right)_{i}=\footnotesize{\begin{cases}\gamma_{k}\cdot\theta_{i}&,\text{ if $\theta_{i}\in$ \{param. subset connecting as input edges to $k$-th activation at $l$-th layer\}}\\ \theta_{i}&,\text{for $\theta_{i}$ in the other cases}\end{cases}} (21)

Note that the weight decay regularization [12, 13] can be implemented as a realization of 𝒮γ,l,k()\mathcal{S}_{\gamma,l,k}(\cdot) (e.g., γ=0.9995\gamma=0.9995 for all activations preceding BN layers). Therefore, thanks to Theorem 2.2 and Theorem 2.5, our CTK-based bound is invariant to weight decay regularization applied to parameters before BN layers. We also refer to [50, 66] for optimization perspective of weight decay with BN.

Appendix E Implementation of Connectivity Laplace

To estimate the empirical/generalization bound in Sec. 2.4 and calibrate uncertainty in Sec. 4.2, we need to sample cc from the posterior θ(c)\mathbb{Q}_{\theta^{*}}(c). For this, we sample perturbations δ\delta in connectivity space

δ𝒩(δ|𝟎P,(𝐈Pα2+𝐉c𝐉cσ2)1)\displaystyle\delta\sim\mathcal{N}\left(\delta|\mathbf{0}_{P},\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}\right)

so that c=μ+δc=\mu_{\mathbb{Q}}+\delta for equation 6. To sample this, we provide a novel approach to sample from LA/CL without curvature approximation. To this end, we consider following optimization problem

argmincL(c):=argminc12Nσ2𝒴f(𝒳,θ)𝐉cc+ε2+12Nα2cc022\displaystyle\arg\min_{c}L(c):=\arg\min_{c}\frac{1}{2N\sigma^{2}}\|\mathcal{Y}-f(\mathcal{X},\theta^{*})-\mathbf{J}_{c}c+\varepsilon\|^{2}+\frac{1}{2N\alpha^{2}}\|c-c_{0}\|^{2}_{2}

where ε𝒩(ε|𝟎NK,σ2𝐈NK)\varepsilon\sim\mathcal{N}(\varepsilon|\mathbf{0}_{NK},\sigma^{2}\mathbf{I}_{NK}) and c0𝒩(c0|𝟎P,α2𝐈P)c_{0}\sim\mathcal{N}(c_{0}|\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{P}). By first-order optimality condition, we have

NcL(c)=𝐉c(𝒴f(𝒳,θ)𝐉cc+ε)σ2+cc0α2=𝟎P.\displaystyle N\nabla_{c}L(c)=-\frac{\mathbf{J}_{c}^{\top}(\mathcal{Y}-f(\mathcal{X},\theta^{*})-\mathbf{J}_{c}c^{*}+\varepsilon)}{\sigma^{2}}+\frac{c^{*}-c_{0}}{\alpha^{2}}=\mathbf{0}_{P}.

By arranging this w.r.t. optimizer cc^{*}, we get

c\displaystyle c^{*} =(𝐉c𝐉c+σ2α2𝐈P)1(𝐉c(𝒴f(𝒳,θ))+σ2α2c0+𝐉cε)\displaystyle=\left(\mathbf{J}_{c}^{\top}\mathbf{J}_{c}+\frac{\sigma^{2}}{\alpha^{2}}\mathbf{I}_{P}\right)^{-1}\left(\mathbf{J}_{c}^{\top}(\mathcal{Y}-f(\mathcal{X},\theta^{*}))+\frac{\sigma^{2}}{\alpha^{2}}c_{0}+\mathbf{J}_{c}\varepsilon\right)
=(𝐉c𝐉c+σ2α2𝐈P)1𝐉c(𝒴f(𝒳,θ))+(𝐉c𝐉c+σ2α2𝐈P)1(σ2α2c0+𝐉cε)\displaystyle=\left(\mathbf{J}_{c}^{\top}\mathbf{J}_{c}+\frac{\sigma^{2}}{\alpha^{2}}\mathbf{I}_{P}\right)^{-1}\mathbf{J}_{c}^{\top}(\mathcal{Y}-f(\mathcal{X},\theta^{*}))+\left(\mathbf{J}_{c}^{\top}\mathbf{J}_{c}+\frac{\sigma^{2}}{\alpha^{2}}\mathbf{I}_{P}\right)^{-1}\left(\frac{\sigma^{2}}{\alpha^{2}}c_{0}+\mathbf{J}_{c}\varepsilon\right)
=(𝐈Pα2+𝐉c𝐉cσ2)1𝐉c(𝒴f(𝒳,θ))σ2Deterministic+(𝐈Pα2+𝐉c𝐉cσ2)1(c0α2+𝐉cεσ2)Stochastic\displaystyle=\underbrace{\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}\frac{\mathbf{J}_{c}^{\top}(\mathcal{Y}-f(\mathcal{X},\theta^{*}))}{\sigma^{2}}}_{\textrm{Deterministic}}+\underbrace{\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}\left(\frac{c_{0}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\varepsilon}{\sigma^{2}}\right)}_{\textrm{Stochastic}}

Since both ε\varepsilon and c0c_{0} are sampled from independent Gaussian distributions, we have

z:=(c0α2+𝐉cεσ2)𝒩(z|𝟎P,𝐈Pα2+𝐉c𝐉cσ2)\displaystyle z:=\left(\frac{c_{0}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\varepsilon}{\sigma^{2}}\right)\sim\mathcal{N}\left(z|\mathbf{0}_{P},\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)

Therefore, optimal solution of randomized optimization problem argmincL(c)\arg\min_{c}L(c) is

c𝒩(c|(𝐈Pα2+𝐉c𝐉cσ2)1𝐉c(𝒴f(𝒳,θ))σ2,(𝐈Pα2+𝐉c𝐉cσ2)1)=𝒩(c|μ,Σ)\displaystyle c\sim\mathcal{N}\left(c\,\Big{|}\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}\frac{\mathbf{J}_{c}^{\top}(\mathcal{Y}-f(\mathcal{X},\theta^{*}))}{\sigma^{2}},\left(\frac{\mathbf{I}_{P}}{\alpha^{2}}+\frac{\mathbf{J}_{c}^{\top}\mathbf{J}_{c}}{\sigma^{2}}\right)^{-1}\right)=\mathcal{N}(c|\mu_{\mathbb{Q}},\Sigma_{\mathbb{Q}})

Similarly, sampling from CL can be implemented as a following optimization problem.

argmincL(c):=argminc12Nσ2𝐉ccε2+12Nα2cc022\displaystyle\arg\min_{c}L(c):=\arg\min_{c}\frac{1}{2N\sigma^{2}}\|\mathbf{J}_{c}c-\varepsilon\|^{2}+\frac{1}{2N\alpha^{2}}\|c-c_{0}\|^{2}_{2}

where ε𝒩(ε|𝟎NK,σ2𝐈NK)\varepsilon\sim\mathcal{N}(\varepsilon|\mathbf{0}_{NK},\sigma^{2}\mathbf{I}_{NK}) and c0𝒩(c0|𝟎P,α2𝐈P)c_{0}\sim\mathcal{N}(c_{0}|\mathbf{0}_{P},\alpha^{2}\mathbf{I}_{P}). Since we sample the noise of data/perturbation and optimize the perturbation, this can be interpreted as a Randomize-Then-Optimize implementation of Laplace Approximation and Connectivity Laplace [41, 42].

Appendix F Details of computing Connectivity Sharpness

It is well known that empirical NTK or Jacobian is intractable in modern architecture of NNs (e.g., ResNet [37] or BERT [67]). Therefore, one might wonder how Connectivity Sharpness can be computed for these architectures. However, Connectivity Sharpness in Sec. 2.5 is defined as trace of empirical CTK, thereby one can compute CS with Hutchison’s method [46, 68]. According to Hutchison’s method, trace of a matrix Am×mA\in\mathbb{R}^{m\times m} is

tr(A)=tr(A𝐈p)=tr(A𝔼z[zz])=𝔼z[tr(Azz)]=𝔼z[tr(zAz)]=𝔼z[zAz]\displaystyle\mathrm{tr}(A)=\mathrm{tr}(A\mathbf{I}_{p})=\mathrm{tr}(A\mathbb{E}_{z}[zz^{\top}])=\mathbb{E}_{z}[\mathrm{tr}(Azz^{\top})]=\mathbb{E}_{z}[\mathrm{tr}(z^{\top}Az)]=\mathbb{E}_{z}[z^{\top}Az]

where zmz\in\mathbb{R}^{m} is a random variable with cov(z)=𝐈m\mathrm{cov}(z)=\mathbf{I}_{m} (e.g., standard normal distribution or Rademacher distribution). Since A=𝐂𝒳θ=𝐉c(,𝟎p)𝐉c(,𝟎p)NkA=\mathbf{C}^{\theta^{*}}_{\mathcal{X}}=\mathbf{J}_{c}(\mathcal{M},\mathbf{0}_{p})\mathbf{J}_{c}(\mathcal{M},\mathbf{0}_{p})^{\top}\in\mathbb{R}^{Nk} in our case, we further use mini-batch approximation to compute zAzz^{\top}Az: (i) Sample zMMkz_{M}\in\mathbb{R}^{Mk} from Rademacher distribution for mini-batch \mathcal{M} with size MM and (ii) compute v:=𝐉c(,𝟎p)zPv_{\mathcal{M}}:=\mathbf{J}_{c}(\mathcal{M},\mathbf{0}_{p})^{\top}z_{\mathcal{M}}\in\mathbb{R}^{P} with Jacobian-vector product of JAX [69] and (iii) compute x=v22x_{\mathcal{M}}=\|v_{\mathcal{M}}\|^{2}_{2}. Then, the sum of xMx_{M} for all mini-batch in training dataset is a Monte-Carlo approximation of CS with sample size 1. Empirically, we found that this approximation is sufficiently stable to capture the correlation between sharpness and generalization as shown in Sec. 4.1.

Appendix G Predictive uncertainty of Connectivity/Linearized Laplace

In this section, we derive predictive uncertainty of Linearized Laplace (LL) and Connectivity Laplace (CL). By matrix inversion lemma [70], the weight covariance of LL is

(𝐈p/α2+𝐉θ(𝒳,θ)𝐉θ(𝒳,θ)/σ2)1=α2𝐈pα2𝐉θ(𝒳,θ)(σ2α2𝐈Nk+Θ𝒳𝒳θ)1𝐉θ(𝒳,θ).\displaystyle(\mathbf{I}_{p}/\alpha^{2}+\mathbf{J}_{\theta}(\mathcal{X},\theta^{*})^{\top}\mathbf{J}_{\theta}(\mathcal{X},\theta^{*})/\sigma^{2})^{-1}=\alpha^{2}\mathbf{I}_{p}-\alpha^{2}\mathbf{J}_{\theta}(\mathcal{X},\theta^{*})^{\top}(\frac{\sigma^{2}}{\alpha^{2}}\mathbf{I}_{Nk}+\Theta_{\mathcal{XX}}^{\theta^{*}})^{-1}\mathbf{J}_{\theta}(\mathcal{X},\theta^{*}).

Therefore, if σ2/α20\sigma^{2}/\alpha^{2}\rightarrow 0, then the weight covariance of LL converges to

α2𝐈pα2𝐉θ(𝒳,θ)Θ𝒳𝒳θ1𝐉θ(𝒳,θ).\displaystyle\alpha^{2}\mathbf{I}_{p}-\alpha^{2}\mathbf{J}_{\theta}(\mathcal{X},\theta^{*})^{\top}\Theta_{\mathcal{XX}}^{\theta^{*}-1}\mathbf{J}_{\theta}(\mathcal{X},\theta^{*}).

With this weight covariance and linearized NN, the predictive uncertainty of LL is

fθlin(x,θ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\theta) |pLA(θ|𝒮)𝒩(f(x,θ),α2Θxxθα2Θx𝒳θΘ𝒳𝒳θ1Θ𝒳xθ).\displaystyle|p_{\mathrm{LA}}(\theta|\mathcal{S})\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\Theta_{xx}^{\theta^{*}}-\alpha^{2}\Theta_{x\mathcal{X}}^{\theta^{*}}\Theta_{\mathcal{XX}}^{\theta^{*}-1}\Theta_{\mathcal{X}x}^{\theta^{*}}).

Similarly, the predictive uncertainty of CL is

fθlin(x,θ)\displaystyle f^{\mathrm{lin}}_{\theta^{*}}(x,\theta) |θ(θ)𝒩(f(x,θ),α2𝐂xxθα2𝐂x𝒳θ𝐂𝒳𝒳θ1𝐂𝒳xθ).\displaystyle|\mathbb{Q}_{\theta^{*}}(\theta)\sim\mathcal{N}(f(x,\theta^{*}),\alpha^{2}\mathbf{C}_{xx}^{\theta^{*}}-\alpha^{2}\mathbf{C}_{x\mathcal{X}}^{\theta^{*}}\mathbf{C}_{\mathcal{XX}}^{\theta^{*}-1}\mathbf{C}_{\mathcal{X}x}^{\theta^{*}}).

Appendix H Details on sharpness-generalization experiments

To verify that the CS has a better correlation with generalization performance compared to existing sharpness measures, we evaluate the three metrics: (a) Kendall’s rank-correlation coefficient τ\tau [51] which considers the consistency of a sharpness measure with generalization gap (i.e., if one has higher sharpness, then so has higher generalization gap) (b) granulated Kendall’s coefficient [7] which examines Kendall’s rank-correlation coefficient w.r.t. individual hyper-parameters to separately evaluate the effect of each hyper-parameter to generalization gap (c) conditional independence test [7] which captures the causal relationship between measure and generalization.

Table 5: Configuration of hyper-parameter
network depth 1, 2, 3
network width 32, 64, 128
learning rate 0.1, 0.032, 0.001
weight decay 0.0, 1e-4, 5e-4
mini-batch size 256, 1024, 4096

Three metrics are compared with the following baselines: trace of Hessian (tr(𝐇)\mathrm{tr}(\mathbf{H}); [19]), trace of Fisher information matrix (tr(𝐅)\mathrm{tr}(\mathbf{F}); [52]), trace of empirical NTK at θ\theta^{*} (tr(𝚯θ)\mathrm{tr}(\mathbf{\Theta^{\theta^{*}}})), and four PAC-Bayes bound based measures, sharpness-orig (SO), pacbayes-orig (PO), 1/α1/\alpha^{\prime} sharpness mag (SM), and 1/σ1/\sigma^{\prime} pacbayes mag (PM), which are eq. (52), (49), (62), (61) in Jiang et al. [7].

For the granulated Kendall’s coefficient, we use 5 hyper-parameters : network depth, network width, learning rate, weight decay and mini-batch size, along with 3 options for each hyper-parameters as in Table 5. We use the VGG-13 [53] as a base model and we adjust the depth and width of each conv block. We add BN layers after the convolution layer for each block. Specifically, the number of convolution layers of each conv block is the depth and the number of channels of convolution layers of the first conv block is the width. For the subsequent conv blocks, we follow the original VGG width multipliers (×2\times 2, ×4\times 4, ×8\times 8). An example with depth 1 and width 128 is depicted in Table 6.

Table 6: Example of network configuration with respect to the depth 1, width 128 in [53]-style.
ConvNet Configuration
input (224 ×\times 224 RGB image)
Conv3-128
BN
ReLU
MaxPool
Conv3-256
BN
ReLU
MaxPool
Conv3-512
BN
ReLU
MaxPool
Conv3-1024
BN
ReLU
MaxPool
Conv3-1024
BN
ReLU
MaxPool
FC-4096
ReLU
FC-4096
ReLU
FC-1000

We use SGD optimizer with a momentum 0.9. We train each model for 200 epochs and use cosine learning rate scheduler [38] with 30% of initial epochs as warm-up epochs. The standard data augmentations (padding, random crop, random horizontal flip, and normalization) for CIFAR-10 is used for training data. For the analysis, we only use models with above 99% training accuracy following Jiang et al. [7]. As a result, we use 200 out of 243 trained models for our correlation analysis. For every experiment, we use 8 NVIDIA RTX 3090 GPUs.

Appendix I Details and additional results on BNN experiments

I.1 Experimental Setting

Uncertainty calibration on image classification task We pre-train models for 200 epochs CIFAR-10/100 dataset [36] with ResNet-18[37] as mentioned in Section 2.4. We choose ensemble size MM as 8 except Deep Ensemble [55] and Batch Ensemble [58]. We use 4 ensemble members for Deep Ensemble and Batch Ensemble due to computational cost.

For evaluation, we define single member prediction as one-hot representation of network output with label smoothing. We select label smoothing coefficient as 0.01 for CIFAR-10, 0.1 for CIFAR-100. We define ensemble prediction as averaged prediction of single member predictions. For OOD detection, we use variance of prediction in output space, which is competitive to recent OOD detection methods [71, 72]. We use 0.01 for σ\sigma and select best α\alpha with cross validation. For every experiment, we use 8 NVIDIA RTX 3090 GPUs.

Appendix J Additional results on bound estimation

Table 7: Results for experiments on PAC-Bayes-NTK estimation.
CIFAR-10 CIFAR-100
Parameter scale 0.5 1.0 2.0 4.0 0.5 1.0 2.0 4.0
tr(Θ𝒳θ)\mathrm{tr}(\Theta^{\theta^{*}}_{\mathcal{X}}) 18746194.0 6206303.5 3335419.75 2623873.25 12688970.0 3916139.25 2819272.5 2662497.0
Bias 483.86 427.0042 299.0085 197.3149 476.9061 478.1776 440.284 329.8767
Sharpness 579.6815 472.0 402.8186 369.3761 547.2874 434.7583 398.5075 387.3265
KL divergence 531.7708 449.5021 350.9135 283.3455 512.0967 456.4679 419.3957 358.6016
Test err. 0.5617 ± 0.0670 0.4566 ± 0.0604 0.2824 ± 0.0447 0.1530 ± 0.0199 0.6210 ± 0.0096 0.6003 ± 0.0094 0.5499 ± 0.0100 0.4666 ± 0.0093
PAC-Bayes-NTK 0.7985 ± 0.0694 0.6730 ± 0.0626 0.4718 ± 0.0465 0.3186 ± 0.0202 0.8530 ± 0.0140 0.8162 ± 0.0136 0.7602 ± 0.0112 0.6617 ± 0.0114

Appendix K Additional results on image classification

Table 8: Uncertainty calibration results on CIFAR-10 [36] for VGG-13 [53].
CIFAR-10
NLL (\downarrow) ECE (\downarrow) Brier. (\downarrow) AUC (\uparrow)
Deterministic 0.4086 ± 0.0018 0.0490 ± 0.0003 0.1147 ± 0.0005 -
MCDO 0.3889 ± 0.0049 0.0465 ± 0.0009 0.1106 ± 0.0015 0.7765 ± 0.0221
MCBN 0.3852 ± 0.0012 0.0462 ± 0.0002 0.1108 ± 0.0003 0.9051 ± 0.0065
Batch Ensemble 0.3544 ± 0.0036 0.0399 ± 0.0009 0.1064 ± 0.0012 0.9067 ± 0.0030
Deep Ensemble 0.2243 0.0121 0.0776 0.7706
Linearized Laplace 0.3366 ± 0.0013 0.0398 ± 0.0004 0.1035 ± 0.0003 0.8883 ± 0.0017
Connectivity Laplace (Ours) 0.2674 ± 0.0028 0.0234 ± 0.0011 0.0946 ± 0.0010 0.9002 ± 0.0033
Table 9: Uncertainty calibration results on CIFAR-100 [36] for VGG-13 [53].
CIFAR-100
NLL (\downarrow) ECE (\downarrow) Brier. (\downarrow) AUC (\uparrow)
Deterministic 1.8286 ± 0.0066 0.1544 ± 0.0010 0.4661 ± 0.0018 -
MCDO 1.7439 ± 0.0089 0.1363 ± 0.0008 0.4456 ± 0.0017 0.6424 ± 0.0099
MCBN 1.7491 ± 0.0075 0.1399 ± 0.0010 0.4488 ± 0.0015 0.7039 ± 0.0197
Batch Ensemble 1.6142 ± 0.0101 0.1077 ± 0.0020 0.4143 ± 0.0027 0.7232 ± 0.0021
Deep Ensemble 1.2006 0.0456 0.3228 0.6929
Linearized Laplace 1.5806 ± 0.0054 0.1036 ± 0.0004 0.4127 ± 0.0010 0.6893 ± 0.0221
Connectivity Laplace (Ours) 1.4073 ± 0.0039 0.0703 ± 0.0028 0.3827 ± 0.0012 0.7254 ± 0.0136