This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

SAM operates far from home: eigenvalue regularization as a dynamical phenomenon

Atish Agarwala    Yann Dauphin
Abstract

The Sharpness Aware Minimization (SAM) optimization algorithm has been shown to control large eigenvalues of the loss Hessian and provide generalization benefits in a variety of settings. The original motivation for SAM was a modified loss function which penalized sharp minima; subsequent analyses have also focused on the behavior near minima. However, our work reveals that SAM provides a strong regularization of the eigenvalues throughout the learning trajectory. We show that in a simplified setting, SAM dynamically induces a stabilization related to the edge of stability (EOS) phenomenon observed in large learning rate gradient descent. Our theory predicts the largest eigenvalue as a function of the learning rate and SAM radius parameters. Finally, we show that practical models can also exhibit this EOS stabilization, and that understanding SAM must account for these dynamics far away from any minima.

Machine Learning, ICML

1 Introduction

Since the dawn of optimization, much effort has gone into developing algorithms which use geometric information about the loss landscape to make optimization more efficient and stable (Nocedal, 1980; Duchi et al., 2011; Lewis & Overton, 2013). In more modern machine learning, control of the large curvature eigenvalues of the loss landscape has been a goal in and of itself (Hochreiter & Schmidhuber, 1997; Chaudhari et al., 2019). There is empirical and theoretical evidence that controlling curvature of the training landscape leads to benefits for generalization (Keskar et al., 2017; Neyshabur et al., 2017), although in general the relationship between the two is complex (Dinh et al., 2017).

Recently the sharpness aware minimization (SAM) algorithm has emerged as a popular choice for regularizing the curvature during training (Foret et al., 2022). SAM has the advantage of being a tractable first-order method; for the cost of a single extra gradient evaluation, SAM can control the large eigenvalues of the loss Hessian and often leads to improved optimization and generalization (Bahri et al., 2022).

However, understanding the mechanisms behind the effectiveness of SAM is an open question. The SAM algorithm itself is a first-order approximation of SGD on a modified loss function ~(𝜽)=maxδ𝜽<ρ(𝜽+δ𝜽)\tilde{\mathcal{L}}(\bm{\theta})=\max_{||\delta\bm{\theta}||<\rho}\mathcal{L}(\bm{\theta}+\delta\bm{\theta}). Part of the original motivation was that ~\tilde{\mathcal{L}} explicitly penalizes sharp minima over flatter ones. However the approximation performs as well or better than running gradient descent on ~\tilde{\mathcal{L}} directly. SAM often works better with small batch sizes as compared to larger ones (Foret et al., 2022; Andriushchenko & Flammarion, 2022).These stochastic effects suggest that studying the deterministic gradient flow dynamics on ~\tilde{\mathcal{L}} will not capture key features of SAM, since small batch size induces non-trivial differences from gradient flow (Paquette et al., 2021).

In parallel to the development of SAM, experimental and theoretical work has uncovered some of the curvature-controlling properties of first-order methods due to finite step size - particularly in the full batch setting. At intermediate learning rates, a wide variety of models and optimizers show a tendency for the largest Hessian eigenvalues to stabilize near the edge of stability (EOS) for long times (Lewkowycz et al., 2020; Cohen et al., 2022a, b). The EOS is the largest eigenvalue which would lead to convergence for a quadratic loss landscape. This effect can be explained in terms of a non-linear feedback between the large eigenvalue and changes in the parameters in that eigendirection (Damian et al., 2022; Agarwala et al., 2022).

We will show that these two areas of research are in fact intimately linked: under a variety of conditions, SAM displays a modified EOS behavior, which leads to stabilization of the largest eigenvalues at a lower magnitude via non-linear, discrete dynamics. These effects highlight the dynamical nature of eigenvalue regularization, and demonstrates that SAM can have strong effects throughout a training trajectory.

1.1 Related work

Previous experimental work suggested that decreasing batch size causes SAM to display both stronger regularization and better generalization (Andriushchenko & Flammarion, 2022). This analysis also suggested that SAM may induce more sparsity.

A recent theoretical approach studied SAM close to a minimum, where the trajectory oscillates about the minima and provably decreases the largest eigenvalue (Bartlett et al., 2022). A contemporaneous approach studied the SAM algorithm in the limit of small learning rate and SAM radius, and quantified how the implicit and explicit regularization of SAM differs between full batch and batch size 11 dynamics (Wen et al., 2023).

1.2 Our contributions

In contrast to other theoretical approaches, we study the behavior of SAM far from minima. We find that SAM regularizes the eigenvalues throughout training through a dynamical phenomenon and analysis only near convergence cannot capture the full picture. In particular, in simplified models we show:

  • Near initialization, full batch SAM provides limited suppression of large eigenvalues (Theorem 2.1).

  • SAM induces a modified edge of stability (EOS) (Theorem 2.2).

  • For full batch training, the largest eigenvalues stabilize at the SAM-EOS, at a smaller value than pure gradient descent (Section 3).

  • As batch size decreases, the effect of SAM is stronger and the dynamics is no longer controlled by the Hessian alone (Theorem 2.3).

We then present experimental results on realistic models which show:

  • The SAM-EOS predicts the largest eigenvalue for WideResnet 28-10 on CIFAR10.

Taken together, our results suggest that SAM can operate throughout the learning trajectory, far from minima, and that it can use non-linear, discrete dynamical effects to stabilize large curvatures of the loss function.

2 Quadratic regression model

2.1 Basic model

We consider a quadratic regression model (Agarwala et al., 2022) which extends a linear regression model to second order in the parameters. Given a PP-dimensional parameter vector 𝜽\bm{\theta}, the DD-dimensional output is given by 𝐟(𝜽)\mathbf{f}(\bm{\theta}):

𝐟(𝜽)=𝐲+𝐆𝜽+12𝑸(𝜽,𝜽).\mathbf{f}(\bm{\theta})=\mathbf{y}+\mathbf{G}^{\top}\bm{\theta}+\frac{1}{2}\bm{\mathsfit{Q}}(\bm{\theta},\bm{\theta})\,. (1)

Here, 𝐲\mathbf{y} is a DD-dimensional vector, 𝐆\mathbf{G} is a D×PD\times P-dimensional matrix, and 𝑸\bm{\mathsfit{Q}} is a D×P×PD\times P\times P- dimensional tensor symmetric in the last two indices - that is, 𝑸(,)\bm{\mathsfit{Q}}(\cdot,\cdot) takes two PP-dimensional vectors as input, and outputs a DD-dimensional vector 𝑸(𝜽,𝜽)α=𝜽𝑸α𝜽\bm{\mathsfit{Q}}(\bm{\theta},\bm{\theta})_{\alpha}=\bm{\theta}^{\top}\bm{\mathsfit{Q}}_{\alpha}\bm{\theta}. If 𝑸=𝟎\bm{\mathsfit{Q}}=\bm{0}, the model corresponds to linear regression. 𝐲\mathbf{y}, 𝐆\mathbf{G}, and 𝑸\bm{\mathsfit{Q}} are all fixed at initialization.

Consider optimizing the model with under a squared loss. More concretely, let 𝐲tr\mathbf{y}_{tr} be a DD-dimensional vector of training targets. We focus on the MSE loss

(𝜽)=12𝐟(𝜽)𝐲tr2\mathcal{L}(\bm{\theta})=\frac{1}{2}||\mathbf{f}(\bm{\theta})-\mathbf{y}_{tr}||^{2} (2)

We can write the dynamics in terms of the residuals 𝐳\mathbf{z} and the Jacobian 𝐉\mathbf{J} defined by

𝐳f(𝜽)𝐲tr,𝐉𝐟𝜽=𝐆+𝑸(𝜽,).\mathbf{z}\equiv f(\bm{\theta})-\mathbf{y}_{tr},~{}\mathbf{J}\equiv\frac{\partial\mathbf{f}}{\partial\bm{\theta}}=\mathbf{G}+\bm{\mathsfit{Q}}(\bm{\theta},\cdot)\,. (3)

The loss can be written as (𝜽)=12𝐳𝐳\mathcal{L}(\bm{\theta})=\frac{1}{2}\mathbf{z}\cdot\mathbf{z}. The full batch gradient descent (GD) dynamics of the parameters are given by

𝜽t+1=𝜽tη𝐉t𝐳t\bm{\theta}_{t+1}=\bm{\theta}_{t}-\eta\mathbf{J}^{\top}_{t}\mathbf{z}_{t} (4)

which leads to

𝐳t+1𝐳t=η𝐉t𝐉t𝐳t+12η2𝑸(𝐉t𝐳t,𝐉t𝐳t)𝐉t+1𝐉t=η𝑸(𝐉t𝐳t,).\begin{split}\mathbf{z}_{t+1}-\mathbf{z}_{t}&=-\eta\mathbf{J}_{t}\mathbf{J}_{t}^{\top}\mathbf{z}_{t}+\frac{1}{2}\eta^{2}\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}\mathbf{z}_{t},\mathbf{J}_{t}^{\top}\mathbf{z}_{t})\\ \mathbf{J}_{t+1}-\mathbf{J}_{t}&=-\eta\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}\mathbf{z}_{t},\cdot)\,.\end{split} (5)

The D×DD\times D-dimensional matrix 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} is known as the neural tangent kernel (NTK) (Jacot et al., 2018), and controls the dynamics for small η𝐉𝐳\eta||\mathbf{J}^{\top}\mathbf{z}|| (Lee et al., 2019).

We now consider the dynamics of un-normalized SAM (Andriushchenko & Flammarion, 2022). That is, given a loss function \mathcal{L} we study the update rule

𝜽t+1𝜽t=η(𝜽t+ρ(𝜽t))\bm{\theta}_{t+1}-\bm{\theta}_{t}=-\eta\nabla\mathcal{L}(\bm{\theta}_{t}+\rho\nabla\mathcal{L}(\bm{\theta}_{t})) (6)

We are particularly interested in small learning rate and small SAM radius. The dynamics in 𝐳𝐉\mathbf{z}-\mathbf{J} space are given by

