This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

\AtBeginEnvironment

algorithmic

G-TRACER: Expected Sharpness Optimization

John Williams [email protected]
Machine Learning Research Group
University of Oxford
Stephen Roberts [email protected]
Machine Learning Research Group
University of Oxford
Abstract

We propose a new regularization scheme for the optimization of deep learning architectures, G-TRACER ("Geometric TRACE Ratio"), which promotes generalization by seeking flat minima, and has a sound theoretical basis as an approximation to a natural-gradient descent based optimization of a generalized Bayes objective. By augmenting the loss function with a TRACER, curvature-regularized optimizers (eg SGD-TRACER and Adam-TRACER) are simple to implement as modifications to existing optimizers and don’t require extensive tuning. We show that the method converges to a neighborhood (depending on the regularization strength) of a local minimum of the unregularized objective, and demonstrate competitive performance on a number of benchmark computer vision and NLP datasets, with a particular focus on challenging low signal-to-noise ratio problems.

1 Introduction

1.1 Problem setting

The connection between generalization performance and the loss-surface geometry of deep-learning architectures in the neighborhood of local minima has long been the subject of interest and speculation, dating back to the MDL-based arguments of (Hinton & van Camp, 1993) and (Hochreiter & Schmidhuber, 1997). The connection is an intuitively appealing one, in that the sharp local minima of the highly nonlinear, non-convex optimization problems associated with modern overparameterized deep learning architectures are more likely to be brittle and sensitive to perturbations in the parameters and training data, and thus lead to worse performance on unseen data. We can build some intuition for this from a probabilistic modelling perspective, given a dataset 𝒟={(xi,yi)i=1n}\mathcal{D}=\{(x_{i},y_{i})_{i=1}^{n}\} consisting of nn independent input random variables xix_{i} with distribution p(x)p(x) and corresponding targets (or labels) yiy_{i} with distribution p(y|x)p(y|x) and treating the parameters wΘpw\in\Theta\subseteq\mathbb{R}^{p} of a deep neural network (DNN) f(,w):dxdyf(\cdot,w):\mathbb{R}^{d_{x}}\rightarrow\mathbb{R}^{d_{y}} as a random variable. Given a loss function l(yi,f(xi,w))l(y_{i},f(x_{i},w)) our goal is to find a ww^{*} that minimizes the expected loss: 𝔼p(x,y)[l(y,f(x,w))]\mathbb{E}_{p(x,y)}[l(y,f(x,w))]. Writing the finite-sample version of this expected loss as L(w)=i=1nl(yi,f(xi,w))L(w)=\sum_{i=1}^{n}l(y_{i},f(x_{i},w)), we can form a generalized posterior distribution (Bissiri et al., 2016) p(w|𝒟)=p(w)1Zexp{L(w))}p(w|\mathcal{D})=p(w)\frac{1}{Z}\exp\{-L(w))\} (with normalizer ZZ) over the weights, which coincides with the Bayesian posterior in the special case that the loss is the negative log-likelihood L(w)=1ni=1nlogp(yi|xi,w)L(w)=-\frac{1}{n}\sum_{i=1}^{n}\log p(y_{i}|x_{i},w) and then, together with an output (conditional predictive) probability distribution p(y|x,w)p(y|x,w), we can form the predictive distribution by marginalization:

p(y|x,𝒟)=p(y|x,w)p(w|𝒟)𝑑wp(y|x,\mathcal{D})=\int p(y|x,w)p(w|\mathcal{D})dw (1)

At a local maximum (or mode) wkw_{k} of p(w|𝒟p(w|\mathcal{D}) we can can form the Laplace approximation (valid asymptotically, for large nn):

pk(w|𝒟)1Zkp(wk|𝒟)exp(12(wwk)TH(wwk))p_{k}(w|\mathcal{D})\approx\frac{1}{Z_{k}}p(w_{k}|\mathcal{D})\exp\left(-\frac{1}{2}(w-w_{k})^{T}H(w-w_{k})\right) (2)

where (assuming, for simplicity, a flat prior) H=w2L|w=wkH=\nabla_{w}^{2}L|_{w=w_{k}} and the normalizer (which for the negative log-likelihood loss is the evidence, or marginal likelihood, and which we will also refer to as the pseudo-marginal likelihood) is given by Zk=p(wk|𝒟)(2π)d2det(H)12Z_{k}=p(w_{k}|\mathcal{D})(2\pi)^{\frac{d}{2}}\det(H)^{-\frac{1}{2}}. Thus, in the neighborhood of each local maximum of p(w|𝒟)p(w|\mathcal{D}), we approximate the posterior by a multivariate Gaussian with covariance given by the inverse Hessian of the negative loss: p(wk|𝒟)𝒩(wk,H1)p(w_{k}|\mathcal{D})\sim\mathcal{N}(w_{k},H^{-1}). Modern DNNs are characterized by multimodal losses (Wilson & Izmailov, 2020), and so, informally, we can decompose the posterior predictive distribution into disjoint contributions from each of the local maxima, the sum of which dominate the overall integral:

p(y|x,𝒟)1Zk{k}p(y|x,w)Zkpk(w|𝒟)𝑑wp(y|x,\mathcal{D})\approx\frac{1}{Z}\sum_{k^{\prime}\in\{k\}}\int p(y|x,w)Z_{k^{\prime}}p_{k^{\prime}}(w|\mathcal{D})dw (3)

where Z=k{k}ZkZ=\sum_{k^{\prime}\in\{k\}}Z_{k^{\prime}}, which is an expectation with respect to a probability measure with density given by: 1Zk{k}Zkpk(w|𝒟)\frac{1}{Z}\sum_{k^{\prime}\in\{k\}}Z_{k^{\prime}}p_{k^{\prime}}(w|\mathcal{D}), and which, by writing:

1Zk{k}Zkpk(w|𝒟)=k{k}πkpk(w|𝒟)\frac{1}{Z}\sum_{k^{\prime}\in\{k\}}Z_{k^{\prime}}p_{k^{\prime}}(w|\mathcal{D})=\sum_{k^{\prime}\in\{k\}}\pi_{k^{\prime}}p_{k^{\prime}}(w|\mathcal{D}) (4)

can be viewed as a Gaussian mixture model (GMM) with mixing coefficients:

πk=Zkk{k}Zk\pi_{k}=\frac{Z_{k}}{\sum_{k^{\prime}\in\{k\}}Z_{k^{\prime}}} (5)

Thus the relative contribution of each component is given by the relative size of the pseudo-marginal likelihoods ZkZ_{k}. For very high-dimensional wpw\in\mathbb{R}^{p} (modern DNN architectures often have billions or even trillions of parameters) even small differences in the width of the Gaussian approximation will have exponentially large effects on the magnitude of ZkZ_{k} (which can be thought of as the the volume associated with the local maximum). How does all this relate to flatness? The Gaussian curvature KK, providing an intrinsic (and thus coordinate-free) measure of curvature is given by:

K=Πiλi=det(H)K=\Pi_{i}\lambda_{i}=\det(H) (6)

and contributions to the mixture thus scale inversely with K\sqrt{K}. In other words, the flatter the solution, the more it contributes to the mixture model against which the output probability distribution is integrated, in order to form the posterior predictive distribution. In a typical high-dimensional setting, the effect of small differences in curvature (or flatness) is exponentially magnified. To see this, we can consider two local minima i) with Hessian HH, and ii) an ϵ\epsilon-flattened minimum (0<ϵ10<\epsilon\ll 1) with Hessian H=(1ϵ)HH^{\prime}=(1-\epsilon)H, (so that each eigenvalue of HH is simply shrunk by a constant factor (1ϵ)(1-\epsilon)). We have K=det(H)=(1ϵ)pdet(H)K^{\prime}=\det(H^{\prime})=(1-\epsilon)^{p}\det(H), so that the ϵ\epsilon-flattened minimum with curvature KK^{\prime} has exponentially lower Gaussian curvature. The corresponding Gaussian approximations have covariances Σ(1+ϵ)Σ\Sigma^{\prime}\approx(1+\epsilon)\Sigma and Σ\Sigma, and the ratio of the corresponding pseudo-marginal likelihoods scales as:

det((1+ϵ)Σ)det(Σ)=(1+ϵ)pp\frac{\det((1+\epsilon)\Sigma)}{\det(\Sigma)}=(1+\epsilon)^{p}\xrightarrow{p\rightarrow\infty}\infty (7)

Thus we can see that the contribution from the flatter minimum dominates in the high-dimensional limit. In an empirical study, Huang et al. (2019) train a ResNet18 architecture on the Street View House Number (SVHN) dataset and estimate the volume around local minima using Monte-Carlo integration, finding that the volumes of basins surrounding minima that generalize well are at least 10,000 orders of magnitude larger than those of minima that generalize poorly.

A complementary approach is to characterize the loss-surface Hessian, since, at such local minimum of the loss, for a perturbation Δw\Delta w, we have:

L(w+Δw)L(w)=ΔwT2L(w)Δw+O(Δw3)L(w+\Delta w)-L(w)=\Delta w^{T}\nabla^{2}L(w)\Delta w+O(\|\Delta w\|^{3}) (8)

There has therefore been a large literature attempting to characterize the loss-surface Hessian 2L(w)\nabla^{2}L(w) and to relate these characteristics to generalization. In many practically relevant cases, multiple minima are associated with zero (or close to zero) training error, and explicit or implicit regularization is needed to find solutions with the best generalization error. Overparameterization is associated with the bulk of the Hessian spectrum lying close to zero and thus to highly degenerate minima (Sagun et al., 2017). Wei & Schwab (2020) further show that given a degenerate valley in the loss surface, SGD on average decreases the trace of the Hessian, which is strongly suggestive of a connection between locally flat minima, overaparameterization and generalization.

1.2 Sharpness-Aware Minimization

Despite the intuitive appeal and plausible justifications for flat solutions to be a goal of DNN optimization algorithms, there have been few practical unqualified successes in exploiting this connection to improve generalization performance. A notable exception is a recent algorithm, Sharpness Aware Minimization (SAM) (Foret et al., 2020), which seeks to improve generalization by optimizing a saddle-point problem of the form:

minwmaxΔwρL(w+Δw)\min_{w}\max_{\|\Delta w\|\leq\rho}L(w+\Delta w) (9)

An approximate solution to this problem is obtained by differentiating through the inner maximization, so that, given an approximate solution Δw:=ρL(wk)L(wk)2\Delta w^{*}:=\rho\frac{\nabla L(w^{k})}{\|\nabla L(w^{k})\|_{2}} to the inner maximization (dual norm) problem:

argmaxΔwρL(w+Δw)\arg\max_{\|\Delta w\|\leq\rho}L(w+\Delta w) (10)

the gradient of the SAM objective is approximated as follows:

w(maxΔwFRϵL(w+Δw))wL(w+Δw)wL(w)|w+Δw\nabla_{w}\left(\max_{\|\Delta w\|_{FR}\leq\epsilon}L(w+\Delta w)\right)\approx\nabla_{w}L(w+\Delta w^{*})\approx\nabla_{w}L(w)|_{w+\Delta w^{*}} (11)

While the method has gained widespread attention, and state-of-the-art performance has been demonstrated on several benchmark datasets, it remains relatively poorly understood, and the motivation and connection to sharpness is questionable given that the Euclidian norm-ball isn’t invariant to changes in coordinates. Given a 1-1 mapping g:ΘΘg:\Theta^{\prime}\rightarrow\Theta we can reparameterize our DNN f(,w)f(\cdot,w) using the "pullback" g(f)(,ν):=f(,g(ν))g^{*}(f)(\cdot,\nu):=f(\cdot,g(\nu)) under which, crucially, the underlying prediction function f(,w):dxdyf(\cdot,w):\mathbb{R}^{d_{x}}\rightarrow\mathbb{R}^{d_{y}} (and therefore the loss) itself is invariant, since, for ν=g1(w)\nu=g^{-1}(w), we have f(,w)=f(,g(ν))f(\cdot,w)=f(\cdot,g(\nu)). Under this coordinate transformation, however, the Hessian at a critical point transforms as (Dinh et al., 2017):

2L(ν)=g(ν)T2Lg(ν)\nabla^{2}L(\nu)=\nabla g(\nu)^{T}\nabla^{2}L\nabla g(\nu) (12)

In particular, Dinh et al. (2017) explicitly show, using layer-wise transformations Tα:(w1,w2)(αw1,α1w2)T_{\alpha}:(w_{1},w_{2})\rightarrow(\alpha w_{1},\alpha^{-1}w_{2}), that deep rectifier feedforward networks possess large numbers of symmetries which can be exploited to control sharpness without changing the network output. The existence of these symmetries in the loss function, under which the geometry of the local loss can be substantially modified (and in particular, the spectral norm and trace of the Hessian) means that the relationship between the local flatness of the loss landscape and generalization is a subtle one.

It’s instructive to consider the PAC Bayes generalization bound that motivates SAM, the derivation of which starts from a PAC-Bayesian generalization bound (McAllester, 1999; Dziugaite & Roy, 2017):

Theorem 1.

For any distribution 𝒟\mathcal{D} and prior pp over the parameters ww, with probability 1δ1-\delta over the choice of the training set 𝒮𝒟\mathcal{S}\sim\mathcal{D}, and for any posterior qq over the parameters:

𝔼q[L𝒟(w)]𝔼q[L𝒮(w)]+KL(q||p)+lognδ2(n1)\mathbb{E}_{q}[L_{\mathcal{D}}(w)]\leq\mathbb{E}_{q}[L_{\mathcal{S}}(w)]+\sqrt{\frac{KL(q||p)+\log\frac{n}{\delta}}{2(n-1)}} (13)

where the KL divergence:

𝔻KL[q,p]=𝔼p(w)[log(q(w)p(w))]\mathbb{D}_{KL}[q,p]=\mathbb{E}_{p(w)}\left[\log\left(\frac{q(w)}{p(w)}\right)\right] (14)

defines a statistical distance 𝔻KL[q,p]\mathbb{D}_{KL}[q,p] (though not a metric, as it’s symmetric only to second order) on the space of probability distributions. Assuming an isotropic prior p=N(0,σp2I)p=N(0,\sigma_{p}^{2}I) for some σp\sigma_{p}, an isotropic posterior q=N(w,σq2I)q=N(w,\sigma_{q}^{2}I), so that 𝔼q[L𝒟(w)]=𝔼ϵN(0,σq2I)[L𝒟(w+ϵ)]\mathbb{E}_{q}[L_{\mathcal{D}}(w)]=\mathbb{E}_{\epsilon\sim N(0,\sigma_{q}^{2}I)}[L_{\mathcal{D}}(w+\epsilon)], applying the covering approach of Langford & Caruana (2001) to select the best (closest to qq in the sense of KL divergence) from a set of pre-defined data-independent prior distributions satisfying the PAC generalization bound, Foret et al. (2020) show that the bound in theorem 1 can be written in the following form:

𝔼ϵN(0,σq2I)[L𝒟(w+ϵ)]𝔼ϵN(0,σq2I)[L𝒮(w+ϵ)]+g(w22)ρ2)\mathbb{E}_{\epsilon\sim N(0,\sigma_{q}^{2}I)}[L_{\mathcal{D}}(w+\epsilon)]\leq\mathbb{E}_{\epsilon\sim N(0,\sigma_{q}^{2}I)}[L_{\mathcal{S}}(w+\epsilon)]+g\left(\frac{\|w\|_{2}^{2})}{\rho^{2}}\right) (15)

(for a monotone function gg) and then, crucially, apply a well-known tail-bound for a chi-square random variable to bound ϵ2\|\epsilon\|_{2} thus bounding the expectation over qq, with probability 11/n1-1/\sqrt{n}, by the maximum value over a Euclidian norm-ball ball and deriving the following generalization bound:

Theorem 2.

For any ρ>0\rho>0 and any distribution 𝒟\mathcal{D}, with probability 1δ1-\delta over the choice of the training set 𝒮𝒟\mathcal{S}\sim\mathcal{D},

L𝒟(w)maxϵ2ρL𝒮(w+ϵ)+g(w22)ρ2)L_{\mathcal{D}}(w)\leq\max_{\|\epsilon\|_{2}\leq\rho}L_{\mathcal{S}}(w+\epsilon)+g\left(\frac{\|w\|_{2}^{2})}{\rho^{2}}\right) (16)

where ρ=σk(1+ln(n)k)\rho=\sigma\sqrt{k}\left(1+\sqrt{\frac{\ln(n)}{k}}\right), n=|𝒮|n=|\mathcal{S}|, and kk is the number of parameters.

This bound justifies and motivates the SAM objective:

maxΔwρL(w+Δw)+λw22\max_{\|\Delta w\|\leq\rho}L(w+\Delta w)+\lambda\|w\|_{2}^{2} (17)

and resulting algorithm. While the bound in Theorem 2 suggests that the ridge-penalty should vary with the radius of the perturbation, in practice (Foret et al., 2020) the penalty term is fixed (or simply set to zero) even when different perturbation radii are searched over. Subsequent refinements of SAM (Kim et al., 2022) ignore the ridge penalty term altogether, and the choice of an optimal perturbation radius is what drives the success of the method. It is not clear, however, why this adversarial parameter-space perturbation should help generalization more than evaluating (and approximating) the expectation in the very bound which motivates the SAM procedure in the first place, which would lead instead to an objective (ignoring, for now, the ridge penalty term) of the following form:

𝔼ϵN(0,σ2I)[L𝒮(w+ϵ)]\mathbb{E}_{\epsilon\sim N(0,\sigma^{2}I)}[L_{\mathcal{S}}(w+\epsilon)] (18)

Moreover, the worst-case adversarial perturbation used by SAM is likely to be noisier and is also naturally a significantly looser bound than the expectation-based bound.

2 Generalized variational posterior

Our starting point is a similar, but more general, optimization objective, which arises in the variational optimization of a generalized posterior distribution, qq, over the space of probability measures 𝒫(Θ)\mathcal{P}(\Theta) on the parameter space Θ\Theta (Bissiri et al., 2016) given by:

q(w)=argminq𝒫(Θ){𝔼q(w)[L(w)]+𝔻KL[q,p]}q^{*}(w)=\arg\min_{q\in\mathcal{P}(\Theta)}\left\{\mathbb{E}_{q(w)}[L(w)]+\mathbb{D}_{KL}[q,p]\right\} (19)

to which, when Z=Θexp{i=1nl(w,xi)}π(θ)𝑑θ<Z=\int_{\Theta}\exp\left\{-\sum_{i=1}^{n}l(w,x_{i})\right\}\pi(\theta)d\theta<\infty, the solution is given by the generalized posterior:

q(w)p(w)i=1Nexp{l(w,xi)}q^{*}(w)\propto p(w)\prod_{i=1}^{N}\exp\{-l(w,x_{i})\} (20)

The terms exp{l(w,xi)}\exp\{-l(w,x_{i})\} are to be interpreted as quasi-likelihoods, and for the particular choice l(w,xi)=logp(xi|w)l(w,x_{i})=-\log p(x_{i}|w), we recover the standard Bayesian posterior. As this infinite dimensional optimization is, in general, intractable, it is usual to assume that the posterior belongs to a parametric family 𝒬𝒫\mathcal{Q}\subset\mathcal{P}:

q(w)=argminq𝒬(Θ){𝔼q(w)[L(w)]+𝔻KL[q,p]}q^{*}(w)=\arg\min_{q\in\mathcal{Q}(\Theta)}\left\{\mathbb{E}_{q(w)}[L(w)]+\mathbb{D}_{KL}[q,p]\right\} (21)

which, for the choice l(w,xi)=logp(xi|w)l(w,x_{i})=-\log p(x_{i}|w), is the same objective (up to a constant factor) as the evidence lower bound (ELBO) used in variational Bayes.

In practice, it is often found that tempering the KL divergence term by a positive factor ρ<1\rho<1 produces optimal performance, giving rise to:

q(w)=argminq𝒬(Θ){𝔼q(w)[L(w)]+ρ𝔻KL[q,p]}q^{*}(w)=\arg\min_{q\in\mathcal{Q}(\Theta)}\left\{\mathbb{E}_{q(w)}[L(w)]+\rho\mathbb{D}_{KL}[q,p]\right\} (22)

2.1 TRACER: flatness-inducing regularization

Ignoring for simplicity the contribution from the prior term (which would correspond to a ridge-regularization term under the assumption p(w)N(0,σpI)p(w)\sim N(0,\sigma_{p}I)), leads to following objective, which we seek to minimize over ww:

𝔼q(w)[L(w)]ρ(q)\mathbb{E}_{q(w)}[L(w)]-\rho\mathcal{H}(q) (23)

where (q)=𝔼q(w)[q(w)]\mathcal{H}(q)=-\mathbb{E}_{q(w)}[q(w)] is the entropy of qq. For the choice q(w)N(w,σ2I)q(w)\sim N(w,\sigma^{2}I), the optimization problem associated with the variational objective becomes (absorbing some constants into ρ\rho):

argminq,Σ𝔼q[L(w)]ρ(q)=argminw,σ2𝔼q[L(w)]+ρlog1σ2\arg\min_{q,\Sigma}\mathbb{E}_{q}[L(w)]-\rho\mathcal{H}(q)=\arg\min_{w,\sigma^{2}}\mathbb{E}_{q}[L(w)]+\rho\log\frac{1}{\sigma^{2}} (24)

so that we can see that ρ\rho determines the variance of Gaussian perturbation over which the loss is averaged. More generally, choosing qN(w,Σ)q\sim N(w,\Sigma) leads to the following variational objective:

argminq,Σ𝔼q[L(w)]ρ(q)=argminw,Σ𝔼q[L(w)]+ρlog1det(Σ)\arg\min_{q,\Sigma}\mathbb{E}_{q}[L(w)]-\rho\mathcal{H}(q)=\arg\min_{w,\Sigma}\mathbb{E}_{q}[L(w)]+\rho\log\frac{1}{\det(\Sigma)} (25)

so that large values of ρ\rho will correspond to distributions with larger volume, since for xN(0,Σ)x\sim N(0,\Sigma), xx lies within the ellipsoid xTΣ1x=χ2(α)x^{T}\Sigma^{-1}x=\chi^{2}(\alpha) with probability 1α1-\alpha, with the volume of the ellipsoid proportional to det(Σ)12\det(\Sigma)^{\frac{1}{2}} (Anderson, 2003). We show in section 2.3 that expanding the expectation under qq to second order, we have:

𝔼q(w)[L(w)]L(w)+12Tr(Σw2L(w))\mathbb{E}_{q(w)}[L(w)]\approx L(w)+\frac{1}{2}\text{Tr}(\Sigma\nabla_{w}^{2}L(w)) (26)

so that, in the neighborhood of a local minimum, where the Hessian is positive-definite, the curvature of the loss-surface is penalized over a region whose volume is determined by ρ\rho. While intuitively appealing, this flatness inducing penalty is not invariant to coordinate transformations, so that scale changes (such as occur, for example, when applying batch-normalization or weight-normalization) which have no effect on the output of the learned probability distribution, can nevertheless still result in arbitrary changes to the penalty. More generally, as discussed above, any geometric notion of loss surface flatness must be independent of arbitrary rescaling of the network parameters. Motivated by these considerations, we apply steepest descent in the KL-metric (also known as natural gradient descent (Amari, 1998)) to our variational objective, under the assumption q(w)N(w,Σ)q(w)\sim N(w,\Sigma) in order to find solutions to:

argminμ,Σ𝔼q[L(w)]+ρ𝔻KL[q,p]\arg\min_{\mu,\Sigma}\mathbb{E}_{q}[L(w)]+\rho\mathbb{D}_{KL}[q,p] (27)

where ρ\rho is a positive real-valued regularization parameter.

We show in the sequel that, assuming an isotropic Gaussian prior, p(w)N(0,ηI)p(w)\sim N(0,\eta I), performing gradient descent w.r.t. the natural gradient then leads to the following iterative update equations:

μμαtΛ1(𝔼q[wL(w)]+ρηw)\mu\xleftarrow[]{}\mu-\alpha_{t}\Lambda^{-1}\left(\mathbb{E}_{q}[\nabla_{w}L(w)]+\frac{\rho}{\eta}w\right) (28)
Λ(1β)Λ+β(𝔼q[w2L(w)]ρ+η1I)\Lambda\xleftarrow[]{}(1-\beta)\Lambda+\beta\left(\frac{\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]}{\rho}+\eta^{-1}I\right) (29)

where αt\alpha_{t} and β\beta are the learning rates for the mean and precision updates, respectively, and Λ:=Σ1\Lambda:=\Sigma^{-1} is the precision matrix. Approximating the expectations to second order and further simplifying leads to the following update equations (see below for a detailed derivation):

μμ+αtH¯1(w[L(w)+ρTr(HH¯1)])\mu\xleftarrow[]{}\mu+\alpha_{t}{\overline{H}}^{-1}\left(\nabla_{w}[L(w)+\rho\text{Tr}(H\overline{H}^{-1})]\right) (30)
H¯(1β)H¯+βH{\overline{H}}\xleftarrow[]{}(1-\beta){\overline{H}}+\beta H (31)

where H=w2L(w)H=\nabla^{2}_{w}L(w) is the Hessian. The update rule for H¯\overline{H} is an exponential smoothing and the update for the mean consists of a preconditioned (by the inverse smoothed Hessian) gradient, together with, crucially, a penalty term proportional to the (affine-invariant) ratio of the Hessian and the smoothed Hessian. Finally, via an empirical Fisher (diagonal) Hessian approximation (see below for details and a discussion of alternatives) and dropping the preconditioner, we arrive at a modified SGD-type update which we call SGD-TRACER.

2.2 SGD-TRACER

SGD-TRACER is given by Algorithm (1) in which the usual stochastic gradient update is modified with a term which penalizes the trace of the ratio between the diagonal of the Empirical FIM and an exponentially weighted average the of the Empirical FIM diagonal. By augmenting the loss with a TRACER term and maintaining a smoothed squared-gradient estimate, in principle, any optimization scheme can be modified in the same way. In our experiments we use SGD with momentum for vision tasks and Adam-TRACER for NLP tasks, based on standard practice in each problem domain.

Algorithm 1 SGD-TRACER
0:  αt\alpha_{t}: Stepsize
0:  β\beta: Exponential smoothing constant for the online Fisher estimate
0:  ρ:\rho: flatness inducing penalty term
0:  δ\delta: small positive constant
  Initialize 𝐰0\mathbf{w}_{0}, 𝐟0\mathbf{f}_{0}, t=0t=0
  while not converged do do
     Sample batch ={(𝒙1,𝒚1),(𝒙b,𝒚b)}\mathcal{B}=\{(\bm{x}_{1},\bm{y}_{1}),...(\bm{x}_{b},\bm{y}_{b})\}
     𝐰t+1=𝐰tαt𝐰[L(𝐰t)+ρ(𝐰L(𝐰t))2,(𝐟t¯+δ)1]\mathbf{w}_{t+1}=\mathbf{w}_{t}-\alpha_{t}\nabla_{\mathbf{w}}\left[L_{\mathcal{B}}(\mathbf{w}_{t})+\rho\left\langle\left(\nabla_{\mathbf{w}}L_{\mathcal{B}}(\mathbf{w}_{t})\right)^{2},(\overline{\mathbf{f}_{t}}+\delta)^{-1}\right\rangle\right]
     𝐟t+1=(1β)𝐟t+β(wL(𝐰t))2\mathbf{f}_{t+1}=(1-\beta)\cdot\mathbf{f}_{t}+\beta\cdot\left(\nabla_{w}L_{\mathcal{B}}(\mathbf{w}_{t})\right)^{2}
  end while

2.3 Derivation of the TRACER flatness-inducing regularizer

Following Khan & Rue and Zhang et al., we make the assumption q(w)N(μ,Σ)q(w)\sim N(\mu,\Sigma) and seek to optimize the variational objective in Equation (22) w.r.t. the variational parameters ϕ=(μ,Σ)\phi=(\mu,\Sigma) using natural gradient descent, which allows us to derive an algorithm that respects the intrinsic geometry of the parameter space, and thus derive an algorithm that seeks sharp minima in an approximately coordinate-independent way.

Thus we aim to minimize:

(ϕ):=𝔼q[L(w)]+ρ𝔻KL[q,p]\quad\mathcal{L}(\phi):=\mathbb{E}_{q}[L(w)]+\rho\mathbb{D}_{KL}[q,p] (32)

where ρ\rho is a positive real-valued regularization parameter. The negative gradient corresponds to the steepest descent direction in the Euclidian metric:

ϕϕ=limϵ 01ϵargminΔϕ:Δϕ2<ϵ(ϕ+Δϕ)\displaystyle\frac{-\nabla_{\phi}\mathcal{L}}{\|\nabla_{\phi}\mathcal{L}\|}=\lim_{\epsilon\to\ 0}\frac{1}{\epsilon}\underset{\Delta\phi:\|\Delta\phi\|_{2}<\epsilon}{\operatorname{argmin}}\mathcal{L}(\phi+\Delta\phi) (33)

and thus depends on the chosen coordinates ϕ\phi. In contrast, the so-called natural gradient update corresponds to steepest descent in the KL-divergence metric:

F1ϕϕ=limϵ 01ϵargminΔϕ:𝔻KL[qϕ,qϕ+Δϕ]<ϵ(ϕ+Δϕ)\displaystyle\frac{-F^{-1}\nabla_{\phi}\mathcal{L}}{\|\nabla_{\phi}\mathcal{L}\|}=\lim_{\epsilon\to\ 0}\frac{1}{\epsilon}\underset{\Delta\phi:\mathbb{D}_{KL}[q_{\phi},q_{\phi+\Delta\phi}]<\epsilon}{\operatorname{argmin}}\mathcal{L}(\phi+\Delta\phi) (34)

where FF is the Fisher Information Matrix (FIM):