𝐳t+1𝐳t=η𝐉𝐉(1+ρ𝐉𝐉)𝐳ηρ𝐳𝑸(𝐉𝐳,𝐉)+η212𝑸(𝐉𝐳,𝐉𝐳)+O(ηρ(η+ρ)𝐳2)\begin{split}\mathbf{z}_{t+1}-\mathbf{z}_{t}&=-\eta\mathbf{J}\mathbf{J}^{\top}(1+\rho\mathbf{J}\mathbf{J}^{\top})\mathbf{z}-\eta\rho\mathbf{z}\cdot\bm{\mathsfit{Q}}(\mathbf{J}^{\top}\mathbf{z},\mathbf{J}^{\top}\cdot)\\ &+\eta^{2}\frac{1}{2}\bm{\mathsfit{Q}}(\mathbf{J}^{\top}\mathbf{z},\mathbf{J}^{\top}\mathbf{z})+O(\eta\rho(\eta+\rho)||\mathbf{z}||^{2})\end{split} (7)
𝐉t+1𝐉t=η[𝑸((1+ρ𝐉𝐉)𝐉𝐳,)+ρ𝑸(𝐳𝑸(𝐉𝐳,),)]+O(ηρ2||𝐳||2)\begin{split}\mathbf{J}_{t+1}-\mathbf{J}_{t}&=-\eta\left[\bm{\mathsfit{Q}}((1+\rho\mathbf{J}^{\top}\mathbf{J})\mathbf{J}^{\top}\mathbf{z},\cdot)+\right.\\ &\left.\rho\bm{\mathsfit{Q}}(\mathbf{z}\cdot\bm{\mathsfit{Q}}(\mathbf{J}^{\top}\mathbf{z},\cdot),\cdot)\right]+O(\eta\rho^{2}||\mathbf{z}||^{2})\end{split} (8)

to lowest order in η\eta and ρ\rho.

From Equation 7 we see that for small η𝐳\eta||\mathbf{z}|| and ρ𝐳\rho||\mathbf{z}||, the dynamics of 𝐳\mathbf{z} is controlled by the modified NTK (1+ρ𝐉𝐉)𝐉𝐉(1+\rho\mathbf{J}\mathbf{J}^{\top})\mathbf{J}\mathbf{J}^{\top}. The factor 1+ρ𝐉𝐉1+\rho\mathbf{J}\mathbf{J}^{\top} shows up in the dynamics of 𝐉\mathbf{J} as well, and we will show that this effective NTK can lead to dynamical stabilization of large eigenvalues. And note that when ρ=0\rho=0, these dynamics coincide with that of gradient descent.

2.2 Gradient descent theory

2.2.1 Eigenvalue dynamics at initialization

A basic question is: how does SAM affect the eigenvalues of the NTK? We can study this directly for early learning dynamics by using random initializations. We have the following theorem (proof in Appendix A.2):

Theorem 2.1.

Consider a second-order regression model, with 𝑸\bm{\mathsfit{Q}} initialized randomly with i.i.d. components with 0 mean and variance 11. For a model trained with full batch gradient descent, with unnormalized SAM, the change in 𝐉\mathbf{J} at the first step of the dynamics, averaged over 𝑸\bm{\mathsfit{Q}} is

E𝑸[𝐉1𝐉0]=ρηP𝐳0𝐳0𝐉0+O(ρ2η2𝐳02)+O(η3𝐳03){\rm E}_{\bm{\mathsfit{Q}}}[\mathbf{J}_{1}-\mathbf{J}_{0}]=-\rho\eta P\mathbf{z}_{0}\mathbf{z}_{0}^{\top}\mathbf{J}_{0}+O(\rho^{2}\eta^{2}||\mathbf{z}_{0}||^{2})+O(\eta^{3}||\mathbf{z}_{0}||^{3}) (9)

The α\alphath singular value σα\sigma_{\alpha} of 𝐉0\mathbf{J}_{0} associated with left and right singular vectors 𝐰α\mathbf{w}_{\alpha} and 𝐯α\mathbf{v}_{\alpha} can be approximated as

(σα)1(σα)0=𝐰αE𝑸[𝐉1𝐉0]𝐯α+O(η2)=ρηP(𝐳0𝐰α)2σα+O(η2)\begin{split}&(\sigma_{\alpha})_{1}-(\sigma_{\alpha})_{0}=\mathbf{w}_{\alpha}^{\top}{\rm E}_{\bm{\mathsfit{Q}}}[\mathbf{J}_{1}-\mathbf{J}_{0}]\mathbf{v}_{\alpha}+O(\eta^{2})\\ &=-\rho\eta P(\mathbf{z}_{0}\cdot\mathbf{w}_{\alpha})^{2}\sigma_{\alpha}+O(\eta^{2})\end{split} (10)

for small η\eta.

Note that the singular vector 𝐰α\mathbf{w}_{\alpha} is an eigenvector of 𝐉𝐉T\mathbf{J}\mathbf{J}^{{\rm T}} associated with the eigenvalue σα2\sigma_{\alpha}^{2}.

This analysis suggests that on average, at early times, the change in the singular value is negative. However, the change also depends linearly on (𝐰α𝐳0)2(\mathbf{w}_{\alpha}\cdot\mathbf{z}_{0})^{2}. This suggests that if the component of 𝐳\mathbf{z} in the direction of the singular vector becomes small, the stabilizing effect of SAM becomes small as well. For large batch size/small learning rate with MSE loss, we in fact expect 𝐳𝐰α\mathbf{z}\cdot\mathbf{w}_{\alpha} to decrease rapidly early in training (Cohen et al., 2022a; Agarwala et al., 2022). Therefore the relative regularizing effect can be weaker for larger modes in the GD setting.

Refer to caption
Figure 1: Schematic of SAM-modified EOS. Gradient descent decreases loss until a high-curvature area is reached, where large eigenmode is non-linearly stabilized (orange, solid). SAM causes stabilization to happen earlier, at a smaller value of the curvature (green, dashed).
Refer to caption Refer to caption Refer to caption
Figure 2: Trajectories of largest eigenvalue λmax\lambda_{max} of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} for quadratic regression model, 55 independent initializations. For gradient descent with small learning rate (η=3103\eta=3\cdot 10^{-3}), SAM (ρ=4102\rho=4\cdot 10^{-2}) does not regularize the large NTK eigenvalues (left). For larger learning rate (η=8102\eta=8\cdot 10^{-2}), SAM controls large eigenvalues (middle). Largest eigenvalue can be predicted by SAM edge of stability η(λmax+ρλmax2)=2\eta(\lambda_{max}+\rho\lambda_{max}^{2})=2 (right).

2.2.2 Edge of stability and SAM

One of the most dramatic consequences of SAM for full batch training is the shift of the edge of stability. We begin by reviewing the EOS phenomenology. Consider full-batch gradient descent training with respect to a twice-differentiable loss. Near a minimum of the loss, the dynamics of the displacement 𝐱\mathbf{x} from the minimum (in parameter space) are well-approximated by

𝐱t+1𝐱t=η𝐇𝐱t\mathbf{x}_{t+1}-\mathbf{x}_{t}=-\eta\mathbf{H}\mathbf{x}_{t} (11)

where 𝐇\mathbf{H} is the positive semi-definite Hessian at the minimum 𝐱=0\mathbf{x}=0. The dynamics converges exponentially iff the largest eigenvalue of 𝐇\mathbf{H} is bounded by ηλmax<2\eta\lambda_{max}<2. We refer to ηλmax\eta\lambda_{max} as the normalized eigenvalue, Otherwise, there is at least one component of 𝐱\mathbf{x} which is non-decreasing. The value 2/η2/\eta is often referred to as the edge of stability (EOS) for the dynamics.

Previous work has shown that for many non-linear models, there is a range of learning rates where the largest eigenvalue of the Hessian stabilizes around the edge of stability (Cohen et al., 2022a). Equivalent phenomenology exists for other gradient-based methods (Cohen et al., 2022b). The stabilization effect is due to feedback between the largest curvature eigenvalue and the displacement in the largest eigendirection (Agarwala et al., 2022; Damian et al., 2022). For MSE loss, EOS behavior occurs for the large NTK eigenvalues as well (Agarwala et al., 2022).

We will show that SAM also induces an EOS stabilization effect, but at a smaller eigenvalue than GD. We can understand the shift intuitively by analyzing un-normalized SAM on a loss 12𝐱𝐇𝐱\frac{1}{2}\mathbf{x}^{\top}\mathbf{H}\mathbf{x}. Direct calculation gives the update rule:

𝐱t+1𝐱t=η(𝐇+ρ𝐇2)𝐱t\mathbf{x}_{t+1}-\mathbf{x}_{t}=-\eta(\mathbf{H}+\rho\mathbf{H}^{2})\mathbf{x}_{t} (12)

For positive definite 𝐇\mathbf{H}, 𝐱t\mathbf{x}_{t} converges exponentially to 0 iff η(λmax+ρλmax2)<2\eta(\lambda_{max}+\rho\lambda_{max}^{2})<2. Recall from Section 2.1 that the SAM NTK is (1+ρ𝐉𝐉)𝐉𝐉>𝐉𝐉(1+\rho\mathbf{J}\mathbf{J}^{\top})\mathbf{J}\mathbf{J}^{\top}>\mathbf{J}\mathbf{J}^{\top}. This suggests that η(λmax+ρλmax2)\eta(\lambda_{max}+\rho\lambda_{max}^{2}) is the SAM normalized eigenvalue. This bound gives a critical λmax\lambda_{max} which is smaller than that in the GD case. This leads to the hypothesis that SAM can cause a stabilization at the EOS in a flatter region of the loss, as schematically illustrated in Figure 1.

We can formalize the SAM edge of stability (SAM EOS) for any differentiable model trained on MSE loss. Equation 7 suggests the matrix 𝐉𝐉(1+ρ𝐉𝐉)\mathbf{J}\mathbf{J}^{\top}(1+\rho\mathbf{J}\mathbf{J}^{\top}) - which has larger eigenvalues for larger ρ\rho - controls the low-order dynamics. We can formalize this intuition in the following theorem (proof in Appendix B.1):

Theorem 2.2.

Consider a 𝒞\mathcal{C}^{\infty} model 𝐟(𝛉)\mathbf{f}(\bm{\theta}) trained using Equation 6 with MSE loss. Suppose that there exists a point 𝛉\bm{\theta}^{*} where 𝐳(𝛉)=0\mathbf{z}(\bm{\theta}^{*})=0. Suppose that for some ϵ>0{\epsilon}>0, we have the lower bound ϵ<ηλi(1+ρλi){\epsilon}<\eta\lambda_{i}(1+\rho\lambda_{i}) for the eigenvalues of the positive definite symmetric matrix 𝐉(𝛉)𝐉(𝛉)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}. Given a bound on the largest eigenvalue, there are two regimes:

Convergent regime. If ηλi(1+ρλi)<2ϵ\eta\lambda_{i}(1+\rho\lambda_{i})<2-{\epsilon} for all for all eigenvalues λi\lambda_{i} of 𝐉(𝛉)𝐉(𝛉)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}, there exists a neighborhood UU of 𝛉\bm{\theta}^{*} such that limt𝐳t=0\lim_{t\to\infty}\mathbf{z}_{t}=0 with exponential convergence for any trajectory initialized at 𝛉0U\bm{\theta}_{0}\in U.

Divergent regime. If ηλi(1+ρλi)>2+ϵ\eta\lambda_{i}(1+\rho\lambda_{i})>2+{\epsilon} for some eigenvector 𝐯i\mathbf{v}_{i} of 𝐉(𝛉)𝐉(𝛉)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}, then there exists some qminq_{min} such that for any q<qminq<q_{min}, given Bq(𝛉)B_{q}(\bm{\theta}^{*}), the ball of radius qq around 𝛉\bm{\theta}^{*}, there exists some initialization 𝛉0Bq(𝛉)\bm{\theta}_{0}\in B_{q}(\bm{\theta}^{*}) such that the trajectory {𝛉t}\{\bm{\theta}_{t}\} leaves Bq(𝛉)B_{q}(\bm{\theta}^{*}) at some time tt.

Note that the theorem is proven for the NTK eigenvalues, which also show EOS behavior for MSE loss in the GD setting (Agarwala et al., 2022).

This theorem gives us the modified edge of stability condition:

ηλmax(1+ρλmax)2\eta\lambda_{max}(1+\rho\lambda_{max})\approx 2 (13)

For larger ρ\rho, a smaller λmax\lambda_{max} is needed to meet the edge of stability condition. In terms of the normalized eigenvalue λ~=ηλ\tilde{\lambda}=\eta\lambda, the modified EOS can be written as λ~(1+rλ~)=2\tilde{\lambda}(1+r\tilde{\lambda})=2 with the ratio r=ρ/ηr=\rho/\eta. Larger values of rr lead to stronger regularization effects, and for the quadratic regression model specifically η\eta can be factored out leaving rr as the key dimensionless parameter (Appendix A.1).

2.3 SGD theory

It has been noted that the effects of SAM have a strong dependence on batch size (Andriushchenko & Flammarion, 2022). While a full analysis of SGD is beyond the scope of this work, we can see some evidence of stronger regularization for SGD in the quadratic regression model.

Consider SGD dynamics, where a random fraction β=B/D\beta=B/D of the training residuals 𝐳\mathbf{z} are used to generate the dynamics at each step. We can represent the sampling at each step with a random projection matrix 𝐏t\mathbf{P}_{t}, and replacing all instances of 𝐳t\mathbf{z}_{t} with 𝐏t𝐳t\mathbf{P}_{t}\mathbf{z}_{t}. Under these dynamics, we can can prove the following:

Theorem 2.3.

Consider a second-order regression model, with 𝑸\bm{\mathsfit{Q}} initialized randomly with i.i.d. components with 0 mean and variance 11. For a model trained with SGD, sampling BB datapoints independently at each step, the change in 𝐳\mathbf{z} and 𝐉\mathbf{J} at the first step, averaged over 𝑸\bm{\mathsfit{Q}} and the sampling matrix 𝐏t\mathbf{P}_{t}, is given by

E[𝐳1𝐳0]𝑸,𝐏=ηβ𝐉0𝐉0(1+ρ[β(𝐉0𝐉0)+(1β)diag(𝐉0𝐉0)])𝐳0+O(η2||𝐳||2)+O(D1)\begin{split}&{\rm E}[\mathbf{z}_{1}-\mathbf{z}_{0}]_{\bm{\mathsfit{Q}},\mathbf{P}}=-\eta\beta\mathbf{J}_{0}\mathbf{J}_{0}^{\top}(1+\rho[\beta(\mathbf{J}_{0}\mathbf{J}_{0}^{\top})\\ &+(1-\beta){\rm diag}(\mathbf{J}_{0}\mathbf{J}_{0}^{\top})])\mathbf{z}_{0}+O(\eta^{2}||\mathbf{z}||^{2})+O(D^{-1})\end{split} (14)
E𝑸,𝐏[𝐉1𝐉0]=ρηP(β2𝐳0𝐳0+β(1β)diag(𝐳0𝐳0))𝐉0+O(ρ2η2𝐳2)+O(η3𝐳3)\begin{split}&{\rm E}_{\bm{\mathsfit{Q}},\mathbf{P}}[\mathbf{J}_{1}-\mathbf{J}_{0}]=-\rho\eta P(\beta^{2}\mathbf{z}_{0}\mathbf{z}_{0}^{\top}+\beta(1-\beta){\rm diag}(\mathbf{z}_{0}\mathbf{z}_{0}^{\top}))\mathbf{J}_{0}\\ &+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3})\end{split} (15)

where βB/D\beta\equiv B/D is the batch fraction.

The calculations are detailed in Appendix A.2. This suggests that there are two possible sources of increased regularization for SGD: the first being the additional terms proportional to β(1β)\beta(1-\beta). In addition to the fact that β(1β)>β2\beta(1-\beta)>\beta^{2} for β<12\beta<\frac{1}{2}, we have

𝐯αdiag(𝐳0𝐳0))𝐉0𝐰α=σα(𝐯α𝐳0)(𝐯α𝐳0)\mathbf{v}_{\alpha}{\rm diag}(\mathbf{z}_{0}\mathbf{z}_{0}^{\top}))\mathbf{J}_{0}\mathbf{w}_{\alpha}=\sigma_{\alpha}(\mathbf{v}_{\alpha}\circ\mathbf{z}_{0})\cdot(\mathbf{v}_{\alpha}\circ\mathbf{z}_{0}) (16)

for left and right eigenvectors 𝐯α\mathbf{v}_{\alpha} and 𝐰α\mathbf{w}_{\alpha} of 𝐉0\mathbf{J}_{0}, where \circ is the Hadamard (elementwise) product. This term can be large even if 𝐯α\mathbf{v}_{\alpha} and 𝐳t\mathbf{z}_{t} have small dot product. This is in contrast to β2(𝐯α𝐳0)2\beta^{2}(\mathbf{v}_{\alpha}\cdot\mathbf{z}_{0})^{2}, which is small if 𝐳0\mathbf{z}_{0} does not have a large component in the 𝐯α\mathbf{v}_{\alpha} direction. This suggests that at short times, where the large eigenmodes decay quickly, the SGD term can still be large. Additionally, the onto the largest eigenmode itself decreases more slowly in the SGD setting (Paquette et al., 2021), which also suggests stronger early time regularization for small batch size.

Refer to caption
Figure 3: Largest eigenvalues of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} for a fully-connected network trained using MSE loss on 2-class CIFAR. For gradient descent (η=4103\eta=4\cdot 10^{-3}) largest eigenvalue stabilizes according to the GD EOS ηλmax=2\eta\lambda_{max}=2 (solid line, blue). SAM (ρ=102\rho=10^{-2}) stabilizes to a lower value (dashed line, blue), which is well-predicted by the SAM EOS η(λmax+ρλmax2)=2\eta(\lambda_{max}+\rho\lambda_{max}^{2})=2 (dashed line, orange).

3 Experiments on basic models

3.1 Quadratic regression model

We can explore the effects of SAM and show the SAM EOS behavior via numerical experiments on the quadratic regression model. We use the update rule in Equation 6, working directly in 𝐳\mathbf{z} and 𝐉\mathbf{J} space as in (Agarwala et al., 2022). Experimental details can be found in Appendix A.3.

For small learning rates, we see that SAM does not reduce the large eigenvalues of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} in the dynamics (Figure 2, left). In fact in some cases the final eigenvalue is larger with SAM turned on. The projection onto the largest eigenmodes of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} exponentially decreases to 0 quicker than any other mode; as suggested by Theorem 2.1, this leads to only a small decreasing pressure from SAM. The primary dynamics of the large eigenvalues is due to the progressive sharpening phenomenology studied in (Agarwala et al., 2022), which tends to increase the eigenmodes.

However, for larger learning rates, SAM has a strong suppressing effect on the largest eigenvalues (Figure 2, middle). The overall dynamics are more non-linear than in the small learning rate case. The eigenvalues stabilize at the modified EOS boundary η(λmax+ρλmax2)=2\eta(\lambda_{max}+\rho\lambda_{max}^{2})=2 (Figure 2, right), suggesting non-linear stabilization of the eigenvalues. In Appendix A.3 we conduct additional experiments which confirm that the boundary predicts the largest eigenvalue for a range of ρ\rho, and that consequently generally increasing ρ\rho leads to decreased λmax\lambda_{max}.

3.2 CIFAR-22 with MSE loss

We can see this phenomenology in more general non-linear models as well. We trained a fully-connected network on the first 22 classes of CIFAR with MSE loss, with both full batch gradient descent and SAM. We then computed the largest eigenvalues of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} along the trajectory. We can see that in both GD and SAM the largest eigenvalues stabilize, and the stabilization threshold is smaller for SAM (Figure 3). The threshold is once again well predicted by the SAM EOS.

4 Connection to realistic models

Refer to caption Refer to caption Refer to caption
Figure 4: Largest Hessian eigenvalues for CIFAR10 trained with MSE loss. Left: largest eigenvalues increase at late times. Larger SAM radius mitigates eigenvalue increase. Middle: eigenvalues normalized by learning rate decrease at late times, and SGD shows edge of stability (EOS) behavior. Right: For larger ρ\rho, SAM-normalized eigenvalues show modified EOS behavior.

In this section, we show that our analysis of quadratic models can bring insights into the behavior of more realistic models.

4.1 Setup

Sharpness For MSE loss, edge of stability dynamics can be shown in terms of either the NTK eigenvalues or the Hessian eigenvalues (Agarwala et al., 2022). For more general loss functions, EOS dynamics takes place with respect to the largest Hessian eigenvalues (Cohen et al., 2022a; Damian et al., 2022). Following these results and the analysis in Equation 12, we chose to measure the largest eigenvalue of the Hessian rather than the NTK. We used a Lanczos method (Ghorbani et al., 2019) to approximately compute λmax\lambda_{max}. Any reference to λmax\lambda_{max} in this section refers to eigenvalues computed in this way.

CIFAR-10 We conducted experiments on the popular CIFAR-10 dataset (Krizhevsky et al., 2009) using the WideResnet 28-10 architecture (Zagoruyko & Komodakis, 2016). We report results for both the MSE loss and the cross-entropy loss. In the case of the MSE loss, we replace the softmax non-linearity with Tanh and rescale the one-hot labels 𝐲{0,1}{\bf y}\in\{0,1\} to {1,1}\{-1,1\}. In both cases, the loss is averaged across the number of elements in the batch and the number of classes. For each setting, we report results for a single configuration of the learning rate η\eta and weight decay μ\mu found from an initial cross-validation sweep. For MSE, we use η=0.3,μ=0.005\eta=0.3,\mu=0.005 and η=0.4,μ=0.005\eta=0.4,\mu=0.005 for cross-entropy. We use the cosine learning rate schedule (Loshchilov & Hutter, 2016) and SGD instead of Nesterov momentum (Sutskever et al., 2013) to better match the theoretical setup. Despite the changes to the optimizer and the loss, the test error for the models remains in a reasonable range (4.4% for SAM regularized models with MSE and 5.3% with SGD). In accordance with the theory, we use unnormalized SAM in these experiments. We keep all other hyper-parameters to the default values described in the original WideResnet paper.