F:=𝔼qϕ(w)[ϕlogqϕ(w)Tϕlogqϕ(w)]=𝔼qϕ(w)[ϕ2logqϕ(w)]F:=\mathbb{E}_{q_{\phi}(w)}\left[\nabla_{\phi}\log q_{\phi}(w)^{T}\nabla_{\phi}\log q_{\phi}(w)\right]=\mathbb{E}_{q_{\phi}(w)}\left[-\nabla_{\phi}^{2}\log q_{\phi}(w)\right] (35)

which defines a Riemannian metric on the parameter manifold Φ\Phi where 𝒬(Θ)={qϕ(w):ϕΦ}\mathcal{Q}(\Theta)=\{q_{\phi}(w):\phi\in\Phi\}. Expanding to second order in a small neighbourhood of ϕ\phi we have:

𝔻KL[qϕ,qϕ+Δϕ]=𝔼qϕ(w)[ΔϕTϕlogqϕ(w)12ΔϕTϕ2logqϕ(w)Δϕ]+O(Δϕ3)\mathbb{D}_{KL}[q_{\phi},q_{\phi+\Delta\phi}]=\mathbb{E}_{q_{\phi}(w)}\left[-{\Delta\phi}^{T}\nabla_{\phi}\log q_{\phi}(w)-\frac{1}{2}{\Delta\phi}^{T}\nabla_{\phi}^{2}\log q_{\phi}(w)\Delta\phi\right]+O(||\Delta\phi||^{3}) (36)

and since:

𝔼qϕ(w)ϕlogqϕ(w)=𝔼qϕ(w)[ϕqϕ(w)qϕ(w)]=ϕ𝔼qϕ(w)[1]=0\mathbb{E}_{q_{\phi}(w)}\nabla_{\phi}\log q_{\phi}(w)=\mathbb{E}_{q_{\phi}(w)}\left[\frac{\nabla_{\phi}q_{\phi}(w)}{q_{\phi}(w)}\right]=\nabla_{\phi}\mathbb{E}_{q_{\phi}(w)}[1]=0 (37)

the FIM (under certain regularity conditions) can be seen to be the Hessian (or curvature) of the K-L divergence:

𝔻KL[qϕ,qϕ+Δϕ]=12ΔϕT𝔼qϕ(w)[ϕ2logqϕ(w)]Δϕ+O(Δϕ3)=12ΔϕTFΔϕ+O(Δϕ3)\mathbb{D}_{KL}[q_{\phi},q_{\phi+\Delta\phi}]=-\frac{1}{2}{\Delta\phi}^{T}\mathbb{E}_{q_{\phi}(w)}\left[\nabla_{\phi}^{2}\log q_{\phi}(w)\right]\Delta\phi+O(||\Delta\phi||^{3})=\frac{1}{2}{\Delta\phi}^{T}F\Delta\phi+O(||\Delta\phi||^{3}) (38)

The following proposition gives an expression for the natural gradient vector (for proof see Appendix A.5):

Proposition 1.

For a probability distribution with pdf qϕ(w)N(μ,Λ1)q_{\phi}(w)\sim N(\mu,\Lambda^{-1}) with the parameterization ϕ=[μvec(Λ)]\phi=\begin{bmatrix}\mu\\ \mathrm{vec}(\Lambda)\end{bmatrix}, the natural gradient ~ϕ\tilde{\nabla}_{\phi} of (ϕ)\mathcal{L}(\phi) is given by:

~ϕ(ϕ)=[~μvec(~Λ)]\tilde{\nabla}_{\phi}\mathcal{L}(\phi)=\begin{bmatrix}\tilde{\nabla}_{\mu}\mathcal{L}\\ \mathrm{vec}(\tilde{\nabla}_{\Lambda}\mathcal{L})\end{bmatrix} (39)

where

~μ=Σ𝔼q[wL(w)+ρwp(w)]\tilde{\nabla}_{\mu}\mathcal{L}=\Sigma\mathbb{E}_{q}[\nabla_{w}L(w)+\rho\nabla_{w}p(w)] (40)
~Σ1=𝔼q[w2L(w)ρw2p(w)]+ρΣ1\tilde{\nabla}_{\Sigma^{-1}}\mathcal{L}=-\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}p(w)]+\rho\Sigma^{-1} (41)

Assuming an isotropic Gaussian prior, p(w)N(0,ηI)p(w)\sim N(0,\eta I), performing gradient descent w.r.t. this natural gradient then leads to the following iterative update equations:

μμαtΛ1(𝔼q[wL(w)]+ρηw)\mu\xleftarrow[]{}\mu-\alpha_{t}\Lambda^{-1}\left(\mathbb{E}_{q}[\nabla_{w}L(w)]+\frac{\rho}{\eta}w\right) (42)
Λ(1β)Λ+β(𝔼q[w2L(w)]ρ+η1I)\Lambda\xleftarrow[]{}(1-\beta)\Lambda+\beta\left(\frac{\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]}{\rho}+\eta^{-1}I\right) (43)

where αt\alpha_{t} and β\beta are the learning rates for the mean and precision updates, respectively. We work with each of these update equations in turn. Starting with the update equation for the mean μ\mu, the key observation is that the expectation 𝔼q[wL(w)]\mathbb{E}_{q}[\nabla_{w}L(w)] is taken with respect to the distribution q(w)q(w) which is an exponential moving average of the expected Hessian 𝔼q[w2L(w)]\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]. This updating happens naturally as a consequence taking natural gradient steps, and is what leads to an approximately coordinate free algorithm in the sequel. Applying Bonnet’s theorem Khan & Rue (2021) and forming the second-order approximation to the loss:

𝔼q[wL(w)]=μ𝔼q[L(w)]μ𝔼q[L(μ)+(wμ)Tw2L(w)|w=μ(wμ)]\begin{split}\mathbb{E}_{q}[\nabla_{w}L(w)]=\nabla_{\mu}\mathbb{E}_{q}[L(w)]\approx\nabla_{\mu}\mathbb{E}_{q}[L(\mu)+(w-\mu)^{T}\nabla^{2}_{w}L(w)|_{w=\mu}(w-\mu)]\end{split} (44)

Writing wμ=Σ12,νw-\mu=\Sigma^{\frac{1}{2}},\nu where ν𝒩(0,I)\nu\sim\mathcal{N}(0,I) we have:

𝔼q[(wμ)Tw2L(w)|w=μ(wμ)]=𝔼ν𝒩(0,I)[νTΣ12Tw2L(w)|w=μΣ12ν]=Tr(Σ12Tw2L(w)|w=μΣ12)=Tr(HΣ)\mathbb{E}_{q}[(w-\mu)^{T}\nabla^{2}_{w}L(w)|_{w=\mu}(w-\mu)]=\mathbb{E}_{\nu\sim\mathcal{N}(0,I)}[\nu^{T}{\Sigma^{\frac{1}{2}}}^{T}\nabla^{2}_{w}L(w)|_{w=\mu}\Sigma^{\frac{1}{2}}\nu]=\text{Tr}{(\Sigma^{\frac{1}{2}}}^{T}\nabla^{2}_{w}L(w)|_{w=\mu}\Sigma^{\frac{1}{2}})=\text{Tr}(H\Sigma) (45)

where we used the fact that 𝔼ν𝒩(0,I)[νTΩν]=i,jΩi,j𝔼[νiνj]=Tr(Ω)\mathbb{E}_{\nu\sim\mathcal{N}(0,I)}[{\nu^{T}\Omega\nu]}=\sum_{i,j}\Omega_{i,j}\mathbb{E}[\nu_{i}\nu_{j}]=\text{Tr}(\Omega), and where HH is the Hessian w2L(w)\nabla^{2}_{w}L(w). We therefore have have that:

𝔼q[wL(w)]μ[L(μ)+Tr(HΣ)]\mathbb{E}_{q}[\nabla_{w}L(w)]\approx\nabla_{\mu}[L(\mu)+\text{Tr}(H\Sigma)] (46)

Choosing the prior variance η\eta to be infinite and thus ignoring terms involving η\eta, corresponding to an improper prior (and consistent with the discussion above), leads to the following update for the mean:

μμ+αtΛ(μ[L(μ)+Tr(HΣ)])\mu\xleftarrow[]{}\mu+\alpha_{t}\Lambda\left(\nabla_{\mu}[L(\mu)+\text{Tr}(H\Sigma)]\right) (47)

Thus, in order to blur the loss with multivariate Gaussian noise in a way that aligns with the intrinsic geometry of the parameter space, we can (to second order) augment the loss with a term involving the Trace of the Hessian. Taking now the update equation for the covariance, we can simplify by Price’s theorem (Khan & Rue, 2021) together with a Taylor expansion, to get, to second order: (see Appendix for details) 𝔼q[w2L(w)]w2L(w)|w=μ\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]\approx\nabla_{w}^{2}L(w)|_{w=\mu}:

Λ(1β)Λ+β(w2L(w)|w=μλ+η1I)\Lambda\xleftarrow[]{}(1-\beta)\Lambda+\beta\left(\frac{\nabla_{w}^{2}L(w)|_{w=\mu}}{\lambda}+\eta^{-1}I\right) (48)

We next substitute, as is common in the literature on approximate second order approximation Martens (2020), the Generalized Gauss-Newton matrix (GGN) for the Hessian, given by:

G(w)=1|S|(x,y)S[JfTHLJf]G(w)=\frac{1}{|S|}\sum_{(x,y)\in S}[J^{T}_{f}H_{L}J_{f}] (49)