4.2 Results

As shown in Figure 4 (left), the maximum eigenvalue increases significantly throughout training for all approaches considered. However, the normalized curvature ηλmax\eta\lambda_{max}, which sets the edge of stability in GD, remains relatively stable early on in training when the learning rate is high, but necessarily decreases as the cosine schedule drives the learning rate to 0 (Figure 4, middle).

SAM radius drives curvature below GD EOS. As we increase the SAM radius, the largest eigenvalue is more controlled (Figure 4, left) - falling below the gradient descent edge of stability (Figure 4, middle). The stabilizing effect of SAM on the large eigenvalues is evident even early on in training.

Eigenvalues stabilize around SAM-EOS. If we instead plot the SAM-normalized eigenvalue η(λmax+ρλmax2)\eta(\lambda_{max}+\rho\lambda_{max}^{2}), we see that the eigenvalues stay close to (and often slightly above) the critical value of 22, as predicted by theory (Figure 4, right). This suggests that there are settings where the control that SAM has on the large eigenvalues of the Hessian comes, in part, from a modified EOS stabilization effect.

Refer to caption Refer to caption
Figure 5: Maximum eigenvalues for CIFAR-10 model trained on MSE loss with a SAM schedule. Starting out with SAM (ρ=0.05\rho=0.05, solid lines) and turning it off at 25002500 steps leads to initial suppression and eventual increase of λmax\lambda_{max}; starting out with SGD and turning SAM on after 25002500 steps leads to the opposite behavior (left). Eigenvalues cross over quickly after the switch. Plotting GD normalized eigenvalues (blue, right) shows GD EOS behavior in SGD phase; plotting SAM normalized eigenvalues (orange, right) shows SAM EOS behavior in SAM phase.

Altering SAM radius during training can successfully move us between GD-EOS and SAM-EOS. Further evidence from EOS stabilization comes from using a SAM schedule. We trained the model with two settings: early SAM, where SAM is used for the first 25002500 steps (5050 epochs), after which the training proceeds with SGD (ρ=0\rho=0), and late SAM, where SAM is used for the first 25002500 steps, after which only SGD is used. The maximum eigenvalue is lower for early SAM before 25002500 steps, after which there is a quick crossover and late SAM gives better control (Figure 5). Both SAM schedules give improvement over SGD-only training. Generally, turning SAM on later or for the full trajectory gave better generalization than turning SAM on early, consistent with the earlier work of Andriushchenko & Flammarion (2022).

Plotting the eigenvalues for the early SAM and late SAM schedules, we see that when SAM is turned off, the normalized eigenvalues lie above the gradient descent EOS (Figure 5, right, blue curves). However when SAM is turned on, ηλmax\eta\lambda_{max} is usually below the edge of stability value of 22; instead, the SAM-normalized value η(λmax+ρλmax2)\eta(\lambda_{max}+\rho\lambda_{max}^{2}) lies close to the critical value of 22 (Figure 5, right, orange curves). This suggests that turning SAM on or off during the intermediate part of training causes the dynamics to quickly reach the appropriate edge of stability.

Networks with cross-entropy loss behave similarly. We found similar results for cross-entropy loss as well, which we detail in Appendix C.1. The mini-batch gradient magnitude and eigenvalues vary more over the learning trajectories; this may be related to effects of logit magnitudes which have been previously shown to affect curvature and general training dynamics (Agarwala et al., 2020; Cohen et al., 2022a).

Minibatch gradient norm varies little. Another quantity of interest is the magnitude of the mini-batch gradients. For SGD, the gradient magnitudes were steady during the first half of training and dropped by a factor of 44 at late times (Figure 6). Gradient magnitudes were very stable for SAM, particularly for larger ρ\rho. This suggests that in practice, there may not be much difference between the normalized and un-normalized SAM algorithms. This is consistent with previous work which showed that the generalization of the two approaches is similar (Andriushchenko & Flammarion, 2022).

Refer to caption
Figure 6: Minibatch gradient magnitudes for CIFAR-10 model trained on MSE loss. Magnitudes are steady early on in SGD training, but decrease at the end of training. Eigenvalue variation is smaller for increasing sam radius ρ\rho.

5 Discussion

5.1 SAM as a dynamical phenomenon

Much like the study of EOS before it, our analysis of SAM suggests that sharpness dynamics near minima are insufficient to capture relevant phenomenology. Our analysis of the quadratic regression model suggests that SAM already regularizes the large eigenmodes at early times, and the EOS analysis shows how SAM can have strong effects even in the large-batch setting. Our theory also suggested that SGD has additional mechanisms to control curvature early on in training as compared to full batch gradient descent.

The SAM schedule experiments provided further evidence that multiple phases of the optimization trajectory are important for understanding the relationship between SAM and generalization. If the important effect was the convergence to a particular minimum, then only late SAM would improve generalization. If instead some form of “basin selection” was key, then only early SAM would improve generalization. The fact that both are important (Andriushchenko & Flammarion, 2022) suggests that the entire optimization trajectory matters.

We note that while EOS effects are necessary to understand some aspects of SAM, they are certainly not sufficient. As shown in Appendix A.3, the details of the behavior near the EOS have a complex dependence on ρ\rho (and the model). Later on in learning, especially with a loss like cross entropy, the largest eigenvalues may decrease even without SAM (Cohen et al., 2022a) - potentially leading the dynamics away from the EOS. Small batch size may add other effects, and EOS effects become harder to understand if multiple eigenvalues are at the EOS. Nonetheless, even in more complicated cases the SAM EOS gives a good approximation to the control SAM has on the eigenvalues, particularly at early times.

5.2 Optimization and regularization are deeply linked

This work provides additional evidence that understanding some regularization methods may in fact require analysis of the optimization dynamics - especially those at early or intermediate times. This is in contrast to approaches which seek to understand learning by characterizing minima, or analyzing behavior near convergence only. A similar phenomenology has been observed in evolutionary dynamics - the basic 0th order optimization method - where the details of optimization trajectories are often more important than the statistics of the minima to understand long-term dynamics (Nowak & Krug, 2015; Park & Krug, 2016; Agarwala & Fisher, 2019).

6 Future work

Our main theoretical analysis focused on the dynamics 𝐳\mathbf{z} and 𝐉\mathbf{J} under squared loss; additional complications arise for non-squared losses like cross-entropy. Providing a detailed quantitative characterization of the EOS dynamics under these more general conditions is an important next step.

Another important open question is the analysis of SAM (and the EOS effect more generally) under SGD. While Theorem 2.3 provides some insight into the differences, a full understanding would require an analysis of E𝐏[(𝐳𝐯i)2]{\rm E}_{\mathbf{P}}[(\mathbf{z}\cdot\mathbf{v}_{i})^{2}] for the different eigenmodes 𝐯i\mathbf{v}_{i} - which has only recently been analyzed for a quadratic loss function (Paquette et al., 2021, 2022a, 2022b; Lee et al., 2022). Our analysis of the CIFAR10 models showed that the SGD gradient magnitude does not change much over training. Further characterization of the SGD gradient statistics will also be useful in understanding the interaction of SAM and SGD.

More detailed theoretical and experimental analysis of more complex settings may allow for improvements to the SAM algorithm and its implementation in practice. A more detailed theoretical understanding could lead to proposals for ρ\rho-schedules, or improvements to the core algorithm itself - already a field of active research (Zhuang et al., 2022).

Finally, our work focuses on optimization and training dynamics; linking these properties to generalization remains a key goal of any further research into SAM and other optimization methods.

References

  • Agarwala & Fisher (2019) Agarwala, A. and Fisher, D. S. Adaptive walks on high-dimensional fitness landscapes and seascapes with distance-dependent statistics. Theoretical Population Biology, 130:13–49, December 2019. ISSN 0040-5809. doi: 10.1016/j.tpb.2019.09.011.
  • Agarwala et al. (2020) Agarwala, A., Pennington, J., Dauphin, Y., and Schoenholz, S. Temperature check: Theory and practice for training models with softmax-cross-entropy losses, October 2020.
  • Agarwala et al. (2022) Agarwala, A., Pedregosa, F., and Pennington, J. Second-order regression models exhibit progressive sharpening to the edge of stability, October 2022.
  • Andriushchenko & Flammarion (2022) Andriushchenko, M. and Flammarion, N. Towards Understanding Sharpness-Aware Minimization, June 2022.
  • Bahri et al. (2022) Bahri, D., Mobahi, H., and Tay, Y. Sharpness-Aware Minimization Improves Language Model Generalization, March 2022.
  • Bartlett et al. (2022) Bartlett, P. L., Long, P. M., and Bousquet, O. The Dynamics of Sharpness-Aware Minimization: Bouncing Across Ravines and Drifting Towards Wide Minima, October 2022.
  • Chaudhari et al. (2019) Chaudhari, P., Choromanska, A., Soatto, S., LeCun, Y., Baldassi, C., Borgs, C., Chayes, J., Sagun, L., and Zecchina, R. Entropy-SGD: Biasing gradient descent into wide valleys. Journal of Statistical Mechanics: Theory and Experiment, 2019(12):124018, December 2019. ISSN 1742-5468. doi: 10.1088/1742-5468/ab39d9.
  • Cohen et al. (2022a) Cohen, J., Kaur, S., Li, Y., Kolter, J. Z., and Talwalkar, A. Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability. In International Conference on Learning Representations, February 2022a.
  • Cohen et al. (2022b) Cohen, J. M., Ghorbani, B., Krishnan, S., Agarwal, N., Medapati, S., Badura, M., Suo, D., Cardoze, D., Nado, Z., Dahl, G. E., and Gilmer, J. Adaptive Gradient Methods at the Edge of Stability, July 2022b.
  • Damian et al. (2022) Damian, A., Nichani, E., and Lee, J. D. Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability, September 2022.
  • Dinh et al. (2017) Dinh, L., Pascanu, R., Bengio, S., and Bengio, Y. Sharp minima can generalize for deep nets. In Proceedings of the 34th International Conference on Machine Learning - Volume 70, ICML’17, pp.  1019–1028, Sydney, NSW, Australia, August 2017. JMLR.org.
  • Duchi et al. (2011) Duchi, J., Hazan, E., and Singer, Y. Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12(61):2121–2159, 2011. ISSN 1533-7928.
  • Foret et al. (2022) Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. Sharpness-aware Minimization for Efficiently Improving Generalization. In International Conference on Learning Representations, April 2022.
  • Ghorbani et al. (2019) Ghorbani, B., Krishnan, S., and Xiao, Y. An Investigation into Neural Net Optimization via Hessian Eigenvalue Density. In Proceedings of the 36th International Conference on Machine Learning, pp.  2232–2241. PMLR, May 2019.
  • Hochreiter & Schmidhuber (1997) Hochreiter, S. and Schmidhuber, J. Flat Minima. Neural Computation, 9(1):1–42, January 1997. ISSN 0899-7667. doi: 10.1162/neco.1997.9.1.1.
  • Jacot et al. (2018) Jacot, A., Gabriel, F., and Hongler, C. Neural Tangent Kernel: Convergence and Generalization in Neural Networks. In Advances in Neural Information Processing Systems 31, pp.  8571–8580. Curran Associates, Inc., 2018.
  • Keskar et al. (2017) Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima, February 2017.
  • Krizhevsky et al. (2009) Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009.
  • Lee et al. (2019) Lee, J., Xiao, L., Schoenholz, S., Bahri, Y., Novak, R., Sohl-Dickstein, J., and Pennington, J. Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent. In Advances in Neural Information Processing Systems 32, pp.  8570–8581. Curran Associates, Inc., 2019.
  • Lee et al. (2022) Lee, K., Cheng, A. N., Paquette, C., and Paquette, E. Trajectory of Mini-Batch Momentum: Batch Size Saturation and Convergence in High Dimensions, June 2022.
  • Lewis & Overton (2013) Lewis, A. S. and Overton, M. L. Nonsmooth optimization via quasi-Newton methods. Mathematical Programming, 141(1):135–163, October 2013. ISSN 1436-4646. doi: 10.1007/s10107-012-0514-2.
  • Lewkowycz et al. (2020) Lewkowycz, A., Bahri, Y., Dyer, E., Sohl-Dickstein, J., and Gur-Ari, G. The large learning rate phase of deep learning: The catapult mechanism. March 2020.
  • Loshchilov & Hutter (2016) Loshchilov, I. and Hutter, F. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  • Neyshabur et al. (2017) Neyshabur, B., Bhojanapalli, S., Mcallester, D., and Srebro, N. Exploring Generalization in Deep Learning. In Advances in Neural Information Processing Systems 30, pp.  5947–5956. Curran Associates, Inc., 2017.
  • Nocedal (1980) Nocedal, J. Updating quasi-Newton matrices with limited storage. Mathematics of Computation, 35(151):773–782, 1980. ISSN 0025-5718, 1088-6842. doi: 10.1090/S0025-5718-1980-0572855-7.
  • Nowak & Krug (2015) Nowak, S. and Krug, J. Analysis of adaptive walks on NK fitness landscapes with different interaction schemes. Journal of Statistical Mechanics: Theory and Experiment, 2015(6):P06014, 2015.
  • Paquette et al. (2021) Paquette, C., Lee, K., Pedregosa, F., and Paquette, E. SGD in the Large: Average-case Analysis, Asymptotics, and Stepsize Criticality. In Proceedings of Thirty Fourth Conference on Learning Theory, pp.  3548–3626. PMLR, July 2021.
  • Paquette et al. (2022a) Paquette, C., Paquette, E., Adlam, B., and Pennington, J. Homogenization of SGD in high-dimensions: Exact dynamics and generalization properties, May 2022a.
  • Paquette et al. (2022b) Paquette, C., Paquette, E., Adlam, B., and Pennington, J. Implicit Regularization or Implicit Conditioning? Exact Risk Trajectories of SGD in High Dimensions, June 2022b.
  • Park & Krug (2016) Park, S.-C. and Krug, J. δ\delta-exceedance records and random adaptive walks. Journal of Physics A: Mathematical and Theoretical, 49(31):315601, 2016.
  • Sutskever et al. (2013) Sutskever, I., Martens, J., Dahl, G., and Hinton, G. On the importance of initialization and momentum in deep learning. In International conference on machine learning, pp. 1139–1147. PMLR, 2013.
  • Wen et al. (2023) Wen, K., Ma, T., and Li, Z. How Does Sharpness-Aware Minimization Minimize Sharpness?, January 2023.
  • Zagoruyko & Komodakis (2016) Zagoruyko, S. and Komodakis, N. Wide residual networks. arXiv preprint arXiv:1605.07146, 2016.
  • Zhuang et al. (2022) Zhuang, J., Gong, B., Yuan, L., Cui, Y., Adam, H., Dvornek, N., Tatikonda, S., Duncan, J., and Liu, T. Surrogate Gap Minimization Improves Sharpness-Aware Training, March 2022.

Appendix A Quadratic regression model

A.1 Rescaled dynamics

The learning rate can be rescaled out of the quadratic regression model. In previous work, the the rescaling

𝐳~=η𝐳,𝐉~=η1/2𝐉\tilde{\mathbf{z}}=\eta\mathbf{z},~{}\tilde{\mathbf{J}}=\eta^{1/2}\mathbf{J} (17)

which gave a universal representation of the dynamics. The same rescaling in the SAM case gives us:

𝐳~t+1𝐳~t=(𝐉~t𝐉~t+r(𝐉~t𝐉~t)2)𝐳~tr[(1+r𝐉~t𝐉~t)𝐳~t]𝑸(𝐉~t𝐳~t,𝐉~t)+12𝑸[𝐉~t(1+r𝐉~t𝐉~t)𝐳~t,𝐉~t(1+r𝐉~t𝐉~t)𝐳~t]+O(𝐳~3)\begin{split}\tilde{\mathbf{z}}_{t+1}-\tilde{\mathbf{z}}_{t}&=-(\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top}+r(\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})^{2})\tilde{\mathbf{z}}_{t}-r[(1+r\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})\tilde{\mathbf{z}}_{t}]^{\top}\bm{\mathsfit{Q}}(\tilde{\mathbf{J}}_{t}^{\top}\tilde{\mathbf{z}}_{t},\tilde{\mathbf{J}}_{t}^{\top}\cdot)\\ &+\frac{1}{2}\bm{\mathsfit{Q}}[\tilde{\mathbf{J}}_{t}^{\top}(1+r\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})\tilde{\mathbf{z}}_{t},\tilde{\mathbf{J}}_{t}^{\top}(1+r\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})\tilde{\mathbf{z}}_{t}]+O(||\tilde{\mathbf{z}}^{3}||)\end{split} (18)
𝐉~t+1𝐉~t=𝑸(𝐉~t(1+r𝐉~t𝐉~t)𝐳~t,)r𝑸([(1+r𝐉~t𝐉~t)𝐳~t]𝑸(𝐉~t𝐳~t,),)12r2𝑸[𝐉~t𝑸(𝐉~t𝐳~t,𝐉~t𝐳~t),]+O(𝐳~3)\begin{split}\tilde{\mathbf{J}}_{t+1}-\tilde{\mathbf{J}}_{t}&=-\bm{\mathsfit{Q}}(\tilde{\mathbf{J}}_{t}^{\top}(1+r\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})\tilde{\mathbf{z}}_{t},\cdot)-r\bm{\mathsfit{Q}}([(1+r\tilde{\mathbf{J}}_{t}\tilde{\mathbf{J}}_{t}^{\top})\tilde{\mathbf{z}}_{t}]^{\top}\bm{\mathsfit{Q}}(\tilde{\mathbf{J}}_{t}^{\top}\tilde{\mathbf{z}}_{t},\cdot),\cdot)\\ &-\frac{1}{2}r^{2}\bm{\mathsfit{Q}}\left[\tilde{\mathbf{J}}_{t}^{\top}\bm{\mathsfit{Q}}(\tilde{\mathbf{J}}_{t}^{\top}\tilde{\mathbf{z}}_{t},\tilde{\mathbf{J}}_{t}^{\top}\tilde{\mathbf{z}}_{t}),\cdot\right]+O(||\tilde{\mathbf{z}}^{3}||)\end{split} (19)

where rr is the rescaled SAM radius r=ρ/ηr=\rho/\eta.

This suggests that, at least for gradient descent, the ratio of the SAM radius to the learning rate determines the amount of regularization which SAM provides.

A.2 Average values, one step SGD

We will prove Theorem 2.3 first, and then derive Theorem 2.1 is as a special case.

Theorem 2.3.

Consider a second-order regression model, with 𝑸\bm{\mathsfit{Q}} initialized randomly with i.i.d. components with 0 mean and variance 11. For a model trained with SGD, sampling BB datapoints independently at each step, the change in 𝐳\mathbf{z} and 𝐉\mathbf{J} at the first step, averaged over 𝑸\bm{\mathsfit{Q}} and the sampling matrix 𝐏t\mathbf{P}_{t}, is given by

E[𝐳1𝐳0]𝑸,𝐏=ηβ𝐉0𝐉0(1+ρ[β(𝐉0𝐉0)+(1β)diag(𝐉0𝐉0)])𝐳0+O(η2𝐳2)+O(D1){\rm E}[\mathbf{z}_{1}-\mathbf{z}_{0}]_{\bm{\mathsfit{Q}},\mathbf{P}}=-\eta\beta\mathbf{J}_{0}\mathbf{J}_{0}^{\top}(1+\rho[\beta(\mathbf{J}_{0}\mathbf{J}_{0}^{\top})+(1-\beta){\rm diag}(\mathbf{J}_{0}\mathbf{J}_{0}^{\top})])\mathbf{z}_{0}+O(\eta^{2}||\mathbf{z}||^{2})+O(D^{-1}) (20)
E𝑸,𝐏[𝐉1𝐉0]=ρηP(β2𝐳0𝐳0+β(1β)diag(𝐳0𝐳0))𝐉0+O(ρ2η2𝐳2)+O(η3𝐳3){\rm E}_{\bm{\mathsfit{Q}},\mathbf{P}}[\mathbf{J}_{1}-\mathbf{J}_{0}]=-\rho\eta P(\beta^{2}\mathbf{z}_{0}\mathbf{z}_{0}^{\top}+\beta(1-\beta){\rm diag}(\mathbf{z}_{0}\mathbf{z}_{0}^{\top}))\mathbf{J}_{0}+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3}) (21)

where βB/D\beta\equiv B/D is the batch fraction.

Proof.

We can write the SGD dynamics of the quadratic regression model as:

𝐳t+1𝐳t=η𝐉t𝐉t𝐏t𝐳t+12η2𝑸(𝐉t𝐏𝐳t,𝐉t𝐏𝐳t)\mathbf{z}_{t+1}-\mathbf{z}_{t}=-\eta\mathbf{J}_{t}\mathbf{J}_{t}^{\top}\mathbf{P}_{t}\mathbf{z}_{t}+\frac{1}{2}\eta^{2}\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}\mathbf{P}\mathbf{z}_{t},\mathbf{J}_{t}^{\top}\mathbf{P}\mathbf{z}_{t}) (22)
𝐉t+1𝐉t=η𝑸(𝐉t𝐏t𝐳t,).\mathbf{J}_{t+1}-\mathbf{J}_{t}=-\eta\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}\mathbf{P}_{t}\mathbf{z}_{t},\cdot)\,. (23)

where 𝐏t\mathbf{P}_{t} is a projection matrix which chooses the batch. It is a D×DD\times D diagonal matrix with exactly BB random 11s on the diagonal, with all other entries 0. This corresponds to choosing BB random elements, without replacement, at each timestep.

For SAM with SGD, the SAM step is replaced with ρ𝐉t𝐏t𝐳t\rho\mathbf{J}_{t}\mathbf{P}_{t}\mathbf{z}_{t} as well. Expanding to lowest order, we have:

𝐳t+1𝐳t=η(𝐉t𝐉t+ρ(𝐉t𝐉t)𝐏t(𝐉t𝐉t))𝐏t𝐳t+O(𝐳2)\mathbf{z}_{t+1}-\mathbf{z}_{t}=-\eta(\mathbf{J}_{t}\mathbf{J}_{t}^{\top}+\rho(\mathbf{J}_{t}\mathbf{J}_{t}^{\top})\mathbf{P}_{t}(\mathbf{J}_{t}\mathbf{J}_{t}^{\top}))\mathbf{P}_{t}\mathbf{z}_{t}+O(||\mathbf{z}||^{2}) (24)
𝐉t+1𝐉t=η𝑸(𝐉t(1+ρ𝐏t𝐉t𝐉t)𝐏t𝐳t,)ρη𝑸([𝐏t𝐳t]𝑸(𝐉t𝐏t𝐳t,),)+O(ρ2η2𝐳2)+O(η3𝐳3)\begin{split}\mathbf{J}_{t+1}-\mathbf{J}_{t}&=-\eta\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}(1+\rho\mathbf{P}_{t}\mathbf{J}_{t}\mathbf{J}_{t}^{\top})\mathbf{P}_{t}\mathbf{z}_{t},\cdot)-\rho\eta\bm{\mathsfit{Q}}([\mathbf{P}_{t}\mathbf{z}_{t}]^{\top}\bm{\mathsfit{Q}}(\mathbf{J}_{t}^{\top}\mathbf{P}_{t}\mathbf{z}_{t},\cdot),\cdot)\\ &+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3})\end{split} (25)

Consider the dynamics of 𝐳\mathbf{z}. Taking the average over 𝐏t\mathbf{P}_{t}, we note that E[𝐏]=β𝐈{\rm E}[\mathbf{P}]=\beta\mathbf{I}. For any fixed D×DD\times D matrix 𝐌\mathbf{M}, we also have:

E[𝐏t𝐌𝐏t]=β2𝐌+β(1β)diag(𝐌)+O(D1){\rm E}[\mathbf{P}_{t}\mathbf{M}\mathbf{P}_{t}]=\beta^{2}\mathbf{M}+\beta(1-\beta){\rm diag}(\mathbf{M})+O(D^{-1}) (26)

Substituting, we have:

E𝐏t[𝐳t+1𝐳t]=ηβ𝐉t𝐉t(1+ρ[β(𝐉t𝐉t)+(1β)diag(𝐉t𝐉t)])𝐳t+O(𝐳2)+O(D1){\rm E}_{\mathbf{P}_{t}}[\mathbf{z}_{t+1}-\mathbf{z}_{t}]=-\eta\beta\mathbf{J}_{t}\mathbf{J}_{t}^{\top}(1+\rho[\beta(\mathbf{J}_{t}\mathbf{J}_{t}^{\top})+(1-\beta){\rm diag}(\mathbf{J}_{t}\mathbf{J}_{t}^{\top})])\mathbf{z}_{t}+O(||\mathbf{z}||^{2})+O(D^{-1}) (27)

Now consider the dynamics of 𝐉\mathbf{J}. First we averaging over random initial 𝑸\bm{\mathsfit{Q}}. At the first step we have:

E𝑸[𝐉1𝐉0]αi=ρηE[𝑸αij(𝐏𝐳)β𝑸βjk𝐉γk(𝐏𝐳)γ]+O(ρ2η2𝐳2)+O(η3𝐳3){\rm E}_{\bm{\mathsfit{Q}}}[\mathbf{J}_{1}-\mathbf{J}_{0}]_{\alpha i}=-\rho\eta{\rm E}[\bm{\mathsfit{Q}}_{\alpha ij}(\mathbf{P}\mathbf{z})_{\beta}\bm{\mathsfit{Q}}_{\beta jk}\mathbf{J}_{\gamma k}(\mathbf{P}\mathbf{z})_{\gamma}]+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3}) (28)
E𝑸[𝐉1𝐉0]αi=ρηP(𝐏𝐳)α𝐉γi(𝐏𝐳)γ+O(ρ2η2𝐳2)+O(η3𝐳3){\rm E}_{\bm{\mathsfit{Q}}}[\mathbf{J}_{1}-\mathbf{J}_{0}]_{\alpha i}=-\rho\eta P(\mathbf{P}\mathbf{z})_{\alpha}\mathbf{J}_{\gamma i}(\mathbf{P}\mathbf{z})_{\gamma}+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3}) (29)

Averaging over 𝐏\mathbf{P} as well, we have:

E𝑸,𝐏[𝐉1𝐉0]=ρηP(β2𝐳𝐳+β(1β)diag(𝐳𝐳))𝐉+O(ρ2η2𝐳2)+O(η3𝐳3)+O(D1){\rm E}_{\bm{\mathsfit{Q}},\mathbf{P}}[\mathbf{J}_{1}-\mathbf{J}_{0}]=-\rho\eta P(\beta^{2}\mathbf{z}\mathbf{z}^{\top}+\beta(1-\beta){\rm diag}(\mathbf{z}\mathbf{z}^{\top}))\mathbf{J}+O(\rho^{2}\eta^{2}||\mathbf{z}||^{2})+O(\eta^{3}||\mathbf{z}||^{3})+O(D^{-1}) (30)

Theorem 2.1 can be derived by first setting β=1\beta=1. Given a singular value σα\sigma_{\alpha} corresponding to singular vectors 𝐰al\mathbf{w}_{al} and 𝐯α\mathbf{v}_{\alpha} we have σα=𝐰α𝐉𝐯α\sigma_{\alpha}=\mathbf{w}_{\alpha}^{\top}\mathbf{J}\mathbf{v}_{\alpha}. For small learning rates, the singular value of 𝐉1\mathbf{J}_{1} can be written in terms of the SVD of 𝐉0\mathbf{J}_{0} as

σα(𝐉1)=𝐰α(𝐉0)𝐉1𝐯α(𝐉0)+O(η2)\sigma_{\alpha}(\mathbf{J}_{1})=\mathbf{w}_{\alpha}(\mathbf{J}_{0})^{\top}\mathbf{J}_{1}\mathbf{v}_{\alpha}(\mathbf{J}_{0})+O(\eta^{2}) (31)

Therefore we can write

σα(𝐉1)σα(𝐉0)=𝐰α(𝐉0)(𝐉1𝐉0)𝐯α(𝐉0)+O(η2)\sigma_{\alpha}(\mathbf{J}_{1})-\sigma_{\alpha}(\mathbf{J}_{0})=\mathbf{w}_{\alpha}(\mathbf{J}_{0})^{\top}(\mathbf{J}_{1}-\mathbf{J}_{0})\mathbf{v}_{\alpha}(\mathbf{J}_{0})+O(\eta^{2}) (32)

Averaging over 𝑸\bm{\mathsfit{Q}} and 𝐏\mathbf{P} completes the theorem.

A.3 Numerical results

The numerical results in Figure 2 were obtained by training the model defined by the update Equation 6 in 𝐳\mathbf{z} and 𝐉\mathbf{J} space directly. The tensor 𝑸\bm{\mathsfit{Q}} was randomly initialized with i.i.d. Gaussian elements at initialization, and 𝐳\mathbf{z} and 𝐉\mathbf{J} were randomly initialized as well following the approach in (Agarwala et al., 2022). The figures correspond to 55 independent initializations with the same statistics for 𝑸\bm{\mathsfit{Q}}, 𝐳\mathbf{z}, and 𝐉\mathbf{J}. All plots used D=200D=200 datapoints with P=400P=400 parameters.

For small η\eta, the loss converges exponentially to 0. In particular, the projection onto the largest eigenmode decreases quickly , which by the analysis of Theorem 2.1 suggests that SAM has only a small effect on the largest eigenvalues.

For larger η\eta, by increasing ρ\rho the SAM dynamics seems to access the edge of stability regime, where non-linear effects can stabilize the large eigenvalues of the curvature. One way the original edge of stability dynamics was explored was to examine trajectories at different learning rates (Cohen et al., 2022a). At small learning rate, training loss decreases monotonically; at intermediate learning rates, the edge of stability behavior causes non-monotonic but still stable learning trajectories, and finally, at large learning rate the training is unstbale.

We can similarly increase the SAM radius ρ\rho for fixed learning rate, and see an analogous transition. If we pick η\eta such that the trajectory doesn’t reach the non-linear edge of stability regime, and increase ρ\rho, we see that SAM eventually leads to a decrease in the largest eigenvalues, before leading to unstable behavior (Figure 7, left). If we plot η(λmax+ρλmax2)\eta(\lambda_{max}+\rho\lambda_{max}^{2}), we see that this normalized, effective eigenvalue stabilizes very close to 22 for a range of ρ\rho, and for larger ρ\rho stabilizes near but not exactly at 22 (Figure 7, right). We leave a more detailed understanding of this stabilization for future work.

Refer to caption Refer to caption
Figure 7: For fixed η\eta, as ρ\rho increases the largest eigenvalue of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} decreases, until training is no longer stable (left). For intermediate ρ\rho, the eigenvalue is very well predicted by η(λmax+ρλmax2)=2\eta(\lambda_{max}+\rho\lambda_{max}^{2})=2 (right); however there is also a range of ρ\rho where training is still stable but η(λmax+ρλmax2)>2\eta(\lambda_{max}+\rho\lambda_{max}^{2})>2 (purple curve).

Appendix B SAM edge of stability

B.1 Proof of Theorem 2.2

We prove the following theorem, which gives us an alternate edge of stability for SAM:

Theorem 2.2.