where JfJ_{f} is the Jacobian of the output function ff and HLH_{L} is the Hessian of the loss w.r.t. the output distribution. The GGN is a positive definite approximation to the Hessian which converges to the Hessian as the fitted residuals go to zero, (Kunstner et al., 2019). The most practically relevant losses, cross-entropy (classification), and squared error (regression) correspond to exponential family output distributions with natural parameters given by f(x,w)f(x,w), together with for the log-loss l(y,f(x,w))=log(p(y|x,w)l(y,f(x,w))=-\log(p(y|x,w), and for these choices, the GGN is equivalent to the Fisher Information Matrix. While the evaluation of the GGN matrix, in particular the matrix multiplies involving the Jacobians JfJ_{f}, can be relatively costly, the FIM can be expressed as an expectation of outer products of gradients w.r.t. the output distribution p(y|x,w)p(y|x,w):

1ni=1n𝔼p(y|xi,w)[wlogp(y|xi,w)Twlogp(y|xi,w)]1ni=1nwlogp(yi~|xi,w)Twlogp(yi~|xi,w):=F~\frac{1}{n}\sum_{i=1}^{n}\mathbb{E}_{p(y|x_{i},w)}\left[\nabla_{w}\log p(y|x_{i},w)^{T}\nabla_{w}\log p(y|x_{i},w)\right]\approx\frac{1}{n}\sum_{i=1}^{n}\nabla_{w}\log p(\tilde{y_{i}}|x_{i},w)^{T}\nabla_{w}\log p(\tilde{y_{i}}|x_{i},w):=\tilde{F} (50)

which, following Martens (2020), can be estimated using a single Monte Carlo sample from the output distribution: y~p(y|xi,w)\tilde{y}\sim p(y|x_{i},w). Using this biased Fisher approximation in our setting thus requires gradients to be calculated through an expectation w𝔼p(y|x,w)[L(w;y)]\nabla_{w}\mathbb{E}_{p(y|x,w)}[L(w;y)] approximated using a Monte Carlo sample from the model’s output distribution. Since the expectation is taken w.r.t. a distribution which depends on ww, it is necessary to reparameterize so that the discrete Monte Carlo sample is expressed as the deterministic transformation of a gw(z)g_{w}(z) (depending on ww) of a sample zhθ(z)z\sim h_{\theta}(z) from a distribution not depending on ww, so that 𝔼p(y|x,w)[L(w;y)]=𝔼zhθ(z)[L(w;gw(z)]\mathbb{E}_{p(y|x,w)}[L(w;y)]=\mathbb{E}_{z\sim h_{\theta}(z)}[L(w;g_{w}(z)]. In the discrete case (corresponding to classification), since the argmax function is non-differentiable, the standard approach is the Gumbel-Softmax reparameterization Jang et al. (2016), which uses the softmax function as a continuous relaxation of the argmax function together with i.i.d. samples distributed as Gumbel(0,1).

It’s important to note that is different from simply evaluating logp(y|x,w)\log p(y|x,w) on the training labels, a widely-used approximation known as the empirical Fisher FempF_{\text{emp}}:

Femp:=i=1nwlogp(yi|xi,w)Twlogp(yi|xi,w)F_{\text{emp}}:=\sum_{i=1}^{n}\nabla_{w}\log p(y_{i}|x_{i},w)^{T}\nabla_{w}\log p(y_{i}|x_{i},w) (51)

which, despite lacking the same convergence guarantees, performs competitively in many settings (Kunstner et al., 2019). We find empirically in our experiments that the empirical Fisher performs competitively with the MC approximation to the GGN (Khan et al., 2018; Kingma & Ba, 2014) and has the advantage of being straightforward and cheap to compute from the already computed gradients (in the case of Adam-TRACER, the smooth squared gradients are already computed and maintained for use as a preconditioner). Given the conceptual and computational simplicity of this approach, and despite its known suboptimality, in the following, we substitute the empirical Fisher for the Hessian. Recent advances in approximate second-order methods in optimization, notably Yao et al. (2020), suggest avenues for improvement, and we leave investigations of alternatives, such as the smoothed (Hessian-free) Hessian diagonal sketch used in AdaHessian, for future work.

Substituting the empirical Fisher approximation for the Hessian in the update equation for the precision, we have the following update:

F¯(1β)F¯+βF~{\overline{F}}\xleftarrow[]{}(1-\beta){\overline{F}}+\beta\tilde{F} (52)

Rewriting the update equations in terms of this exponentially smoothed FIM F¯\overline{F}, absorbing a factor ρ\rho in to αt\alpha_{t}, and writing the iteration in terms of the parameter ww, we obtain:

ww+αtF¯1(w[L(w)+ρTr(FF¯1)])w\xleftarrow[]{}w+\alpha_{t}{\overline{F}}^{-1}\left(\nabla_{w}[L(w)+\rho\text{Tr}(F\overline{F}^{-1})]\right) (53)
F¯(1β)F¯+βF~{\overline{F}}\xleftarrow[]{}(1-\beta){\overline{F}}+\beta\tilde{F} (54)

Crucially, the penalty term ρTr(FF¯1)]\rho\text{Tr}(F\overline{F}^{-1})] can be seen to be invariant to affine coordinate transformations, since it is the trace of the ratio of two (0,2) tensors which transform in the same way. Indeed under an affine coordinate transformation with Jacobian JJ we have FJTFJF\rightarrow J^{T}F^{{}^{\prime}}J and F¯JTF¯J\overline{F}\rightarrow J^{T}\overline{F}^{{}^{\prime}}J so that:

Tr(FF¯1)=Tr(JTFJJ1F¯1JT1)=Tr(JT1JTFF¯1)=Tr(FF¯1)\text{Tr}(F\overline{F}^{-1})=\text{Tr}(J^{T}F^{\prime}JJ^{-1}\overline{F}^{\prime-1}{{J}^{T}}^{-1})=\text{Tr}({{J}^{T}}^{-1}J^{T}F^{\prime}\overline{F}^{\prime-1})=\text{Tr}(F^{\prime}\overline{F}^{\prime-1}) (55)

By penalizing the ratio of the (squared) gradients and the exponentially smoothed gradients, the trace ratio penalty in effect is penalizing the change in (squared) gradient, in a coordinate-free way. More generally, given a coordinate change given by a diffemorphism Φ:pp\Phi:\mathbb{R}^{p}\rightarrow\mathbb{R}^{p} and with Jacobian J(w)J(w), then given the exponential decay in the update equation for the Fisher, subject to Φ\Phi having sufficient regularity, and for sufficiently small β\beta, the penalty term is approximately coordinate free under general smooth diffeomorphisms Φ\Phi (see SM for details).

We now make two simplifications. First, we use a mean-field approximation of the FIM by its diagonal, as is done in Adam (Kingma & Ba, 2014) and Adagrad (Duchi et al., 2011), thus:

F1ni=1nwlogp(yi|xi,w)2F\approx\frac{1}{n}\sum_{i=1}^{n}\nabla_{w}\log p(y_{i}|x_{i},w)^{2} (56)

Secondly, it is standard practice to add Tikhonov regularization or damping by a small positive real constant δ\delta when using 2nd-order optimization methods, giving in this case the preconditioner: (F¯+δI)1(\overline{F}+\delta I)^{-1}. This is justified by recognizing that the local quadratic model from which the second-order update is ultimately derived is a second-order approximation to the KL divergence and is thus only valid locally. For directions corresponding to small eigenvalues, parameter updates can lie outside the region where the approximation is reasonable (Martens, 2020). This is true, a fortiori, when diagonal approximations are used, as is the case here. As our emphasis here is on geometric regularization, we drop the preconditioner entirely by choosing δ\delta to be sufficiently large that the preconditioner is equal to the identity (up to a constant, which is absorbed into the learning rate).

Finally, as most current deep learning frameworks don’t straightforwardly support access to per-example gradients, which can in principle be achieved with negilible additional cost (see, for example, BackPACK Dangel et al. (2020) second-order Pytorch extensions), for simplicity and efficiency, we use the gradient magnitude (GM) approximation (Bottou et al., 2016), as used in standard optimizers Adam and RMSprop, replacing the sum of squared gradients with the square of summed gradients:

1ni=1n[wlogp(yi|xi,w)]2[1ni=1nwlogp(yi|xi,w)]2\frac{1}{n}\sum_{i=1}^{n}\left[\nabla_{w}\log p(y_{i}|x_{i},w)\right]^{2}\approx\left[\frac{1}{n}\sum_{i=1}^{n}\nabla_{w}\log p(y_{i}|x_{i},w)\right]^{2} (57)

Writing the resulting FIM diagonal as (wL(w))2(\nabla_{w}L(w))^{2}, we finally end up with the following simple update equations:

wt+1=wtαtw[L(wt)+ρ(wL(wt))2,f¯t1]f¯t+1=(1β)f¯t+β(wL(wt))2\begin{split}w_{t+1}=w_{t}-\alpha_{t}\nabla_{w}\left[L(w_{t})+\rho\left\langle\left(\nabla_{w}L(w_{t})\right)^{2},\overline{f}^{-1}_{t}\right\rangle\right]\\ \overline{f}_{t+1}=(1-\beta)\overline{f}_{t}+\beta\left(\nabla_{w}L(w_{t})\right)^{2}\end{split} (58)

which are summarized in Algorithm 1. We show in appendix A.3 that the algorithm converges to a neighborhood of a local minimum of L(w)L(w) of size 𝒪(ρ2)\mathcal{O}(\rho^{2}). We note in passing that, in this simplest form (after applying the gradient magnitude approximation), the update equations amount to regularizing with a (scale-adjusted) gradient norm. In principle (particularly for the large batch case) we would expect to see significant improvements by moving to per-gradient calculations (which are in principle no more expensive but require additional work in most current ML frameworks).

2.4 Results

We first examine a challenging variant on a standard benchmark in computer vision, CIFAR-100. We compare SGD, SAM and SGD-Tracer using none of the standard regularizations (no data augmentation, no weight-decay) and a standard training protocol (200 epochs, initial learning rate set to 0.10.1, cosine learning-rate decay). Further, we randomly flip 50% of the labels so that 50% of examples are incorrectly labeled. The results in Table 3 show that GTRACER significantly improves on SAM in this challenging setting. In Figure 1 we highlight results for the same problem over different values of the regularization parameter ρ\rho.

Refer to caption
Figure 1: CIFAR 100: ResNet20, no weight-decay, 50% noise, accuracy vs regularization strength. GTRACER dominates the baseline and SAM across a wide range of regularization strengths.

In Figure 2 we compare the training curves on this problem.

Refer to caption
Figure 2: CIFAR 100: ResNet20, 50% noise, test-accuracy training curves. On a standard 200 epoch training protocol with cosine learning-rate decay, SGD-Tracer converges to a solution that generalizes better than SGD and SAM
Table 1: CIFAR 100: ResNet20, no weight-decay, 50% noise, accuracy (standard error)
No aug
SGD 17.5% (2.41)
SAM 34.63% (1.85)
SGD-TRACER 47.55% (1.51)

We next run SGD-Tracer on CIFAR-100 with and without label noise, with and without augmentation, and with random label flipping, with a standard ridge penalty 5×1045\times 10^{-4}. The results in Table 2 show that SGD-TRACER performs consistently well, with a particularly strong advantage in the the presence of noise and/or without additional regularization in the form of data augmentation.

Table 2: CIFAR-100: ResNet20, accuracy (standard error)
no aug with aug 50% noise & no aug
SGD 51.43 % (0.41) 70.02% (0.36) 21.96% (0.36)
SAM 58.98 % (0.52) 70.33% (0.22) 49.89% (0.32)
SGD-TRACER 63.47% (0.32) 70.71% (0.36) 51.62% (0.18)

For NLP tasks we use the Huggingface Bert-base-uncased checkpoint together with Adam-TRACER. We fine-tune using Adam-Tracer, using a standard protocol of 5 epochs with initial learning rate 2×1052\times 10^{-5}. Each run is repeated 20 times. Performance is uniformly strong across the 3 benchmark tasks (taken from the challenging SuperGlue benchmark), and Adam-TRACER has the additional property of producing more stable results across runs (as reflected in the standard errors). See the SM for details of (standard) experiment hyperparameters.

Table 3: SupeGlue tasks BERT base-uncased results, accuracy (standard error)
BOOLQ WIC RTE
Adam 73.84% (0.14) 69.36% (0.08) 69.18% (0.33)
SAM 73.95% (0.13) 69.06% (0.07) 69.54% (0.28)
Adam-TRACER 75.09% (0.04) 70.01% (0.06) 70.13% (0.18)

3 Conclusion

Motivated by the notable empirical success of SAM, a prior that flat (in expectation, and in an intrinsic, geometric sense) minima should generalize better than sharp minima, and noting the connections between the generalized Bayes objective and SAM, we have derived a new algorithm that is simple to implement and understand, cheap to evaluate, provably convergent, naturally scale-independent (and approximately coordinate-free) and which is competitive with SAM on key benchmark problems. Performance is particularly strong for challenging low signal-to-noise ratio and large batch problems. Crucially the algorithm is straightforwardly derived from an approximate natural gradient optimization of an ELBO-type objective and doesn’t rely on "m-sharpness" (Foret et al., 2020) or other poorly understood (and expensive to compute) heuristics.

Broader Impact Statement

We present a novel method with sound theoretical motivation which delivers competitive results on challenging benchmark and low signal-to-noise-ratio problems in vision and NLP.

Appendix A Appendix

A.1 Multivariate Gaussian Fisher Information Matrix, [μ,vec(Λ)]T[\mu,\mathrm{vec}(\Lambda)]^{T} parameterization

For a probability distribution with density qq with parameters ϕ\phi, the Fisher Information Matrix (FIM) can be written as the expected negative log-likelihood Hessian:

F=𝔼q[ϕ2logq]F=\mathbb{E}_{q}\left[-\nabla_{\phi}^{2}\log q\right] (59)

In particular, for a multivariate Gaussian with pdf: q(x)N(μ,Λ1)q(x)\sim N(\mu,\Lambda^{-1}), parameterized by ϕ=[μvec(Λ)]\phi=\begin{bmatrix}\mu\\ \text{vec}(\Lambda)\end{bmatrix} the negative log-likelihood is, up to constant terms:

logq(x)=12(xμ)TΛ(xμ)+12log|Λ1|-\log q(x)=\frac{1}{2}(x-\mu)^{T}\Lambda(x-\mu)+\frac{1}{2}\log|\Lambda^{-1}| (60)

Taking gradients w.r.t. μ\mu, we have: μlogq(x)=Λ(xμ)-\nabla_{\mu}\log q(x)=\Lambda(x-\mu) and therefore 𝔼[μ2q(y)]=Λ\mathbb{E}\left[\nabla^{2}_{\mu}q(y)\right]=\Lambda. Taking gradients w.r.t. the covariance, and using since Λ(xμ)TΛ(xμ)=(xμ)(xμ)T\nabla_{\Lambda}(x-\mu)^{T}\Lambda(x-\mu)=(x-\mu)(x-\mu)^{T} and Λlog|Λ1|=Λlog|Λ|1=Λlog|Λ|=(ΛT)1=(Λ)1\nabla_{\Lambda}\log|\Lambda^{-1}|=\nabla_{\Lambda}\log|\Lambda|^{-1}=-\nabla_{\Lambda}\log|\Lambda|=-(\Lambda^{T})^{-1}=-(\Lambda)^{-1} we have:

Λq(x)=12(xμ)(xμ)T12Λ1\begin{split}\nabla_{\Lambda}q(x)=\frac{1}{2}(x-\mu)(x-\mu)^{T}-\frac{1}{2}\Lambda^{-1}\end{split} (61)

Finally, writing ΛΛ1\nabla_{\Lambda}\Lambda^{-1} as ΛΛ-\Lambda\otimes\Lambda and Λ1:=Σ\Lambda^{-1}:=\Sigma we have:

Λ2q(x)=12ΣΣ\begin{split}\nabla^{2}_{\Lambda}q(x)=\frac{1}{2}\Sigma\otimes\Sigma\end{split} (62)

so that the FIM is given by:

F=𝔼q[ϕ2logq]=[Σ10012ΣΣ]F=\mathbb{E}_{q}\left[-\nabla_{\phi}^{2}\log q\right]=\begin{bmatrix}\Sigma^{-1}&0\\ 0&\frac{1}{2}\Sigma\otimes\Sigma\end{bmatrix} (63)

A.2 Approximate expected Hessian

Lemma 1.

To second order, we can approximate the expected Hessian w.r.t. a multivariate Gaussian with pdf: q(x)N(μ,Λ1)q(x)\sim N(\mu,\Lambda^{-1}) by its value at the mean:

𝔼q[w2L(w)]w2L(w)|w=μ\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]\approx\nabla_{w}^{2}L(w)|_{w=\mu} (64)
Proof.

Following Khan & Rue (2021), by Price’s theorem, we have:

𝔼q[w2L(w)]=2Λ12𝔼q[L(w)]\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]=2\nabla^{2}_{\Lambda^{-1}}\mathbb{E}_{q}[L(w)] (65)