Consider a 𝒞\mathcal{C}^{\infty} model 𝐟(𝜽)\mathbf{f}(\bm{\theta}) trained using Equation 6 with MSE loss. Suppose that there exists a point 𝜽\bm{\theta}^{*} where 𝐳(𝜽)=0\mathbf{z}(\bm{\theta}^{*})=0. Suppose that for some ϵ>0{\epsilon}>0, we have the lower bound ϵ<ηλi(1+ρλi){\epsilon}<\eta\lambda_{i}(1+\rho\lambda_{i}) for the eigenvalues of the positive-definite symmetric matrix 𝐉(𝜽)𝐉(𝜽)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}. Given bounds on the largest eigenvalues, there are two regimes:

Convergent regime. If ηλi(1+ρλi)<2ϵ\eta\lambda_{i}(1+\rho\lambda_{i})<2-{\epsilon} for all for all eigenvalues λi\lambda_{i} of 𝐉(𝜽)𝐉(𝜽)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}, there exists a neighborhood UU of 𝜽\bm{\theta}^{*} such that limt𝐳t=0\lim_{t\to\infty}\mathbf{z}_{t}=0 with exponential convergence for any trajectory initialized at 𝜽0U\bm{\theta}_{0}\in U.

Divergent regime. If ηλi(1+ρλi)>2+ϵ\eta\lambda_{i}(1+\rho\lambda_{i})>2+{\epsilon} for some eigenvector 𝐯i\mathbf{v}_{i} of 𝐉(𝜽)𝐉(𝜽)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top}, then there exists some qminq_{min} such that for any q<qminq<q_{min}, given Bq(𝜽)B_{q}(\bm{\theta}^{*}), the ball of radius qq around 𝜽\bm{\theta}^{*}, there exists some initialization 𝜽0Bq(𝜽)\bm{\theta}_{0}\in B_{q}(\bm{\theta}^{*}) such that the trajectory {𝜽t}\{\bm{\theta}_{t}\} leaves Bq(𝜽)B_{q}(\bm{\theta}^{*}) at some time tt.

Proof.

The SAM update for MSE loss can be written as:

𝜽t+1𝜽t=η𝐉(𝜽t+ρ𝐉t𝐳t)𝐳(𝜽t+ρ𝐉t𝐳t)\bm{\theta}_{t+1}-\bm{\theta}_{t}=-\eta\mathbf{J}^{\top}(\bm{\theta}_{t}+\rho\mathbf{J}^{\top}_{t}\mathbf{z}_{t})\mathbf{z}(\bm{\theta}_{t}+\rho\mathbf{J}^{\top}_{t}\mathbf{z}_{t}) (33)

We will use the differentiability of f(𝜽)f(\bm{\theta}) to Taylor expand the update step, and divide it into a dominant linear piece which leads to convergence, and an higher order term which can be safely ignored in the long term dynamics.

Since the model f(𝜽)f(\bm{\theta}) is analytic at 𝜽\bm{\theta}^{*}, there is a neighborhood UrU_{r} of 𝜽\bm{\theta}^{*} with the following properties: for 𝜽Ur\bm{\theta}\in U_{r}, 𝐳\mathbf{z} and 𝐉\mathbf{J} there exists a radius rr such that

𝐳(𝜽+Δ𝜽)𝐳(𝜽)=𝐉Δ𝜽+122𝐳𝜽𝜽(Δ𝜽,Δ𝜽)+\mathbf{z}(\bm{\theta}+\Delta\bm{\theta})-\mathbf{z}(\bm{\theta})=\mathbf{J}\Delta\bm{\theta}+\frac{1}{2}\frac{\partial^{2}\mathbf{z}}{\partial\bm{\theta}\partial\bm{\theta}^{\prime}}(\Delta\bm{\theta},\Delta\bm{\theta})+\ldots (34)
𝐉(𝜽+Δ𝜽)𝐉(𝜽)=2𝐳𝜽𝜽(Δ𝜽,)+123𝐳𝜽1𝜽2𝜽3(Δ𝜽,Δ𝜽,)+\mathbf{J}(\bm{\theta}+\Delta\bm{\theta})-\mathbf{J}(\bm{\theta})=\frac{\partial^{2}\mathbf{z}}{\partial\bm{\theta}\partial\bm{\theta}^{\prime}}(\Delta\bm{\theta},\cdot)+\frac{1}{2}\frac{\partial^{3}\mathbf{z}}{\partial\bm{\theta}_{1}\partial\bm{\theta}_{2}\partial\bm{\theta}_{3}}(\Delta\bm{\theta},\Delta\bm{\theta},\cdot)+\ldots (35)

for all Δ𝜽<r||\Delta\bm{\theta}||<r. The derivatives which define the power series are taken at 𝜽\bm{\theta}. By Taylor’s theorem, there exists some b>0b>0 such that we have the bounds

𝐳(𝜽+Δ𝜽)𝐳(𝜽)𝐉Δ𝜽bΔ𝜽2||\mathbf{z}(\bm{\theta}+\Delta\bm{\theta})-\mathbf{z}(\bm{\theta})-\mathbf{J}\Delta\bm{\theta}||\leq b||\Delta\bm{\theta}||^{2} (36)
𝐉(𝜽+Δ𝜽)𝐉(𝜽)2𝐳𝜽𝜽(Δ𝜽,)bΔ𝜽2||\mathbf{J}(\bm{\theta}+\Delta\bm{\theta})-\mathbf{J}(\bm{\theta})-\frac{\partial^{2}\mathbf{z}}{\partial\bm{\theta}\partial\bm{\theta}^{\prime}}(\Delta\bm{\theta},\cdot)||\leq b||\Delta\bm{\theta}||^{2} (37)

for Δ𝜽<r||\Delta\bm{\theta}||<r uniformly over UrU_{r}.

In addition, since 𝐉(𝜽)𝐉(𝜽)\mathbf{J}(\bm{\theta}^{*})\mathbf{J}(\bm{\theta}^{*})^{\top} has ϵ<ηλi(1+ρλi){\epsilon}<\eta\lambda_{i}(1+\rho\lambda_{i}) for all eigenvalues λi\lambda_{i}, there exists a neighborhood Vϵ,1/2V_{{\epsilon},1/2} of 𝜽\bm{\theta}^{*} such that ϵ/2<ηλi(1+ρλi){\epsilon}/2<\eta\lambda_{i}(1+\rho\lambda_{i}) for all eigenvalues λi\lambda_{i} of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top}, as well as λmax\lambda_{max} of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} is bounded by ηλi(1+ρλi)<2ϵ/2\eta\lambda_{i}(1+\rho\lambda_{i})<2-{\epsilon}/2 in the convergent case, and 2λmax(𝜽)2\lambda_{max}(\bm{\theta}^{*}) in the divergent case for any 𝜽Vϵ,1/2\bm{\theta}\in V_{{\epsilon},1/2}. Finally, for any d>0d>0, there exists a connected neighborhood TdT_{d} of 𝜽\bm{\theta} given by the set of points where 𝐳<d||\mathbf{z}||<d.

Consider the (non-empty) neighborhood Xr,a,d=TdUrVϵ,1/2X_{r,a,d}=T_{d}\cap U_{r}\cap V_{{\epsilon},1/2} given by the intersection of these sets. To recap, points 𝜽\bm{\theta} in Xr,dX_{r,d} have the following properties:

  • 𝐳\mathbf{z} and 𝐉\mathbf{J} have power series representations around 𝜽\bm{\theta} with radius r>0r>0.

  • The second-order and higher terms are bounded by bΔ𝜽2b||\Delta\bm{\theta}||^{2} uniformly on Xr,dX_{r,d}, independently of dd.

  • 𝐳(𝜽)<d||\mathbf{z}(\bm{\theta})||<d.

  • The eigenvalues of 𝐉(𝜽)𝐉(𝜽)\mathbf{J}(\bm{\theta})\mathbf{J}(\bm{\theta})^{\top} are bounded from below by ϵ/2<ηλi(1+ρλi){\epsilon}/2<\eta\lambda_{i}(1+\rho\lambda_{i}) and above by ηλi(1+ρλi)<2ϵ/2\eta\lambda_{i}(1+\rho\lambda_{i})<2-{\epsilon}/2 (convergent case) or 2λmax(𝜽)2\lambda_{max}(\bm{\theta}^{*}) (divergent case).

We now proceed with analyzing the dynamics. If ρ𝐉t𝐳t<r||\rho\mathbf{J}_{t}\mathbf{z}_{t}||<r, then we have:

𝜽t+1𝜽t=η(𝐉t+ρ𝐉t𝐉t𝐉t)𝐳t+O(bρ𝐉t𝐳t2)\bm{\theta}_{t+1}-\bm{\theta}_{t}=-\eta(\mathbf{J}_{t}^{\top}+\rho\mathbf{J}_{t}^{\top}\mathbf{J}_{t}\mathbf{J}_{t}^{\top})\mathbf{z}_{t}+O(b||\rho\mathbf{J}_{t}^{\top}\mathbf{z}_{t}||^{2}) (38)

We note that ρ𝐉t𝐳t<A𝐳t||\rho\mathbf{J}_{t}\mathbf{z}_{t}||<A||\mathbf{z}_{t}|| on Xr,dX_{r,d} for some constant AA independent of dd, since the singular values of 𝐉t\mathbf{J}_{t} are bounded uniformly from above. Therefore, if we choose d<r/Ad<r/A, the power series representation exists for all 𝜽Xr,d\bm{\theta}\in X_{r,d}.

If 𝜽t+1𝜽t<r||\bm{\theta}_{t+1}-\bm{\theta}_{t}||<r, then both 𝐳(𝜽t+1)𝐳(𝜽t)\mathbf{z}(\bm{\theta}_{t+1})-\mathbf{z}(\bm{\theta}_{t}) as well as 𝐉(𝜽t+1)𝐉(𝜽t)\mathbf{J}(\bm{\theta}_{t+1})-\mathbf{J}(\bm{\theta}_{t}) can be represented as power series centered around 𝜽t\bm{\theta}_{t}. We can again use the uniform bound on the singular values of 𝐉\mathbf{J}, as well as the uniform bound on the error terms to choose dd small enough such that 𝜽t+1𝜽t<r||\bm{\theta}_{t+1}-\bm{\theta}_{t}||<r always for 𝜽tXr,d\bm{\theta}_{t}\in X_{r,d}.

Therefore, for sufficiently small dd, we have:

𝐳(𝜽t+1)𝐳(𝜽t)=𝐳t+1𝐳t=η𝐉t𝐉t[(1+ρ𝐉t𝐉t)𝐳t]+O(h𝐳t2)\mathbf{z}(\bm{\theta}_{t+1})-\mathbf{z}(\bm{\theta}_{t})=\mathbf{z}_{t+1}-\mathbf{z}_{t}=-\eta\mathbf{J}_{t}\mathbf{J}_{t}^{\top}[(1+\rho\mathbf{J}_{t}\mathbf{J}_{t}^{\top})\mathbf{z}_{t}]+O(h||\mathbf{z}_{t}||^{2}) (39)
𝐉(𝜽t+1)𝐉(𝜽t)=η2𝐳𝜽𝜽(𝐉t𝐳t,)+O(h𝐳t2)\mathbf{J}(\bm{\theta}_{t+1})-\mathbf{J}(\bm{\theta}_{t})=-\eta\frac{\partial^{2}\mathbf{z}}{\partial\bm{\theta}\partial\bm{\theta}^{\prime}}(\mathbf{J}_{t}^{\top}\mathbf{z}_{t},\cdot)+O(h||\mathbf{z}_{t}||^{2}) (40)

for some constant hh independent of dd.

We first analyze the dynamics in the convergent case. We first establish that 𝐳2||\mathbf{z}||^{2} decreases exponentially at each step, and then confirm that the trajectory remains in Xϵ,1/2X_{{\epsilon},1/2}. Consider a single step in the eigenbasis of 𝐉t𝐉t\mathbf{J}_{t}\mathbf{J}_{t}^{\top}. Let z(i)z(i) be the projection 𝐯i𝐳\mathbf{v}_{i}\cdot\mathbf{z} for eigenvector 𝐯i\mathbf{v}_{i} corresponding to eigenvalue λi\lambda_{i}. Then we have:

z(i)t+12z(i)t2=(ηλi(1+ρλi)z(i)t+O(𝐳t2))([2ηλi(1+ρλi)]z(i)t+O(𝐳t2))z(i)^{2}_{t+1}-z(i)^{2}_{t}=(-\eta\lambda_{i}(1+\rho\lambda_{i})z(i)_{t}+O(||\mathbf{z}_{t}||^{2}))([2-\eta\lambda_{i}(1+\rho\lambda_{i})]z(i)_{t}+O(||\mathbf{z}_{t}||^{2})) (41)

From our bounds, we have

z(i)t+12z(i)t212ϵz(i)t2+c𝐳t3z(i)^{2}_{t+1}-z(i)^{2}_{t}\leq-\frac{1}{2}{\epsilon}z(i)_{t}^{2}+c||\mathbf{z}_{t}||^{3} (42)

By uniformity of the Taylor approximation we again have that cc is uniform, independent of aa and dd. Summing over the eigenmodes, we have:

𝐳t+12𝐳t212ϵ𝐳t2+Dc𝐳t3||\mathbf{z}_{t+1}||^{2}-||\mathbf{z}_{t}||^{2}\leq-\frac{1}{2}{\epsilon}||\mathbf{z}_{t}||^{2}+Dc||\mathbf{z}_{t}||^{3} (43)

If we choose d<ϵ4cDd<\frac{{\epsilon}}{4cD}, then we have

𝐳t+12𝐳t214ϵ𝐳t2||\mathbf{z}_{t+1}||^{2}-||\mathbf{z}_{t}||^{2}\leq-\frac{1}{4}{\epsilon}||\mathbf{z}_{t}||^{2} (44)

Therefore 𝐳t+12||\mathbf{z}_{t+1}||^{2} decreases by a factor of at least 1ϵ/41-{\epsilon}/4 each step.

In order to complete the proof over the convergent regime, we need to show that 𝐉t\mathbf{J}_{t} changes by a small enough amount that the upper and lower bounds on the eigenvalues are still valid - that is, the trajectory remains in Xϵ,1/2X_{{\epsilon},1/2}. Suppose the dynamics begins with initial residuals 𝐳0\mathbf{z}_{0}, and remains within Xϵ,1/2X_{{\epsilon},1/2} for tt steps. Consider the t+1t+1th step. The total change in 𝐉\mathbf{J} can be bounded by:

𝐉t+1𝐉0Bt𝐳t+Ct𝐳t2||\mathbf{J}_{t+1}-\mathbf{J}_{0}||\leq B\sum_{t}||\mathbf{z}_{t}||+C\sum_{t}||\mathbf{z}_{t}||^{2} (45)

for some constants BB and CC independent of dd. The first term comes from a uniform upper bound on η2𝐳𝜽𝜽(𝐉t𝐳t,)-\eta\frac{\partial^{2}\mathbf{z}}{\partial\bm{\theta}\partial\bm{\theta}^{\prime}}(\mathbf{J}_{t}^{\top}\mathbf{z}_{t},\cdot), and the second term comes from the uniform upper bound on the higher order corrections to changes in 𝐉\mathbf{J} for each step. Using the bound on 𝐳t||\mathbf{z}_{t}||, we have:

𝐉t+1𝐉04(B+C)ϵ𝐳0||\mathbf{J}_{t+1}-\mathbf{J}_{0}||\leq\frac{4(B+C)}{{\epsilon}}||\mathbf{z}_{0}|| (46)

If the right hand side of the inequality is less than ϵ(1+δ)/2{\epsilon}^{(1+\delta)/2}, for any δ>0\delta>0, then the change in the singular values is o(ϵ1/2)o({\epsilon}^{1/2}), the change in the eigenvalues of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} is o(ϵ)o({\epsilon}), and the trajectory remains in Vϵ,1/2V_{{\epsilon},1/2} at time t+1t+1. Let d14(B+C)ϵ(3+δ)/2d\leq\frac{1}{4(B+C)}{\epsilon}^{(3+\delta)/2}. Then, 𝐉t+1𝐉0ϵ(1+δ)/2||\mathbf{J}_{t+1}-\mathbf{J}_{0}||\leq{\epsilon}^{(1+\delta)/2} for all tt.

Therefore the trajectory remains within Xr,dX_{r,d}, and 𝐳t||\mathbf{z}_{t}|| converges exponentially to 0, for any dd sufficiently small. Therefore there is a neighborhood of 𝜽\bm{\theta}^{*} where 𝐳||\mathbf{z}|| converges exponentially to 0.

Now we consider the divergent regime. We will show that we can find initializations with arbitrarily small 𝐳||\mathbf{z}|| and 𝜽𝜽||\bm{\theta}-\bm{\theta}^{*}|| which eventually have increasing 𝐳||\mathbf{z}||.

Since 𝐉𝐉\mathbf{J}\mathbf{J}^{\top} is full rank, there exists some 𝜽0\bm{\theta}_{0} in any neighborhood of 𝜽\bm{\theta}^{*} such that |𝐯m𝐳(𝜽0)|>0|\mathbf{v}_{m}\cdot\mathbf{z}(\bm{\theta}_{0})|>0 where 𝐯m\mathbf{v}_{m} is the direction of the largest eigenvalue of 𝐉𝐉\mathbf{J}\mathbf{J}^{\top}. Consider such a 𝜽0\bm{\theta}_{0} in Xr,dX_{r,d} (and therefore in TdT_{d} as well. The change in the magnitude of this component mm of 𝐳\mathbf{z} is bounded from below by

z(m)12z(m)0212ϵz(m)12c𝐳03z(m)^{2}_{1}-z(m)^{2}_{0}\geq\frac{1}{2}{\epsilon}z(m)_{1}^{2}-c||\mathbf{z}_{0}||^{3} (47)

Again the correction is uniformly bounded independent of dd. Therefore the bound becomes

z(m)12z(m)0214ϵz(m)02z(m)^{2}_{1}-z(m)^{2}_{0}\geq\frac{1}{4}{\epsilon}z(m)_{0}^{2} (48)

Choose dmind_{min} such that the above bound holds for d<dmind<d_{min}. Furthermore, choose qminq_{min} so that the ball Bqmin(𝜽)Xr,dminB_{q_{min}}(\bm{\theta}^{*})\subset X_{r,d_{min}}. Pick an initialization 𝜽0Bq(𝜽)\bm{\theta}_{0}\in B_{q}(\bm{\theta}^{*}) for q<qminq<q_{min}.

After a single step, there are two possibilities. The first is that 𝜽1\bm{\theta}_{1} is no longer in Bqmin(𝜽)B_{q_{min}}(\bm{\theta}^{*}). In this case the trajectory has left Bq(𝜽)B_{q}(\bm{\theta}^{*}) and the proof is complete.

The second is that 𝜽1\bm{\theta}_{1} remains in Bqmin(𝜽)B_{q_{min}}(\bm{\theta}^{*}). In this case, z(m)12z(m)_{1}^{2} is bounded from below by (1+1/4ϵ)z(m)02(1+1/4{\epsilon})z(m)_{0}^{2}. If the trajectory remains in Bqmin(𝜽)B_{q_{min}}(\bm{\theta}^{*}) for tt steps, we have the bound:

𝐳t2(1+1/4ϵ)tz(m)02||\mathbf{z}_{t}||^{2}\geq(1+1/4{\epsilon})^{t}z(m)_{0}^{2} (49)

Therefore, at some finite time tt, 𝐳t2d||\mathbf{z}_{t}||^{2}\geq d, and 𝜽\bm{\theta} leaves Xr,dminX_{r,d_{min}}. Therefore it leaves Bq(𝜽)B_{q}(\bm{\theta}^{*}). This is true for any q<qminq<q_{min}. This completes the proof for the divergent case.

Appendix C CIFAR-10 experiment details

C.1 Cross-entropy loss

Many of the trends observed using MSE loss in Section 4 can also be observed for cross-entropy loss. Eigenvalues generally increase at late times, and there is still a regime where SGD shows EOS behavior in ηλmax\eta\lambda_{max}, while SAM shows EOS behavior in η(λmax+ρλmax2)\eta(\lambda_{max}+\rho\lambda_{max}^{2}) (Figure 8). In addition, the gradient norm is stil stable for much of training, with SGD gradient norm decreasing at the end of training while SAM gradient norms stay relatively constant (Figure 8).

There are also qualitative differences in the behavior. For example, the eigenvalue decrease starts earlier in training. Decreasing eigenvalues for cross-entropy loss have been previously observed (Cohen et al., 2022a), and there is evidence that the origin of the effect is due to the interaction of the logit magnitude with the softmax function. The gradient magnitudes also have an initial rapid fall-off period. Overall more study is needed to understand how the effects and mechanisms used by SAM depend on the loss used.

Refer to caption Refer to caption Refer to caption
Figure 8: Largest Hessian eigenvalues for CIFAR10 trained with cross-entropy loss. Trends are similar to MSE loss (Figure 4), with the exception that normalized eigenvalues decrease from an earlier time.
Refer to caption
Figure 9: Minibatch gradient magnitudes for CIFAR-10 model trained on cross-entropy loss. Trends are similar to MSE loss (Figure 8), with larger overall variation in gradient values.