which is equal to, expanding the r.h.s. to second order using a Taylor series:

2Λ12𝔼q[(wμ)Tw2L(w)|w=μ(wμ)]2\nabla^{2}_{\Lambda^{-1}}\mathbb{E}_{q}[(w-\mu)^{T}\nabla^{2}_{w}L(w)|_{w=\mu}(w-\mu)] (66)

Finally, noting that 𝔼q[(wμ)Tw2L(w)|w=μ(wμ)]=Tr[12Λ1w2L(w)|w=μ]\mathbb{E}_{q}[(w-\mu)^{T}\nabla^{2}_{w}L(w)|_{w=\mu}(w-\mu)]=\text{Tr}\left[\frac{1}{2}\Lambda^{-1}\nabla^{2}_{w}L(w)|_{w=\mu}\right], we have, to second order:

𝔼q[w2L(w)]2Λ12Tr[12Λ1w2L(w)|w=μ]=w2L(w)|w=μ\mathbb{E}_{q}[\nabla_{w}^{2}L(w)]\approx 2\nabla^{2}_{\Lambda^{-1}}\text{Tr}\left[\frac{1}{2}\Lambda^{-1}\nabla^{2}_{w}L(w)|_{w=\mu}\right]=\nabla_{w}^{2}L(w)|_{w=\mu} (67)

A.3 Convergence analysis

With T(wt):=(wL(wt))2,(f¯+δ)t1T(w_{t}):=\left\langle\left(\nabla_{w}L(w_{t})\right)^{2},{(\overline{f}+\delta)}^{-1}_{t}\right\rangle, as ρ0\rho\xrightarrow{}0, the iterates wt+1=wtαtw[L(wt)+ρwT(wt)]w_{t+1}=w_{t}-\alpha_{t}\nabla_{w}\left[L(w_{t})+\rho\nabla_{w}T(w_{t})\right] will converge to those of SGD. For ρ>0\rho>0, the algorithm is biased away from a pure descent direction, and convergence then depends on the magnitude of ρ\rho. The key assumption in the following convergence proof is ρwT(wt)22κwL(wt)22+ζ\|\rho\nabla_{w}T(w_{t})\|_{2}^{2}\leq\kappa\|\nabla_{w}L(w_{t})\|_{2}^{2}+\zeta , which controls the bias and which follows from the standard assumption of twice-differentiability of L(w)L(w) and the Lipschitz continuity of wL(wt)\nabla_{w}L(w_{t}) which imply that the Hessian has a bounded spectral norm:

ρwT(wt)224ρ2w2L(wt)22(f¯+δ)t1224(ρδ)2C2p\begin{split}\|\rho\nabla_{w}T(w_{t})\|_{2}^{2}\leq 4\rho^{2}\|\nabla_{w}^{2}L(w_{t})\|^{2}_{2}\|{(\overline{f}+\delta)}^{-1}_{t}\|_{2}^{2}\\ \leq 4\left(\frac{\rho}{\delta}\right)^{2}C^{2}p\\ \end{split} (68)

so that ζ\zeta depends on the Lipshitz constant CC and the ratio ρδ\frac{\rho}{\delta}.

Theorem 3.

Let T(wt):=(wL(wt))2,f¯t1T(w_{t}):=\left\langle\left(\nabla_{w}L(w_{t})\right)^{2},\overline{f}^{-1}_{t}\right\rangle, and assume the objective (loss) L:pL:\mathbb{R}^{p}\xrightarrow{}\mathbb{R} is Lipschitz continuous, twice differentiable, and has Lipshitz-continuous gradient. Let us assume, following Bottou et al. (2016) and Ajalloeian & Stich (2021) that we have a stochastic direction g(wt,ξt)g(w_{t},\xi_{t}) which has the following properties, t\forall t:

𝔼[g(wt,ξt)]=wL+ρwT(wt)\mathbb{E}\left[g(w_{t},\xi_{t})\right]=\nabla_{w}L+\rho\nabla_{w}T(w_{t}) (69)

and further assuming that there exist MM, MGM_{G} such that, t\forall t,

𝔼[g(wt,ξt)2]M+MGwL+ρwT(wt)2\mathbb{E}\left[\|g(w_{t},\xi_{t})\|^{2}\right]\leq M+M_{G}\|\nabla_{w}L+\rho\nabla_{w}T(w_{t})\|^{2} (70)

and the following bound on the bias:

ρwT(wt)2κwL(wt)22+ζ\|\rho\nabla_{w}T(w_{t})\|^{2}\leq\kappa\|\nabla_{w}L(w_{t})\|_{2}^{2}+\zeta (71)

then the iteration:

wt+1=wtαtw[L(wt)+ρwT(wt)]f¯t+1=(1β)f¯t+β(wL(wt))2\begin{split}w_{t+1}=w_{t}-\alpha_{t}\nabla_{w}\left[L(w_{t})+\rho\nabla_{w}T(w_{t})\right]\\ \overline{f}_{t+1}=(1-\beta)\overline{f}_{t}+\beta\left(\nabla_{w}L(w_{t})\right)^{2}\end{split} (72)

converges to a neighbourhood of a stationary point with L(w)22=𝒪(ζ)\|\nabla L(w)\|_{2}^{2}=\mathcal{O}(\zeta).

Proof.

By the Lipschitz continuity of the objective function we have the quadratic bound:

L(y)L(x)+wL(x),yx+C2yx2L(y)\leq L(x)+\langle\nabla_{w}L(x),y-x\rangle+\frac{C}{2}\|y-x\|^{2} (73)

By the quadratic upper bound, the iterates generated by the algorithm satisfy:

L(wt+1)L(wt)αtwL(wt),g(wk,ξk)+12αt2Cg(wk,ξk)22\begin{split}L(w_{t+1})-L(w_{t})\leq-\alpha_{t}\langle\nabla_{w}L(w_{t}),g(w_{k},\xi_{k})\rangle+\frac{1}{2}\alpha_{t}^{2}C\|g(w_{k},\xi_{k})\|_{2}^{2}\\ \end{split} (74)

Taking expectations and applying the variance bound we have:

𝔼L(wt+1)L(wt)αtwL(wt)2αtρwL(wt)TwT(wt)+12αt2C𝔼[g(wk,ξk)22]=αtwL(wt)2αtρwL(wt)TwT(wt)+12αt2C[M+MGwL(x)+ρwT(wt)22]=αtwL(wt)2αt(1αCMG)ρwL(wt)TwT(wt)+12αt2CM+12αt2CMG(wL(x)22+ρwT(wt)22)\begin{split}\mathbb{E}L(w_{t+1})-L(w_{t})\leq-\alpha_{t}\|\nabla_{w}L(w_{t})\|^{2}-\alpha_{t}\rho\nabla_{w}L(w_{t})^{T}\nabla_{w}T(w_{t})+\frac{1}{2}\alpha_{t}^{2}C\mathbb{E}\left[\|g(w_{k},\xi_{k})\|_{2}^{2}\right]\\ =-\alpha_{t}\|\nabla_{w}L(w_{t})\|^{2}-\alpha_{t}\rho\nabla_{w}L(w_{t})^{T}\nabla_{w}T(w_{t})+\frac{1}{2}\alpha_{t}^{2}C\left[M+M_{G}\|\nabla_{w}L(x)+\rho\nabla_{w}T(w_{t})\|_{2}^{2}\right]\\ =-\alpha_{t}\|\nabla_{w}L(w_{t})\|^{2}-\alpha_{t}(1-\alpha CM_{G})\rho\nabla_{w}L(w_{t})^{T}\nabla_{w}T(w_{t})+\frac{1}{2}\alpha_{t}^{2}CM+\frac{1}{2}\alpha_{t}^{2}CM_{G}\left(\|\nabla_{w}L(x)\|_{2}^{2}+\rho\|\nabla_{w}T(w_{t})\|_{2}^{2}\right)\\ \end{split} (75)

So that, choosing αt<1CMG\alpha_{t}<\frac{1}{CM_{G}} and applying the bound on wT(wt)\|\nabla_{w}T(w_{t})\| we have:

𝔼L(wt+1)L(wt)12αtwL(wt)2+12αt2CM+12αtρwT(wt)2212αt(1κ)wL(wt)2+12αt2CM+αt2ζ\begin{split}\mathbb{E}L(w_{t+1})-L(w_{t})\leq-\frac{1}{2}\alpha_{t}\|\nabla_{w}L(w_{t})\|^{2}+\frac{1}{2}\alpha_{t}^{2}CM+\frac{1}{2}\alpha_{t}\|\rho\nabla_{w}T(w_{t})\|_{2}^{2}\\ \leq-\frac{1}{2}\alpha_{t}(1-\kappa)\|\nabla_{w}L(w_{t})\|^{2}+\frac{1}{2}\alpha_{t}^{2}CM+\frac{\alpha_{t}}{2}\zeta\end{split} (76)

Taking the total expectation, for a fixed α\alpha, we then have:

LinfL(w1)𝔼[L(wK+1)]L(w1)12α(1κ)t=1KwL(wt)2+12Kα2CM+Kα2ζ\begin{split}L_{inf}-L(w_{1})\leq\mathbb{E}\left[L(w_{K+1})\right]-L(w_{1})\leq-\frac{1}{2}\alpha(1-\kappa)\sum_{t=1}^{K}\|\nabla_{w}L(w_{t})\|^{2}+\frac{1}{2}K\alpha^{2}CM+\frac{K\alpha}{2}\zeta\\ \end{split} (77)

Finally giving:

1Kt=1KwL(wt)2=αCM1κ+2F(w1)FinfKα(1κ)KαCM1κ+ζ1κ\begin{split}\frac{1}{K}\sum_{t=1}^{K}\|\nabla_{w}L(w_{t})\|^{2}=\frac{\alpha CM}{1-\kappa}+2\frac{F(w_{1})-F_{inf}}{K\alpha(1-\kappa)}\xrightarrow{K\to\infty}\frac{\alpha CM}{1-\kappa}+\frac{\zeta}{1-\kappa}\end{split} (78)

A.4 Objective function gradient

Lemma 2.

The gradient of the objective 32 towards ϕ=[μvec(Σ)]\phi^{\prime}=\begin{bmatrix}\mu\\ \mathrm{vec}(\Sigma)\end{bmatrix} is given by:

μ=𝔼q[wL(w)ρwlogp(w)]\nabla_{\mu}\mathcal{L}=\mathbb{E}_{q}[\nabla_{w}L(w)-\rho\nabla_{w}\log p(w)] (79)
Σ=12𝔼q[w2L(w)ρw2logp(w)]ρ2Σ1\nabla_{\Sigma}\mathcal{L}=\frac{1}{2}\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}\log p(w)]-\frac{\rho}{2}\Sigma^{-1} (80)
Proof.

Taking the negative gradient of the objective wrt to μ\mu, and applying Bonnet’s theorem (Khan & Rue, 2021), and the fact that the expectation of the score is 0, we have:

μ(𝔼q[L(w)]+ρ𝔻KL[q(w),p(w)])=𝔼q[wL(w)]ρ𝔼q[wlogp(w)]\nabla_{\mu}\left(\mathbb{E}_{q}[L(w)]+\rho\mathbb{D}_{KL}[q(w),p(w)]\right)=\mathbb{E}_{q}[\nabla_{w}L(w)]-\rho\mathbb{E}_{q}\left[\nabla_{w}\log p(w)\right] (81)

Taking the gradient w.r.t. Σ\Sigma, applying Price’s theorem, we have:

Σ(𝔼q[L(w)]+ρ𝔻KL[q(w),p(w)])=12𝔼q[w2L(w)+ρw2logq(w)ρw2logp(w)]\nabla_{\Sigma}\left(\mathbb{E}_{q}[L(w)]+\rho\mathbb{D}_{KL}[q(w),p(w)]\right)=\frac{1}{2}\mathbb{E}_{q}\left[\nabla_{w}^{2}L(w)+\rho\nabla_{w}^{2}\log q(w)-\rho\nabla_{w}^{2}\log p(w)\right] (82)

and since:

𝔼q[w2logq(w)]=12𝔼q[w2(log|Σ|+(wμ)TΣ1(wμ))]=Σ1\begin{split}\mathbb{E}_{q}\left[\nabla_{w}^{2}\log q(w)\right]=-\frac{1}{2}\mathbb{E}_{q}\left[\nabla_{w}^{2}\left(\log|\Sigma|+(w-\mu)^{T}\Sigma^{-1}(w-\mu)\right)\right]=-\Sigma^{-1}\end{split} (83)

We obtain

μ=𝔼q[wL(w)ρwlogp(w)]\nabla_{\mu}\mathcal{L}=\mathbb{E}_{q}[\nabla_{w}L(w)-\rho\nabla_{w}\log p(w)] (84)
Σ=12𝔼q[w2L(w)ρw2logp(w)]ρ2Σ1\nabla_{\Sigma}\mathcal{L}=\frac{1}{2}\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}\log p(w)]-\frac{\rho}{2}\Sigma^{-1} (85)

A.5 Objective function natural gradient

Proposition 2.
~μ=Σ𝔼q[wL(w)+ρwp(w)]\tilde{\nabla}_{\mu}\mathcal{L}=\Sigma\mathbb{E}_{q}[\nabla_{w}L(w)+\rho\nabla_{w}p(w)] (86)
~Σ1=𝔼q[w2L(w)ρw2p(w)]+ρΣ1\tilde{\nabla}_{\Sigma^{-1}}\mathcal{L}=-\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}p(w)]+\rho\Sigma^{-1} (87)
Proof.

By Lemma 2, the gradients ϕ\nabla_{\phi^{\prime}} of the objective (ϕ)\mathcal{L(\phi)} w.r.t. ϕ=[μvec(Σ)]\phi^{\prime}=\begin{bmatrix}\mu\\ \text{vec}(\Sigma)\end{bmatrix} are given by:

μ=𝔼q[wL(w)ρwlogp(w)]\nabla_{\mu}\mathcal{L}=\mathbb{E}_{q}[\nabla_{w}L(w)-\rho\nabla_{w}\log p(w)] (88)

and

Σ=12𝔼q[w2L(w)ρw2logp(w)]ρ2Σ1\nabla_{\Sigma}\mathcal{L}=\frac{1}{2}\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}\log p(w)]-\frac{\rho}{2}\Sigma^{-1} (89)

The gradient ~μ\tilde{\nabla}_{\mu}\mathcal{L} then follows immediately from the definition of the natural gradient operator. Using the chain rule for matrix derivatives we can also show that:

Λ=Λ1ΣΛ1\nabla_{\Lambda}\mathcal{L}=-\Lambda^{-1}\nabla_{\Sigma}\Lambda^{-1} (90)

So that

Λ=12Λ1𝔼q[w2L(w)ρw2logp(w)]Λ1+ρ2Λ1\nabla_{\Lambda}\mathcal{L}=-\frac{1}{2}\Lambda^{-1}\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}\log p(w)]\Lambda^{-1}+\frac{\rho}{2}\Lambda^{-1} (91)

Thus, the gradients ϕ(ϕ)=[μvec(Λ)]\nabla_{\phi}\mathcal{L}(\phi)=\begin{bmatrix}\nabla_{\mu}\mathcal{L}\\ \text{vec}(\nabla_{\Lambda}\mathcal{L})\end{bmatrix} are thus given by:

[𝔼q[wL(w)ρwp(w)]12Λ1𝔼q[w2L(w)ρw2logp(w)]Λ1+ρ2Λ1]\begin{bmatrix}\mathbb{E}_{q}[\nabla_{w}L(w)-\rho\nabla_{w}p(w)]\\ -\frac{1}{2}\Lambda^{-1}\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}\log p(w)]\Lambda^{-1}+\frac{\rho}{2}\Lambda^{-1}\end{bmatrix} (92)

The Fisher Information Matrix is given by 63 :

F=𝔼qϕ[ϕ2logqϕ]=[Σ10012ΣΣ]F=\mathbb{E}_{q_{\phi}}\left[-\nabla_{\phi}^{2}\log q_{\phi}\right]=\begin{bmatrix}\Sigma^{-1}&0\\ 0&\frac{1}{2}\Sigma\otimes\Sigma\end{bmatrix} (93)

Therefore

F1ϕ(ϕ)=[Λ1002ΛΛ][μvec(Λ)]=[Λ1μvec(2ΛΛΛ)]F^{-1}\nabla_{\phi}\mathcal{L}(\phi)=\begin{bmatrix}\Lambda^{-1}&0\\ 0&2\Lambda\otimes\Lambda\end{bmatrix}\begin{bmatrix}\nabla_{\mu}\mathcal{L}\\ \text{vec}(\nabla_{\Lambda}\mathcal{L})\end{bmatrix}=\begin{bmatrix}\Lambda^{-1}\nabla_{\mu}\mathcal{L}\\ \text{vec}(2\Lambda\nabla_{\Lambda}\mathcal{L}\Lambda)\end{bmatrix} (94)

where we used the identities (BTA)vec(X)=vec(AXB)(B^{T}\otimes A)\text{vec}(X)=\text{vec}(AXB) and (AB)1=A1B1(A\otimes B)^{-1}=A^{-1}\otimes B^{-1}. Since vec(2ΛΛΛ)=vec(𝔼q[w2L(w)ρw2p(w)]+ρΛ)\text{vec}(2\Lambda\nabla_{\Lambda}\mathcal{L}\Lambda)=\text{vec}(-\mathbb{E}_{q}[\nabla_{w}^{2}L(w)-\rho\nabla_{w}^{2}p(w)]+\rho\Lambda), we have the required updates. ∎

References

  • Ajalloeian & Stich (2021) Ahmad Ajalloeian and Sebastian U. Stich. On the convergence of sgd with biased gradients, 2021.
  • Amari (1998) S. Amari. Natural gradient works efficiently in learning. Neural Computation, 10(2):251–276, 1998.
  • Anderson (2003) T. W. Anderson. An introduction to multivariate statistical analysis. Wiley Series in Probability and Statistics, 2003.
  • Bissiri et al. (2016) P. G. Bissiri, C. C. Holmes, and S. G. Walker. A general framework for updating belief distributions. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 78(5):1103–1130, feb 2016. doi: 10.1111/rssb.12158. URL https://doi.org/10.1111%2Frssb.12158.
  • Bottou et al. (2016) Léon Bottou, Frank E. Curtis, and Jorge Nocedal. Optimization methods for large-scale machine learning. 2016. doi: 10.48550/ARXIV.1606.04838. URL https://arxiv.org/abs/1606.04838.
  • Dangel et al. (2020) Felix Dangel, Frederik Kunstner, and Philipp Hennig. Backpack: Packing more into backprop, 2020.
  • Dinh et al. (2017) Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp minima can generalize for deep nets. CoRR, abs/1703.04933, 2017. URL http://arxiv.org/abs/1703.04933.
  • Duchi et al. (2011) John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research, 12(61):2121–2159, 2011. URL http://jmlr.org/papers/v12/duchi11a.html.
  • Dziugaite & Roy (2017) Gintare Karolina Dziugaite and Daniel M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. 2017. doi: 10.48550/ARXIV.1703.11008. URL https://arxiv.org/abs/1703.11008.
  • Foret et al. (2020) Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. CoRR, abs/2010.01412, 2020. URL https://arxiv.org/abs/2010.01412.
  • Hinton & van Camp (1993) Geoffrey E. Hinton and Drew van Camp. Keeping the neural networks simple by minimizing the description length of the weights. In Proceedings of the Sixth Annual Conference on Computational Learning Theory, COLT ’93, pp.  5–13, New York, NY, USA, 1993. Association for Computing Machinery. ISBN 0897916115. doi: 10.1145/168304.168306. URL https://doi.org/10.1145/168304.168306.
  • Hochreiter & Schmidhuber (1997) Sepp Hochreiter and Jürgen Schmidhuber. Flat Minima. Neural Computation, 9(1):1–42, 01 1997. ISSN 0899-7667. doi: 10.1162/neco.1997.9.1.1. URL https://doi.org/10.1162/neco.1997.9.1.1.
  • Huang et al. (2019) W. Ronny Huang, Zeyad Emam, Micah Goldblum, Liam Fowl, Justin K. Terry, Furong Huang, and Tom Goldstein. Understanding generalization through visualizations. CoRR, abs/1906.03291, 2019. URL http://arxiv.org/abs/1906.03291.
  • Jang et al. (2016) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with Gumbel-Softmax, 2016. URL https://arxiv.org/abs/1611.01144.
  • Khan & Rue (2021) Mohammad Emtiyaz Khan and Håvard Rue. The Bayesian learning rule, 2021. URL https://arxiv.org/abs/2107.04562.
  • Khan et al. (2018) Mohammad Emtiyaz Khan, Didrik Nielsen, Voot Tangkaratt, Wu Lin, Yarin Gal, and Akash Srivastava. Fast and scalable Bayesian deep learning by weight-perturbation in adam. 2018. doi: 10.48550/ARXIV.1806.04854. URL https://arxiv.org/abs/1806.04854.
  • Kim et al. (2022) Minyoung Kim, Da Li, Shell Xu Hu, and Timothy M. Hospedales. Fisher SAM: Information geometry and sharpness aware minimisation. 2022. doi: 10.48550/ARXIV.2206.04920. URL https://arxiv.org/abs/2206.04920.
  • Kingma & Ba (2014) Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2014. URL https://arxiv.org/abs/1412.6980.
  • Kunstner et al. (2019) Frederik Kunstner, Lukas Balles, and Philipp Hennig. Limitations of the empirical fisher approximation. CoRR, abs/1905.12558, 2019. URL http://arxiv.org/abs/1905.12558.
  • Langford & Caruana (2001) John Langford and Rich Caruana. (not) bounding the true error. In T. Dietterich, S. Becker, and Z. Ghahramani (eds.), Advances in Neural Information Processing Systems, volume 14. MIT Press, 2001. URL https://proceedings.neurips.cc/paper/2001/file/98c7242894844ecd6ec94af67ac8247d-Paper.pdf.
  • Martens (2020) James Martens. New insights and perspectives on the natural gradient method. Journal of Machine Learning Research, 21(146):1–76, 2020. URL http://jmlr.org/papers/v21/17-678.html.
  • McAllester (1999) David A. McAllester. Some PAC-Bayesian theorems. 1999. URL https://doi.org/10.1023/A:1007618624809.
  • Sagun et al. (2017) Levent Sagun, Utku Evci, V. Ugur Güney, Yann N. Dauphin, and Léon Bottou. Empirical analysis of the Hessian of over-parameterized neural networks. CoRR, abs/1706.04454, 2017. URL http://arxiv.org/abs/1706.04454.
  • Wei & Schwab (2020) Ming-Wei Wei and David Schwab. Implicit regularization of SGD via thermophoresis. In Advances in Neural Information Processing Systems, Machine Learning for Physical Sciences Workshop, 2020.
  • Wilson & Izmailov (2020) Andrew Gordon Wilson and Pavel Izmailov. Bayesian deep learning and a probabilistic perspective of generalization. CoRR, abs/2002.08791, 2020. URL https://arxiv.org/abs/2002.08791.
  • Yao et al. (2020) Zhewei Yao, Amir Gholami, Sheng Shen, Kurt Keutzer, and Michael W. Mahoney. ADAHESSIAN: an adaptive second order optimizer for machine learning. CoRR, abs/2006.00719, 2020. URL https://arxiv.org/abs/2006.00719.
  • Zhang et al. (2017) Guodong Zhang, Shengyang Sun, David Duvenaud, and Roger B. Grosse. Noisy natural gradient as variational inference. CoRR, abs/1712.02390, 2017. URL http://arxiv.org/abs/1712.02390.