This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

AdaFisher: Adaptive Second Order Optimization via Fisher Information

Damien Martins Gomes
Concordia University and IPSA Toulouse
[email protected]
&Yanlei Zhang
Université de Montréal and Mila
[email protected]
\ANDEugene Belilovsky
Concordia University and Mila
[email protected]
&Guy Wolf
Université de Montréal and Mila
[email protected]
&Mahdi S. Hosseini
Concordia University and Mila
[email protected]
To my father and grandmother, whose strength and love continue to inspire me, this work is dedicated.Corresponding Author
Abstract

First-order optimization methods are currently the mainstream in training deep neural networks (DNNs). Optimizers like Adam incorporate limited curvature information by employing the diagonal matrix preconditioning of the stochastic gradient during the training. Despite their widespread, second-order optimization algorithms exhibit superior convergence properties compared to their first-order counterparts e.g. Adam and SGD. However, their practicality in training DNNs are still limited due to increased per-iteration computations and suboptimal accuracy compared to the first order methods. We present AdaFisher–an adaptive second-order optimizer that leverages a block-diagonal approximation to the Fisher information matrix for adaptive gradient preconditioning. AdaFisher aims to bridge the gap between enhanced convergence capabilities and computational efficiency in second-order optimization framework for training DNNs. Despite the slow pace of second-order optimizers, we showcase that AdaFisher can be reliably adopted for image classification, language modelling and stand out for its stability and robustness in hyperparameter tuning. We demonstrate that AdaFisher outperforms the SOTA optimizers in terms of both accuracy and convergence speed. Code is available from https://github.com/AtlasAnalyticsLab/AdaFisher.

1 Background

We consider a supervised learning framework with a dataset 𝐃\mathbf{D} containing NN i.i.d samples, 𝐃:={xn,yn}n=1N\mathbf{D}:=\{x_{n},y_{n}\}_{n=1}^{N} where xndx_{n}\in\mathbb{R}^{d} and ynCy_{n}\in\mathbb{R}^{C}. Let fθ:dCf_{\theta}:\mathbb{R}^{d}\rightarrow\mathbb{R}^{C} be a L-layer neural network parametrized by θ\theta where θi=concat(Wi,bi)Pi\theta_{i}=concat(W_{i},b_{i})\in\mathbb{R}^{P_{i}}, and Pi=Piout×(Piin+1)P_{i}=P_{i}^{out}\times(P_{i}^{in}+1). Let :C×C\mathcal{L}:\mathbb{R}^{C}\times\mathbb{R}^{C}\rightarrow\mathbb{R} be the loss function defined by negative log likelihood, i.e. (y,fθ(x)):=logpθ(y|x)\mathcal{L}(y,f_{\theta}(x)):=-\log p_{\theta}(y|x) where pθ(y|x)p_{\theta}(y|x) is the likelihood of the neural network fθf_{\theta}. The network computes its output hL=fθ(x)h_{L}=f_{\theta}(x) according to: ai=θih¯i1a_{i}=\theta_{i}\bar{h}_{i-1}, hi=ϕi(ai)h_{i}=\phi_{i}(a_{i}), i{1,,L}|h0=xn\forall\,i\in\{1,\dots,L\}\,\,|\,\,h_{0}=x_{n} where h¯i=[1,hiT]TPiin+1\bar{h}_{i}=[1,h_{i}^{T}]^{T}\in\mathbb{R}^{P_{i}^{in}+1} terminated by z:=hLPLoutz:=h_{L}\in\mathbb{R}^{P_{L}^{out}}. For a given input target pair (x,y)(x,y), the gradient of the loss (y,fθ(x))\mathcal{L}(y,f_{\theta}(x)) concerning the weights are computed by the backpropagation algorithm (Lecun, 2001). For convenience, we adopt the special symbol si=ais_{i}=\nabla_{a_{i}}\mathcal{L} for the pre-activation derivative. Starting from hL=z(y,z=hL)\nabla_{h_{L}}\mathcal{L}=\partial_{z}\mathcal{L}(y,z=h_{L}), we perform: si:=ai=hiϕi(ai),θi=sih¯i1T,h¯i1=θiTsi|i{L,,1}s_{i}:=\nabla_{a_{i}}\mathcal{L}=\nabla_{h_{i}}\mathcal{L}\odot\phi_{i}^{\prime}(a_{i}),\,\nabla_{\theta_{i}}\mathcal{L}=s_{i}\bar{h}_{i-1}^{T},\,\nabla_{\bar{h}_{i-1}}\mathcal{L}=\theta_{i}^{T}s_{i}\quad|\,\,\forall i\in\{L,\dots,1\}, where \odot denotes the element-wise product. Finally, the gradient θ\nabla_{\theta}\mathcal{L} is retrieved by: θ=[vec(θ1)T,vec(θ2)T,,vec(θL)T]T\nabla_{\theta}\mathcal{L}=[\text{vec}(\nabla_{\theta_{1}}\mathcal{L})^{T},\text{vec}(\nabla_{\theta_{2}}\mathcal{L})^{T},\dots,\text{vec}(\nabla_{\theta_{L}}\mathcal{L})^{T}]^{T}.

Optimization of a DNN can be recast as a problem of finding the parameter set θ\theta that maximizes the likelihood, or equivalently, minimizes the negative log-likelihood of the observed data. This Maximum Likelihood Estimation approach can be expressed as an unconstrained optimization problem: minθJ(θ)=n=1N(yn,fθ(xn))\min_{\theta}\,J(\theta)=\sum_{n=1}^{N}\mathcal{L}(y_{n},f_{\theta}(x_{n})), where J(θ)J(\theta) denotes the objective function, corresponding to the negative log-likelihood of the data. The FIM, utilized in lieu of the Hessian for Newton’s method (Holmgren, 1996), approximates the curvature of the log-likelihood function (Amari, 1998). It is defined as:

F=n=1N𝔼yp(y|fθ(xn))[θlogpθ(y|xn)θlogpθ(y|xn)T]\displaystyle F=\sum_{n=1}^{N}\mathbb{E}_{y\sim p(y|f_{\theta}(x_{n}))}\left[\nabla_{\theta}\log p_{\theta}(y|x_{n})\nabla_{\theta}\log p_{\theta}(y|x_{n})^{T}\right] =𝔼[θ(θ)T],\displaystyle=\mathbb{E}\left[\nabla_{\theta}\mathcal{L}(\nabla_{\theta}\mathcal{L})^{T}\right], (1)

where FF measures the expected information that an observable yy conveys about the parameter θ\theta. For brevity, we write 𝔼\mathbb{E} instead of 𝔼yp(y|fθ(xn))\mathbb{E}_{y\sim p(y|f_{\theta}(x_{n}))}. The K-FAC approach further simplifies FIM calculation using a block-diagonal approximation in DNNs, known as Empirical FIM (EFIM), denoted F^\hat{F}. In Eq. (1), FF is construed as a block matrix with dimensions L×LL\times L, where each (i,j)(i,j)th block Fi,jF_{i,j} is articulated by Fi,j=𝔼[vec(θi)vec(θj)T]F_{i,j}=\mathbb{E}[\text{vec}(\nabla_{\theta_{i}}\mathcal{L})\text{vec}(\nabla_{\theta_{j}}\mathcal{L})^{T}]. By harnessing the vectorization identity vec(uvT)=vu\text{vec}(uv^{T})=v\otimes u, we express vec(θi)\text{vec}(\nabla_{\theta_{i}}\mathcal{L}) as h¯i1si\bar{h}_{i-1}\otimes s_{i} (Petersen & Pedersen, 2008), where θi\nabla_{\theta_{i}}\mathcal{L} is defined as sih¯i1Ts_{i}\bar{h}_{i-1}^{T}. By segmenting the FIM into discrete layer-specific blocks, we can effectuate a systematic factorization of each block

F^i,j=𝔼[vec(θi)vec(θj)T]=𝔼[h¯i1h¯j1TsisjT]𝔼[h¯i1Th¯j1]𝔼[siTsj],\displaystyle\hat{F}_{i,j}=\mathbb{E}[\text{vec}(\nabla_{\theta_{i}}\mathcal{L})\text{vec}(\nabla_{\theta_{j}}\mathcal{L})^{T}]=\mathbb{E}[\bar{h}_{i-1}\bar{h}_{j-1}^{T}\otimes s_{i}s_{j}^{T}]\approx\mathbb{E}[\bar{h}_{i-1}^{T}\bar{h}_{j-1}]\otimes\mathbb{E}[s_{i}^{T}s_{j}],
Refer to caption
Figure 1: Illustration of EFIM computation using K-FAC for a given layer ii.

where i,ji,j span the layer indices from 1 to LL. Here, 𝔼[h¯i1Th¯j1]\mathbb{E}[\bar{h}_{i-1}^{T}\bar{h}_{j-1}] and 𝔼[siTsj]\mathbb{E}[s_{i}^{T}s_{j}] are empirically approximated using batch statistics, simplifying the computation for large-scale DNNs (Tang et al., 2021). Notably, h¯i1M×(Piin+1)\bar{h}_{i-1}\in\mathbb{R}^{M\times(P_{i}^{in}+1)} and siM×Piouts_{i}\in\mathbb{R}^{M\times P_{i}^{out}}, where MM is the size of a batch, rendering F^i,j(Piin+1)Piout×(Piin+1)Piout\hat{F}_{i,j}\in\mathbb{R}^{(P_{i}^{in}+1)P_{i}^{out}\times(P_{i}^{in}+1)P_{i}^{out}}. Initially, K-FAC estimates the expectation of the Kronecker product under the presumption that activations and pre-activation derivatives are mutually independent, succinctly represented as the Kronecker product of the individual expectations: F^i,j=i1,j1𝒮i,j\hat{F}_{i,j}=\mathcal{H}_{i-1,j-1}\otimes\mathcal{S}_{i,j}, where i,j=𝔼[h¯iTh¯j]\mathcal{H}_{i,j}=\mathbb{E}[\bar{h}_{i}^{T}\bar{h}_{j}] and 𝒮i,j=𝔼[siTsj]\mathcal{S}_{i,j}=\mathbb{E}[s_{i}^{T}s_{j}], denoting the Kronecker factors. The assumption for the block-diagonal structure posits that weight derivatives across distinct layers are uncorrelated, expressed as: FF^=diag(F^1,1,,F^L,L)=diag(F^1,,F^L)F\approx\hat{F}=\text{diag}(\hat{F}_{1,1},\ldots,\hat{F}_{L,L})=\text{diag}(\hat{F}_{1},\ldots,\hat{F}_{L}). Figure 1 shows EFIM computation via K-FAC.

2 Conclusion, Limitations and Future Research

In this work, we introduced AdaFisher, an adaptive optimizer that leverages a novel diagonal block-Kronecker approximation of the FIM to improve gradient rescaling and descent directions. Incorporated within the Adam framework, AdaFisher speeds up training, reduces hyperparameter sensitivity, and delivers higher accuracy and stability across image classification and language modeling tasks. Empirical and theoretical analyses demonstrate its superiority over current optimizers, with efficient space and time usage facilitating its application across diverse tasks. Notably, AdaFisher excels in SOTA comparisons on ImageNet under both single and multi-GPU setups.

Although optimized for statistical tasks, AdaFisher is less suited for tasks involving non-exponential loss families due to its reliance on statistical data for FIM computation.

Future work will expand testing to other models and areas, such as generative modeling (diffusion models) and graph neural networks, and developing CUDA kernels for Kronecker factors could greatly improve AdaFisher’s scalability and performance.

Impact Statement

AdaFisher represents a significant advancement in training efficiency, achieving superior accuracy on the ImageNet dataset using only a single GPU. This optimization is particularly beneficial for academia and students who may not have access to extensive computational resources. By enabling effective training with fewer GPUs, AdaFisher offers an accessible yet powerful solution, reducing hardware costs and making advanced machine learning more attainable for those with limited resources. This capability underscores AdaFisher’s potential as a valuable tool in democratizing machine learning technology.

Acknowledgment

This research is funded by Natural Sciences & Engineering Research Council (NSERC)-Discovery Grant RGPIN‐2022‐05378 [M.H.]; FRQNT-NSERC grant 2023-NOVA-329125 [E.B.& G.W.]; Canada CIFAR AI Chair, NSF DMS grant 2327211 and NSERC Discovery grant 03267 [G.W.]

References

  • rz7 (2018) iris, 2018. URL https://dx.doi.org/10.21227/rz7n-kj20.
  • Amari (1998) Shun-ichi Amari. Natural gradient works efficiently in learning. Neural Computation, 10(2):251–276, Feb 1998. ISSN 0899-7667. doi: 10.1162/089976698300017746.
  • Amari & Nagaoka (2000) Shun‐ichi Amari and Hiroshi Nagaoka. Methods of information geometry. 2000. URL https://api.semanticscholar.org/CorpusID:116976027.
  • Ba et al. (2016) Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization, 2016.
  • Boyd & Vandenberghe (2004) Stephen Boyd and Lieven Vandenberghe. Convex Optimization. Cambridge University Press, 2004.
  • Braeken & Van Assen (2017) Johan Braeken and Marcel ALM Van Assen. An empirical kaiser criterion. Psychological methods, 22(3):450, 2017.
  • Cha et al. (2021) Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, and Sungrae Park. Swad: Domain generalization by seeking flat minima. Advances in Neural Information Processing Systems, 34:22405–22418, 2021.
  • Chen et al. (2019) Xiangyi Chen, Sijia Liu, Ruoyu Sun, and Mingyi Hong. On the convergence of a class of adam-type algorithms for non-convex optimization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=H1x-x309tm.
  • Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE Conference on Computer Vision and Pattern Recognition, pp.  248–255, 2009. doi: 10.1109/CVPR.2009.5206848.
  • Deng (2012) Li Deng. The mnist database of handwritten digit images for machine learning research [best of the web]. IEEE Signal Processing Magazine, 29(6):141–142, 2012. doi: 10.1109/MSP.2012.2211477.
  • DeVries & Taylor (2017) Terrance DeVries and Graham W Taylor. Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552, 2017.
  • Dubey et al. (2022) Shiv Ram Dubey, SH Shabbeer Basha, Satish Kumar Singh, and Bidyut Baran Chaudhuri. Adainject: Injection based adaptive gradient descent optimizers for convolutional neural networks. IEEE Transactions on Artificial Intelligence, 2022.
  • Duchi et al. (2011) John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research, 12(61):2121–2159, 2011. URL http://jmlr.org/papers/v12/duchi11a.html.
  • ElNokrashy et al. (2022) Muhammad ElNokrashy, Badr AlKhamissi, and Mona Diab. Depth-wise attention (dwatt): A layer fusion method for data-efficient classification. arXiv preprint arXiv:2209.15168, 2022.
  • Eschenhagen et al. (2023) Runa Eschenhagen, Alexander Immer, Richard E Turner, Frank Schneider, and Philipp Hennig. Kronecker-factored approximate curvature for modern neural network architectures. arXiv preprint arXiv:2311.00636, 2023.
  • Eschenhagen et al. (2024) Runa Eschenhagen, Alexander Immer, Richard Turner, Frank Schneider, and Philipp Hennig. Kronecker-factored approximate curvature for modern neural network architectures. Advances in Neural Information Processing Systems, 36, 2024.
  • Foret et al. (2021) Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization, 2021.
  • F.R.S. (1901) Karl Pearson F.R.S. Liii. on lines and planes of closest fit to systems of points in space. The London, Edinburgh, and Dublin Philosophical Magazine and Journal of Science, 2(11):559–572, 1901. doi: 10.1080/14786440109462720.
  • George (2021) Thomas George. NNGeometry: Easy and Fast Fisher Information Matrices and Neural Tangent Kernels in PyTorch, February 2021. URL https://doi.org/10.5281/zenodo.4532597.
  • Goyal et al. (2017) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.
  • Grosse & Martens (2016) Roger Grosse and James Martens. A kronecker-factored approximate fisher matrix for convolution layers, 2016.
  • Gupta et al. (2018) Vineet Gupta, Tomer Koren, and Yoram Singer. Shampoo: Preconditioned stochastic tensor optimization, 2018.
  • Hassani et al. (2021) Ali Hassani, Steven Walton, Nikhil Shah, Abulikemu Abuduweili, Jiachen Li, and Humphrey Shi. Escaping the big data paradigm with compact transformers. 2021. URL https://arxiv.org/abs/2104.05704.
  • He et al. (2015) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition, 2015.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  770–778, 2016.
  • Heo et al. (2020) Byeongho Heo, Sanghyuk Chun, Seong Joon Oh, Dongyoon Han, Sangdoo Yun, Gyuwan Kim, Youngjung Uh, and Jung-Woo Ha. Adamp: Slowing down the slowdown for momentum optimizers on scale-invariant weights. arXiv preprint arXiv:2006.08217, 2020.
  • (27) Geoffrey Hinton, Nitish Srivastava, and Kevin Swersky. Neural networks for machine learning lecture 6a overview of mini-batch gradient descent.
  • Holmgren (1996) Richard A. Holmgren. Newton’s Method, pp.  127–151. Springer New York, New York, NY, 1996. ISBN 978-1-4419-8732-7. doi: 10.1007/978-1-4419-8732-7˙12. URL https://doi.org/10.1007/978-1-4419-8732-7_12.
  • Horn & Johnson (2012) Roger A. Horn and Charles R. Johnson. Matrix Analysis. Cambridge University Press, 2 edition, 2012.
  • Howard et al. (2019) Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, and Hartwig Adam. Searching for mobilenetv3, 2019.
  • Huang et al. (2024) Keke Huang, Ruize Gao, Bogdan Cautis, and Xiaokui Xiao. Scalable continuous-time diffusion framework for network inference and influence estimation, 2024.
  • Huo et al. (2021) Zhouyuan Huo, Bin Gu, and Heng Huang. Large batch optimization for deep learning using new complete layer-wise adaptive rate scaling. In Proceedings of the AAAI conference on artificial intelligence, volume 35, pp.  7883–7890, 2021.
  • Ioffe & Szegedy (2015) Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015.
  • Jiang et al. (2024a) Kaiqi Jiang, Dhruv Malik, and Yuanzhi Li. How does adaptive optimization impact local neural network geometry? Advances in Neural Information Processing Systems, 36, 2024a.
  • Jiang et al. (2024b) Zixuan Jiang, Jiaqi Gu, Hanqing Zhu, and David Pan. Pre-rmsnorm and pre-crmsnorm transformers: equivalent and efficient pre-ln transformers. Advances in Neural Information Processing Systems, 36, 2024b.
  • Kiefer & Wolfowitz (1952) Jack Kiefer and Jacob Wolfowitz. Stochastic estimation of the maximum of a regression function. The Annals of Mathematical Statistics, pp.  462–466, 1952.
  • Kingma & Ba (2017) Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2017.
  • Krizhevsky et al. (2009) Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. 2009.
  • Kunstner et al. (2019) Frederik Kunstner, Philipp Hennig, and Lukas Balles. Limitations of the empirical fisher approximation for natural gradient descent. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. URL https://proceedings.neurips.cc/paper_files/paper/2019/file/46a558d97954d0692411c861cf78ef79-Paper.pdf.
  • Le & Yang (2015) Ya Le and Xuan Yang. Tiny imagenet visual recognition challenge. CS 231N, 7(7):3, 2015.
  • Lecun (2001) Yann Lecun. A theoretical framework for back-propagation. 08 2001.
  • LeCun et al. (2012) Yann A. LeCun, Léon Bottou, Genevieve B. Orr, and Klaus Robert Müller. Efficient backprop, pp.  9–48. Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and Lecture Notes in Bioinformatics). Springer Verlag, 2012. ISBN 9783642352881. doi: 10.1007/978-3-642-35289-8˙3. Copyright: Copyright 2021 Elsevier B.V., All rights reserved.
  • Leplat et al. (2022) Valentin Leplat, Daniil Merkulov, Aleksandr Katrutsa, Daniel Bershatsky, Olga Tsymboi, and Ivan Oseledets. Nag-gs: Semi-implicit, accelerated and robust stochastic optimizer. arXiv preprint arXiv:2209.14937, 2022.
  • Lin et al. (2024a) Wu Lin, Felix Dangel, Runa Eschenhagen, Juhan Bae, Richard E Turner, and Alireza Makhzani. Can we remove the square-root in adaptive gradient methods? a second-order perspective. arXiv preprint arXiv:2402.03496, 2024a.
  • Lin et al. (2024b) Wu Lin, Felix Dangel, Runa Eschenhagen, Juhan Bae, Richard E. Turner, and Alireza Makhzani. Can we remove the square-root in adaptive gradient methods? a second-order perspective, 2024b.
  • Liu et al. (2019) Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han. On the variance of the adaptive learning rate and beyond. arXiv preprint arXiv:1908.03265, 2019.
  • Liu et al. (2021) Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  10012–10022, 2021.
  • Loshchilov & Hutter (2016) Ilya Loshchilov and Frank Hutter. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983, 2016.
  • Loshchilov & Hutter (2019) Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization, 2019.
  • Luo et al. (2019) Liangchen Luo, Yuanhao Xiong, Yan Liu, and Xu Sun. Adaptive gradient methods with dynamic bound of learning rate, 2019.
  • Ma et al. (2019) Linjian Ma, Gabe Montague, Jiayu Ye, Zhewei Yao, Amir Gholami, Kurt Keutzer, and Michael W. Mahoney. Inefficiency of k-fac for large batch size training, 2019.
  • Malladi et al. (2022) Sadhika Malladi, Kaifeng Lyu, Abhishek Panigrahi, and Sanjeev Arora. On the sdes and scaling rules for adaptive gradient algorithms. Advances in Neural Information Processing Systems, 35:7697–7711, 2022.
  • Marcus et al. (1993) Mitchell P. Marcus, Beatrice Santorini, and Mary Ann Marcinkiewicz. Building a large annotated corpus of English: The Penn Treebank. Computational Linguistics, 19(2):313–330, 1993. URL https://aclanthology.org/J93-2004.
  • Martens (2020) James Martens. New insights and perspectives on the natural gradient method. Journal of Machine Learning Research, 21(146):1–76, 2020.
  • Martens & Grosse (2015) James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In Francis Bach and David Blei (eds.), Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pp.  2408–2417, Lille, France, 07–09 Jul 2015. PMLR. URL https://proceedings.mlr.press/v37/martens15.html.
  • Martens & Grosse (2020) James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature, 2020.
  • Merity et al. (2016) Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
  • Mishchenko & Stich (2023) Konstantin Mishchenko and Sebastian U Stich. Noise injection irons out local minima and saddle points. In OPT 2023: Optimization for Machine Learning, 2023.
  • Oppenheim et al. (1999) Alan V. Oppenheim, Ronald W. Schafer, and John R. Buck. Discrete-Time Signal Processing. Prentice-hall Englewood Cliffs, second edition, 1999.
  • Osawa et al. (2023) Kazuki Osawa, Satoki Ishikawa, Rio Yokota, Shigang Li, and Torsten Hoefler. Asdl: A unified interface for gradient preconditioning in pytorch, 2023.
  • Park et al. (2000) H Park, S.-I Amari, and K Fukumizu. Adaptive natural gradient learning algorithms for various stochastic models. Neural Networks, 13(7):755–764, 2000. ISSN 0893-6080. doi: https://doi.org/10.1016/S0893-6080(00)00051-4. URL https://www.sciencedirect.com/science/article/pii/S0893608000000514.
  • Paszke et al. (2019) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, et al. Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems, 32, 2019.
  • Patro & Sahu (2015) SGOPAL Patro and Kishore Kumar Sahu. Normalization: A preprocessing stage. arXiv preprint arXiv:1503.06462, 2015.
  • Petersen & Pedersen (2008) K. B. Petersen and M. S. Pedersen. The matrix cookbook, October 2008. URL http://www2.imm.dtu.dk/pubdb/p.php?3274. Version 20081110.
  • Radford et al. (2019) Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language models are unsupervised multitask learners. 2019.
  • Reddi et al. (2019) Sashank J Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of adam and beyond. arXiv preprint arXiv:1904.09237, 2019.
  • Roux et al. (2007) Nicolas Roux, Pierre-Antoine Manzagol, and Yoshua Bengio. Topmoumoute online natural gradient algorithm. Advances in neural information processing systems, 20, 2007.
  • Ruder (2016) Sebastian Ruder. An overview of gradient descent optimization algorithms. arXiv preprint arXiv:1609.04747, 2016.
  • Russakovsky et al. (2015) Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. ImageNet Large Scale Visual Recognition Challenge. International Journal of Computer Vision (IJCV), 115(3):211–252, 2015. doi: 10.1007/s11263-015-0816-y.
  • Ryu & Boyd (2016) Ernest Ryu and Stephen Boyd. A primer on monotone operator methods survey. Applied and computational mathematics, 15:3–43, 01 2016.
  • Schaul et al. (2013) Tom Schaul, Sixin Zhang, and Yann LeCun. No more pesky learning rates, 2013.
  • Sutskever et al. (2013) Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. On the importance of initialization and momentum in deep learning. In Sanjoy Dasgupta and David McAllester (eds.), Proceedings of the 30th International Conference on Machine Learning, volume 28 of Proceedings of Machine Learning Research, pp.  1139–1147, Atlanta, Georgia, USA, 17–19 Jun 2013. PMLR. URL https://proceedings.mlr.press/v28/sutskever13.html.
  • Takahashi et al. (2020) Ryo Takahashi, Takashi Matsubara, and Kuniaki Uehara. Data augmentation using random image cropping and patching for deep cnns. IEEE Transactions on Circuits and Systems for Video Technology, 30(9):2917–2931, September 2020. ISSN 1558-2205. doi: 10.1109/tcsvt.2019.2935128. URL http://dx.doi.org/10.1109/TCSVT.2019.2935128.
  • Tang et al. (2021) Zedong Tang, Fenlong Jiang, Maoguo Gong, Hao Li, Yue Wu, Fan Yu, Zidong Wang, and Min Wang. Skfac: Training neural networks with faster kronecker-factored approximate curvature. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  13479–13487, 2021.
  • Wightman et al. (2021) Ross Wightman, Hugo Touvron, and Hervé Jégou. Resnet strikes back: An improved training procedure in timm. arXiv preprint arXiv:2110.00476, 2021.
  • Wilson et al. (2018) Ashia C. Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin Recht. The marginal value of adaptive gradient methods in machine learning, 2018.
  • Yang et al. (2022) Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, and Jianfeng Gao. Focal modulation networks, 2022.
  • Yao et al. (2021) Zhewei Yao, Amir Gholami, Sheng Shen, Mustafa Mustafa, Kurt Keutzer, and Michael Mahoney. Adahessian: An adaptive second order optimizer for machine learning. In proceedings of the AAAI conference on artificial intelligence, volume 35, pp.  10665–10673, 2021.
  • You et al. (2017) Yang You, Igor Gitman, and Boris Ginsburg. Large batch training of convolutional networks, 2017.
  • You et al. (2019) Yang You, Jing Li, Sashank Reddi, Jonathan Hseu, Sanjiv Kumar, Srinadh Bhojanapalli, Xiaodan Song, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh. Large batch optimization for deep learning: Training bert in 76 minutes. arXiv preprint arXiv:1904.00962, 2019.
  • Zaheer et al. (2018) Manzil Zaheer, Sashank Reddi, Devendra Sachan, Satyen Kale, and Sanjiv Kumar. Adaptive methods for nonconvex optimization. Advances in neural information processing systems, 31, 2018.
  • Zeiler (2012) Matthew D. Zeiler. Adadelta: An adaptive learning rate method, 2012.
  • Zhang et al. (2024) Yushun Zhang, Congliang Chen, Tian Ding, Ziniu Li, Ruoyu Sun, and Zhi-Quan Luo. Why transformers need adam: A hessian perspective. arXiv preprint arXiv:2402.16788, 2024.
  • Zhang & Sabuncu (2018) Zhilu Zhang and Mert R. Sabuncu. Generalized cross entropy loss for training deep neural networks with noisy labels, 2018.
  • Zhuang et al. (2020) Juntang Zhuang, Tommy Tang, Yifan Ding, Sekhar C Tatikonda, Nicha Dvornek, Xenophon Papademetris, and James Duncan. Adabelief optimizer: Adapting stepsizes by the belief in observed gradients. In H. Larochelle, M. Ranzato, R. Hadsell, M.F. Balcan, and H. Lin (eds.), Advances in Neural Information Processing Systems, volume 33, pp.  18795–18806. Curran Associates, Inc., 2020. URL https://proceedings.neurips.cc/paper_files/paper/2020/file/d9d4f495e875a2e075a1a4a6e1b9770f-Paper.pdf.

Appendix A Theory

A.1 Kronecker Factors: A Structural Examination (Continue)

In the realm of matrix theory, Gersgorin’s Circle Theorem offers a principle for localizing the eigenvalues of a complex square matrix, asserting that each eigenvalue is situated within at least one Gersgorin disk. These disks are defined by the matrix’s diagonal elements and the sum of the absolute values of the respective off-diagonal row entries. Formally, the theorem is stated as follows:

Theorem A.1 (Gersgorin Circle Theorem).

Let 𝒜\mathcal{A} be a complex square matrix with eigenvalues λ\lambda. For each λ\lambda, there exists an index ii such that:

|λ𝒜ii|j=1jin|𝒜ij|,|\lambda-\mathcal{A}_{ii}|\leq\sum_{\begin{subarray}{c}j=1\\ j\neq i\end{subarray}}^{n}|\mathcal{A}_{ij}|,

where the summation excludes the diagonal entry 𝒜ii\mathcal{A}_{ii}.

For a detailed proof of Theorem A.1, the reader is referred to the seminal work by Horn and Johnson Horn & Johnson (2012). Extending the application of Gersgorin’s Circle Theorem to the study of Kronecker factors within deep neural networks, we analyze these factors from both convolutional (3737th) and linear (4141st) layers of a ResNet-18 network, post-training on CIFAR10 dataset. As elucidated in Section LABEL:sec:kroneckerdetail, leveraging Theorem A.1 demonstrates that the eigenvalues of the Kronecker factors from the convolutional layer are predominantly concentrated along the diagonal. This observation is analogously applicable to the linear layer. Figure LABEL:fig:gershgorin_and_perturbation showcases the Gersgorin disks for the 41st (linear) layer, with the eigenvalues (red crosses) significantly clustered within these disks (centered at the black circles), underscoring a pronounced diagonal dominance. Moreover, upon introducing Gaussian noise to the off-diagonal elements following this shceme: ^=𝒜+,where =[eij] and eij𝒩(0,σ2) for ij\hat{\mathcal{M}}=\mathcal{A}+\mathcal{E},\quad\text{where }\mathcal{E}=[e_{ij}]\text{ and }e_{ij}\sim\mathcal{N}(0,\sigma^{2})\text{ for }i\neq j, the perturbation analysis elucidates that such stochastic variations engender only marginal displacements in the eigenvalues. Notably, those eigenvalues fulfilling the Kaiser criterion are minimally affected, substantiating the resilience of the diagonal dominance against noise-induced perturbations.

Refer to caption
Figure 2: Gersgorin disks and eigenvalue perturbation analysis for matrices \mathcal{H} and 𝒮\mathcal{S} at training steps 5200 (middle of training) and 9800 (end of training) in a ResNet-18 Network’s Linear Layer (4141st Layer). The left panel depicts the Gersgorin’s circles in the complex plane, while the right panel illustrates the magnitude spectrum of eigenvalues with and without the influence of Gaussian noise.

Our next analysis focus centers on elucidating the behaviors of matrices through consecutive steps in the frequency domain, thereby highlighting the intricate patterns and transformations emergent from the training process. By deploying a Fast Fourier Transform (FFT) on \mathcal{H} and 𝒮\mathcal{S}, along with their noise-infused variants ^\hat{\mathcal{H}} and 𝒮^\hat{\mathcal{S}}, we aim to dissect the spectral nuances of these factors. The deliberate addition of noise to the off-diagonal serves as a probe to validate our hypothesis that the pivotal information of the Kronecker factors is predominantly concentrated along their diagonals. The minimal impact of such noise perturbations observed empirically underscores this diagonal dominance. Our analysis aims to juxtapose the frequency domain representations of both the uncontaminated and the noise-affected matrices at assorted iterative phases, thereby illuminating the inherent stability and tenacity of the Kronecker structures amidst stochastic disturbances.
Let AA be a two-dimensional m×nm\times n matrix. The FFT of AA, denoted as (A)\mathcal{F}(A), is computed as:

(A)kl=p=0m1q=0n1Apqe2πi(pkm+qln),\mathcal{F}(A)_{kl}=\sum_{p=0}^{m-1}\sum_{q=0}^{n-1}A_{pq}\cdot e^{-2\pi i\left(\frac{pk}{m}+\frac{ql}{n}\right)}, (2)

where (A)kl\mathcal{F}(A)_{kl} is the value of the FFT at the kk-th row and ll-th column of the transformed matrix, ApqA_{pq} is the value of the original matrix at the pp-th row and qq-th column, and ii is the imaginary unit (Oppenheim et al., 1999).

Refer to caption
Figure 3: Comparative Visualization of FFT Outputs for Kronecker Factors in a ResNet-18 Network’s Convolutional and Linear Layers. (A) FFT results for Kronecker factors \mathcal{H} and ^\hat{\mathcal{H}} from the 3737th convolutional layer under noise-free conditions (top) and with Gaussian noise (bottom) at iterations 5200 (middle of training) and 9800 (end of training). (B) Analogous FFT results for Kronecker factors 𝒮\mathcal{S} and 𝒮^\hat{\mathcal{S}} from the 4141st linear layer, also contrasted between noise-free (top) and noisy conditions (bottom) at the same iterations.

Figure 3 demonstrates the Fourier spectral analysis of the Kronecker factors \mathcal{H} and 𝒮\mathcal{S} over two distinct iterative stages of training—5200 and 9800 for a convolutional and linear layers (3737th and 4141st of a ResNet-18 network respectively). Each Kronecker factor is analyzed via FFT in both a pristine, noise-free condition and a Gaussian noise-affected state, with the associated Signal-to-Noise Ratios (SNRs) detailed in Eq. (3). In the noise-free FFT spectra, a pronounced diagonal energy concentration is manifest in the \mathcal{H} and 𝒮\mathcal{S} factors of the convolutional layer, indicative of significant informational preservation along the diagonal. In contrast, the linear layer exhibits a less pronounced but still discernible diagonal energy distribution, suggesting a more diffuse yet still noteworthy diagonal information structure. With the addition of noise, the matrices ^\hat{\mathcal{H}} and 𝒮^\hat{\mathcal{S}} still display a notable diagonal pattern, indicating minimal SNR deterioration. This observation supports the proposition that the kronecker factors primarily encode their information along the diagonal, and the introduction of noise into the off-diagonal elements has a limited impact. The SNR between a matrix \mathcal{M} and ^\hat{\mathcal{M}} is computed using the formula:

SNR=10log10(i=1N|ii|2j>iN|^ij|2),\text{SNR}=10\cdot\log_{10}\left(\frac{\sum_{i=1}^{N}|\mathcal{M}_{ii}|^{2}}{\sum_{j>i}^{N}|\hat{\mathcal{M}}_{ij}|^{2}}\right), (3)

where ii\mathcal{M}_{ii} denotes the diagonal elements of \mathcal{M}, and ^ij\hat{\mathcal{M}}_{ij} represents the upper triangular elements of ^\hat{\mathcal{M}} excluding the diagonal (Oppenheim et al., 1999). The observed reduction in SNR from step 5200 to step 9800 for the Kronecker factor 𝒮\mathcal{S} in the convolutional layer, under noisy conditions, could suggest an incremental integration of noise effects across iterations. Conversely, for the remaining factors, an increase in SNR throughout the training process is detected, which may indicate an enhancement in signal clarity. Nevertheless, the integrity of the diagonal concentration of energy remains predominantly intact, demonstrating the underlying robustness of the network’s feature extraction capability against noise perturbations. Ultimately, the spectral analyses validate the hypothesis that the Kronecker factors’ informational content is predominantly diagonal and resistant to the effects of off-diagonal Gaussian noise. This durability is sustained through successive iterations, maintaining the primary spectral characteristics of the Kronecker factors.

Refer to caption
Figure 4: Visualization of Kronecker Factors \mathcal{H} and 𝒮\mathcal{S} for convolutional (A) and linear (B) layers at different iteration steps within a ResNet-18 network. For the convolutional layer (3737th layer), the first two plots in (A) represent factor \mathcal{H} at steps 5200 (middle of training) and 9800 (end of training), elucidating the matrix’s structure at these stages. The subsequent two plots display factor 𝒮\mathcal{S}, highlighting changes in granularity and contrast with iteration progression. Similarly, in (B) for the linear layer (4141st position), we observe the structural evolution of factor \mathcal{H} and 𝒮\mathcal{S} over the same iterations, with variations in pattern density and clarity. These visualizations collectively underscore the dynamic nature of the Kronecker factors’ architecture as training advances.

Figure 4 offers a visual exposition of the Kronecker Product Factors \mathcal{H} and 𝒮\mathcal{S} at progressive iteration junctures—specifically steps 5200 and 9800 for a convolutional and linear layers (3737th and 4141st of a ResNet-18 network respectively). The initial duo of plots in each (A) and (B), delineate the Kronecker factor \mathcal{H} at the aforementioned steps, elucidating the matrix’s structure at two distinct evolutionary stages. The next duo plots in (A) and (B) represents the Kronecker factor 𝒮\mathcal{S} at different steps of training. This visual examination, in conjunction with the preceding spectral analyses, articulates an integrated story of the developmental trajectory of the Kronecker factors. The enduring diagonal salience observed in both \mathcal{H} and 𝒮\mathcal{S} underscores the notion that the informational energy of the Kronecker factors is predominantly concentrated along the diagonal. This persistent feature accentuates the structural stability and the focused nature of information encoding within the network’s layers.

A.2 Proofs

Proposition A.1.

Consider a neural network layer indexed by ii, and a mini-batch 𝔹𝐃\mathbb{B}\subset\mathbf{D} of size MM (|𝔹|=M|\mathbb{B}|=M). The empirical statistics of the Kronecker factors for the normalization layers can be characterized as follows:

i=(𝔹𝒯h¯i)T(𝔹𝒯h¯i)(M|𝒯|)2,𝒮i=(𝔹𝒯si)(𝔹𝒯si)TM\displaystyle\mathcal{H}_{i}=\frac{\left(\sum_{\mathbb{B}}\sum_{\mathcal{T}}\bar{h}_{i}\right)^{T}\left(\sum_{\mathbb{B}}\sum_{\mathcal{T}}\bar{h}_{i}\right)}{(M|\mathcal{T}|)^{2}},\,\mathcal{S}_{i}=\frac{\left(\sum_{\mathbb{B}}\sum_{\mathcal{T}}s_{i}\right)\left(\sum_{\mathbb{B}}\sum_{\mathcal{T}}s_{i}\right)^{T}}{M} (4)

Here, 𝒯\mathcal{T} represents the spatial size dimension for Batch Norm layer and for LayerNorm layer, it signifies the product of the number of heads and the per-head dimension.

Proof.

The justification of our approach will be split into two parts: the first for the Batch Normalization layer and the second for the Layer Normalization.

Part 1: Batch Normalization

Batch Normalization is a solution to the problem of internal covariance shifting for normalizing the layer inputs. For each activation h¯i\bar{h}_{i} they introduce a pair of parameters νi\nu_{i} and βi\beta_{i} which scale and shift the normalized value:

yi=νih¯i+βi\displaystyle y_{i}=\nu_{i}\bar{h}_{i}+\beta_{i} (5)

The FIM for Batch Normalization captures the sensitivity of the output with respect to the parameters νi\nu{i} and βi\beta_{i}. We introduce rescaling and shifting operations into the FIM formulation to adapt for BatchNorm parameters, enabling efficient FIM approximation. For the multiplication operation in BatchNorm (scaling factor), we adjust the FIM calculation using Eq. (4) incorporating the batch size and spatial dimension. This normalization ensures FIM accounts for the BatchNorm scaling factors. Similarly, for the addition operation involving bias terms, we seamlessly integrate biases into the FIM formulation, capturing their impact on gradient computation. The shape of the Kronecker factors are defined as:

i2×2,𝒮ici×ci\displaystyle\mathcal{H}_{i}\in\mathbb{R}^{2\times 2},\,\,\mathcal{S}_{i}\in\mathbb{R}^{c_{i}\times c_{i}}

where cic_{i} refers to the channel dimension of layer ii.

Part 2: Layer Normalization

Layer Normalization normalizes the inputs across the features instead of the batch dimension and the same equation, Eq. (5), is also used for Layer Normalization. Similar to Batch Normalization we introduce rescaling and shifting operations into the FIM formulation but adapted to the normalization across features rather than the batch. In fact, 𝒯\mathcal{T} refers here to the product of the number of heads and the per-head dimension rather than the spatial size dimension.The shape of the Kronecker factors for LayerNorm are:

i2×2,𝒮i|𝒯|×|𝒯|\displaystyle\mathcal{H}_{i}\in\mathbb{R}^{2\times 2},\,\,\mathcal{S}_{i}\in\mathbb{R}^{|\mathcal{T}|\times|\mathcal{T}|}

Proposition A.2.

Let i\mathcal{H}_{i} and 𝒮i\mathcal{S}_{i} represent the Kronecker factors for a given layer index ii within a neural network, where these factors exhibit semi-diagonal characteristics indicating energy concentration predominantly along the diagonal, as elaborated in Section LABEL:sec:kroneckerdetail. Define gig_{i} as the gradient obtained through backpropagation at layer ii. Assume that i\mathcal{H}_{i} and 𝒮i\mathcal{S}_{i} can be closely approximated by diagonal matrices, denoted by Di\mathcal{H}_{D_{i}} and 𝒮Di\mathcal{S}_{D_{i}} respectively at layer ii, such that Di=Diag(i)\mathcal{H}_{D_{i}}=\text{Diag}(\mathcal{H}_{i}), 𝒮Di=Diag(𝒮i)\mathcal{S}_{D_{i}}=\text{Diag}(\mathcal{S}_{i}) where Diag()\text{Diag}(\mathcal{M}) denote the diagonal approximation of a matrix \mathcal{M}, which retains only the main diagonal. Therefore we define the Empirical FIM as:

F~DiDi𝒮Di+λ,\displaystyle\tilde{F}_{D_{i}}\triangleq\mathcal{H}_{D_{i}}^{\prime}\otimes\mathcal{S}_{D_{i}}^{\prime}+\lambda, (6)

where \mathcal{M}^{\prime} denotes the Min-Max normalization technique Patro & Sahu (2015) for =Di\mathcal{M}=\mathcal{H}_{D_{i}} or 𝒮Di\mathcal{S}_{D_{i}}. The regularization parameter λ\lambda set to 0.0010.001, serves as damping factors, in alignment with the principles of Tikhonov regularization, to enhance computational stability and improve the conditioning of the matrix. The foundational aspects of the K-FAC optimization approach are detailed in Martens & Grosse (2015). For a comprehensive account of the methodology and construction details, please consult Appendix A.2. Then, the closed-form solution for the augmented gradient g^i\hat{g}_{i}, derived from the diagonal approximation of the FIM, is given by: g^i=F~Di1gi\hat{g}_{i}=\tilde{F}_{D_{i}}^{-1}g_{i}.

Proof.

The justification of our approach comprises two principal components: the rationale for adopting a diagonal approximation of the Kronecker factors, and the methodology for normalization and regularization of these factors.

Part 1: Diagonalization of Kronecker Factors

The assumption of independent neuronal activity within layers is foundational to our approach. This assumption posits that the covariance matrices \mathcal{H} and 𝒮\mathcal{S}, encapsulating the second-order statistics of activations and sensitivities, respectively, are diagonal. This diagonal nature arises because independence among random variables implies a covariance of zero for any pair of distinct variables, thereby nullifying all off-diagonal elements of these covariance matrices.

Consider matrices AA and BB, each being diagonal with elements aiia_{ii} and bjjb_{jj}, respectively. The Kronecker product ABA\otimes B, by definition, generates elements aiibjja_{ii}b_{jj} at the corresponding (i,j)(i,j) positions. For diagonal AA and BB, this product maintains non-zero values exclusively at diagonal positions where i=ji=j, resulting in:

AB=diag(a11b11,,annbmm),A\otimes B=\text{diag}(a_{11}b_{11},\ldots,a_{nn}b_{mm}), (7)

yielding a purely diagonal matrix. Moreover, we have empirically demonstrated that the energy of the Kronecker factors is concentrated along the diagonal, as detailed in Sections LABEL:sec:kroneckerdetail and A.1. These arguments supports our initial premise.

Part 2: Normalization and Regularization

Normalization plays a pivotal role in machine learning algorithms, particularly in ensuring numerical stability and improving convergence properties of optimization algorithms. When dealing with matrices such as Di\mathcal{H}_{D_{i}} and 𝒮Di\mathcal{S}_{D_{i}}, which exhibit a diagonal structure, normalization not only aids in adjusting the scale of matrix values but also addresses the issue of varying scales among different features. The adoption of Min-Max normalization for the diagonal elements 𝒜i\mathcal{A}_{i} is especially advantageous as it standardizes the data to a fixed interval, commonly [0,1][0,1], which is crucial for many gradient-based optimization methods. The transformed matrix 𝒜~i\tilde{\mathcal{A}}_{i} is mathematically defined as:

𝒜~i=𝒜imin(𝒜i)max(𝒜i)min(𝒜i),\tilde{\mathcal{A}}_{i}=\frac{\mathcal{A}_{i}-\min(\mathcal{A}_{i})}{\max(\mathcal{A}_{i})-\min(\mathcal{A}_{i})}, (8)

where 𝒜i\mathcal{A}_{i} represents the diagonal elements from either Di\mathcal{H}_{D_{i}} or 𝒮Di\mathcal{S}_{D_{i}}. This approach ensures that all elements are scaled uniformly, preserving their relative magnitudes and distances. The numerator, 𝒜imin(𝒜i)\mathcal{A}_{i}-\min(\mathcal{A}_{i}), shifts the values so that the minimum element is zero. The denominator, max(𝒜i)min(𝒜i)\max(\mathcal{A}_{i})-\min(\mathcal{A}_{i}), scales the range of values to fit between zero and one. Compared to other normalization methods, such as z-score normalization (Patro & Sahu, 2015), Min-Max normalization offers the distinct benefit of bounding the values, which prevents problems associated with unbounded ranges that can adversely affect learning processes, particularly in networks sensitive to input magnitude. Moreover, Min-Max normalization is advantageous in scenarios where the parameters are influenced by activation functions like sigmoid or tanh, which are sensitive to input scale and function optimally within a defined range of [0,1][0,1] or [1,1][-1,1]. Thus, normalization, specifically using the Min-Max method, is crucial for maintaining computational stability in algorithms by ensuring that all input features contribute equally to the analysis without any undue influence from outliers or disproportionately large feature values. This uniformity facilitates faster convergence during training and mitigates the risk of encountering vanishing or exploding gradient issues in neural networks.

Together, these components substantiate the proposition, demonstrating that our methodological innovations not only adhere to theoretical expectations but also offer practical advantages in computational stability and efficiency. ∎

Proposition A.3.

For the FIM defined in Eq. (6), the updating scheme θt=F~t1J(θt)\triangle\theta_{t}=\tilde{F}^{-1}_{t}\nabla J(\theta_{t}) converges. Moreover, if J\nabla J is Lipschitz, i.e., J(θ)J(θ)2Lθθ||\nabla J(\theta)-\nabla J(\theta^{\prime})||_{2}\leq L||\theta-\theta^{\prime}|| for any θ\theta and θ\theta^{\prime}, then for the kk-step iteration with a fixed step size δ1/L\delta\leq 1/L, then

J(θ(k))J(θ)θ(0)θ222δk,J(\theta^{(k)})-J(\theta^{*})\leq\frac{||\theta^{(0)}-\theta^{*}||_{2}^{2}}{2\delta k},

where J(θ)J(\theta^{*}) is the optimal value.

Proof.

For convenience, we denote gt:=J(θt)g_{t}:=\nabla J(\theta_{t}). We follow the same proof as in Yao et al. (2021). Assume that J(θ)J(\theta) is a strongly convex and strictly smooth function in d\mathbb{R}^{d}, such that there exist positive constants α\alpha and β\beta so that αI2J(θ)βI\alpha I\leq\nabla^{2}J(\theta)\leq\beta I for all ww. We can show that the update formulation θt=F~t1gt\triangle\theta_{t}=\tilde{F}^{-1}_{t}g_{t} converges by showing that with the proper learning rate:

θt:=J(θt+1)J(θt)α2β2gt2\triangle\theta_{t}:=J(\theta_{t+1})-J(\theta_{t})\leq-\frac{\alpha}{2\beta^{2}}||g_{t}||^{2}

Note that when k=0k=0 or 1, the convergence rate is the same as gradient descent or Newton method, respectively. Our proof is similar to Boyd & Vandenberghe (2004) for Newton method. We denote λ(θt)=(gtTF~t1gt)1/2\lambda(\theta_{t})=(g_{t}^{T}\tilde{F}_{t}^{-1}g_{t})^{1/2}. Since J(θ)J(\theta) is strongly convex, we have

J(θtηθt)\displaystyle J(\theta_{t}-\eta\triangle\theta_{t}) J(θt)ηgtTθt+η2βθt22\displaystyle\leq J(\theta_{t})-\eta g_{t}^{T}\triangle\theta_{t}+\frac{\eta^{2}\beta||\triangle\theta_{t}||^{2}}{2}
J(θt)ηλ(θt)2+β2αη2λ(θt)2.\displaystyle\leq J(\theta_{t})-\eta\lambda(\theta_{t})^{2}+\frac{\beta}{2\alpha}\eta^{2}\lambda(\theta_{t})^{2}.

The second inequality come from the fact that

λ(θt)2=θtTF~tθtαθt2.\lambda(\theta_{t})^{2}=\triangle\theta_{t}^{T}\tilde{F}_{t}\triangle\theta_{t}\geq\alpha||\triangle\theta_{t}||^{2}.

Therefore, the step size η^=α/β\hat{\eta}=\alpha/\beta will make ff decrease as follows,

J(θtη^θt)J(θt)12η^λ(θt)2.J(\theta_{t}-\hat{\eta}\triangle\theta_{t})-J(\theta_{t})\leq-\frac{1}{2}\hat{\eta}\lambda(\theta_{t})^{2}.

Since αIF~tβI\alpha I\preceq\tilde{F}_{t}\preceq\beta I, we have

λ(θt)2=gtTF~t1gt1βgt2.\lambda(\theta_{t})^{2}=g_{t}^{T}\tilde{F}_{t}^{-1}g_{t}\geq\frac{1}{\beta}||g_{t}||^{2}.

Therefore,

J(θtη^θt)J(θt)12βη^gt2=α2β2gt2J(\theta_{t}-\hat{\eta}\triangle\theta_{t})-J(\theta_{t})\leq-\frac{1}{2\beta}\hat{\eta}||g_{t}||^{2}=-\frac{\alpha}{2\beta^{2}}||g_{t}||^{2} (9)

Since FDtF_{D_{t}} is positive definite, hence Eq. (9) holds true. For the bound on convergence rate, we refer to Ryu & Boyd (2016) for the details of the complete proof.

Proposition A.4 (Convergence in nonconvex stochastic optimization).

Under the assumptions:
(i) ff is lower bounded and differentiable; J(θ)J(θ)2Lθθ2||\nabla J(\theta)-\nabla J(\theta^{\prime})||_{2}\leq L||\theta-\theta^{\prime}||_{2}, F~Dt<L,t,θ,θ||\tilde{F}_{D_{t}}||_{\infty}<L,\,\forall t,\theta,\theta^{\prime}.
(ii) Both the true and stochastic gradient are bounded, i.e. J(θt)2λ||\nabla J(\theta_{t})||_{2}\leq\lambda and gt2λ||g_{t}||_{2}\leq\lambda, t\forall t for some λ>0\lambda>0.
(iii) Unbiased and independent noise in gtg_{t}, i.e. gt=J(θt)+ζtg_{t}=\nabla J(\theta_{t})+\zeta_{t}, 𝔼[ζt]=0\mathbb{E}[\zeta_{t}]=0, and ζiζj\zeta_{i}\perp\zeta_{j}, ij\forall i\neq j.

Assume ηt=ηt\eta_{t}=\frac{\eta}{\sqrt{t}}, βtβ1\beta_{t}\leq\beta\leq 1 is non-increasing, F~Dt1[j]ηt1F~Dt[j]ηt\frac{\tilde{F}_{D_{t-1}}[j]}{\eta_{t-1}}\leq\frac{\tilde{F}_{D_{t}}[j]}{\eta_{t}}, t[T],j[d]\forall t\in[T],j\in[d], we then have

mint[T]𝔼[J(θt)22]LT(C1η2λ2(1+logT)+C2dη+C3dη2+C4)\min_{t\in[T]}\mathbb{E}[||\nabla J(\theta_{t})||_{2}^{2}]\leq\frac{L}{\sqrt{T}}(C_{1}\eta^{2}\lambda^{2}(1+\log T)+C_{2}d\eta+C_{3}d\eta^{2}+C_{4}) (10)

where C1,C2,C3C_{1},C_{2},C_{3} are constants independent of dd and TT, C4C_{4} is a constant independent of TT, the expectation is taken w.r.t all the randomness corresponding to {gt}\{g_{t}\}.

Proof.

Follow Chen et al. (2019), as AdaFisher is an Adam-type method with the condition ηtmt/F~Dt2G||\eta_{t}m_{t}/\tilde{F}_{D_{t}}||_{2}\leq G for some GG (which can be obtained by ηt<η\eta_{t}<\eta, gt2λ||g_{t}||_{2}\leq\lambda and F~Dt21||\tilde{F}_{D_{t}}||_{2}\geq 1), we have

𝔼[t=1TηtJ(θt),J(θt)/F~Dt]\displaystyle\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\eta_{t}\langle\nabla J(\theta_{t}),\nabla J(\theta_{t})/\tilde{F}_{D_{t}}\rangle\Bigg{]}\leq 𝔼[C1t=1TηtgtF~Dt22+C2t=1TηtF~Dtηt1F~Dt11\displaystyle\mathbb{E}\Bigg{[}C_{1}\sum_{t=1}^{T}\left\|\frac{\eta_{t}g_{t}}{\tilde{F}_{D_{t}}}\right\|_{2}^{2}+C_{2}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{1}
+C3t=1TηtF~Dtηt1F~Dt122]+C4.\displaystyle+C_{3}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{2}^{2}\Bigg{]}+C_{4}. (11)

We first bound non-constant terms in RHS of Eq. (A.2). For the term with C1C_{1}, since F~Dt21||\tilde{F}_{D_{t}}||_{2}\geq 1, we have

𝔼[t=1TηtgtF~Dt22]\displaystyle\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\left\|\frac{\eta_{t}g_{t}}{\tilde{F}_{D_{t}}}\right\|_{2}^{2}\Bigg{]} 𝔼[t=1Tηtgt22]\displaystyle\leq\mathbb{E}\Bigg{[}\sum_{t=1}^{T}||\eta_{t}g_{t}||_{2}^{2}\Bigg{]}
=𝔼[t=1Tηtgt22]\displaystyle=\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\left\|\frac{\eta}{\sqrt{t}}g_{t}\right\|_{2}^{2}\Bigg{]}
η2λ2t=1T1tη2λ2(1+logT).\displaystyle\leq\eta^{2}\lambda^{2}\sum_{t=1}^{T}\frac{1}{t}\leq\eta^{2}\lambda^{2}(1+\log T).

For the term with C2C_{2}, we have

𝔼[t=1TηtF~Dtηt1F~Dt11]\displaystyle\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{1}\Bigg{]} =𝔼[j=1dt=2T(ηt1F~Dt1[j]ηtF~Dt[j])]\displaystyle=\mathbb{E}\Bigg{[}\sum_{j=1}^{d}\sum_{t=2}^{T}\left(\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}-\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}\right)\Bigg{]}
=𝔼[j=1dη1F~D1[j]ηTF~DT[j]]\displaystyle=\mathbb{E}\Bigg{[}\sum_{j=1}^{d}\frac{\eta_{1}}{\tilde{F}_{D_{1}}[j]}-\frac{\eta_{T}}{\tilde{F}_{D_{T}}[j]}\Bigg{]}
𝔼[j=1dη1F~D1[j]]dη\displaystyle\leq\mathbb{E}\Bigg{[}\sum_{j=1}^{d}\frac{\eta_{1}}{\tilde{F}_{D_{1}}[j]}\Bigg{]}\leq d\eta

where the first equality is due to F~Dt1[j]ηt1F~Dt[j]ηt\frac{\tilde{F}_{D_{t-1}}[j]}{\eta_{t-1}}\leq\frac{\tilde{F}_{D_{t}}[j]}{\eta_{t}}, t[T],j[d]\forall t\in[T],j\in[d].

For the term with C3C_{3}, we have

𝔼[t=1TηtF~Dtηt1F~Dt122]\displaystyle\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{2}^{2}\Bigg{]} =𝔼[t=1Tj=1d(ηtF~Dt[j]ηt1F~Dt1[j])2]\displaystyle=\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\sum_{j=1}^{d}\left(\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}\right)^{2}\Bigg{]}
=𝔼[t=1Tj=1d|ηtF~Dt[j]ηt1F~Dt1[j]||ηtF~Dt[j]ηt1F~Dt1[j]|]\displaystyle=\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\sum_{j=1}^{d}\left|\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}\right|\cdot\left|\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}\right|\Bigg{]}
𝔼[t=1Tj=1d|ηtF~Dt[j]ηt1F~Dt1[j]||ηtF~Dt[j]ηt1F~Dt1[j]|]\displaystyle\leq\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\sum_{j=1}^{d}\left|\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}\right|\cdot\left|\frac{\eta}{\sqrt{t}\tilde{F}_{D_{t}}[j]}-\frac{\eta}{\sqrt{t-1}\tilde{F}_{D_{t-1}}[j]}\right|\Bigg{]}
𝔼[ηt=1Tj=1d|ηtF~Dt[j]ηt1F~Dt1[j]|]\displaystyle\leq\mathbb{E}\Bigg{[}\eta\sum_{t=1}^{T}\sum_{j=1}^{d}\left|\frac{\eta_{t}}{\tilde{F}_{D_{t}}[j]}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}[j]}\right|\Bigg{]}
=η𝔼[t=1TηtF~Dtηt1F~Dt11]\displaystyle=\eta\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{1}\Bigg{]}
dη2\displaystyle\leq d\eta^{2}

Hence

𝔼[C1t=1TηtgtF~Dt22\displaystyle\mathbb{E}\Bigg{[}C_{1}\sum_{t=1}^{T}\left\|\frac{\eta_{t}g_{t}}{\tilde{F}_{D_{t}}}\right\|_{2}^{2} +C2t=1TηtF~Dtηt1F~Dt11+C3t=1TηtF~Dtηt1F~Dt122]+C4\displaystyle+C_{2}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{1}+C_{3}\sum_{t=1}^{T}\left\|\frac{\eta_{t}}{\tilde{F}_{D_{t}}}-\frac{\eta_{t-1}}{\tilde{F}_{D_{t-1}}}\right\|_{2}^{2}\Bigg{]}+C_{4}
C1η2λ2(1+logT)+C2dη+C3dη2+C4\displaystyle\leq C_{1}\eta^{2}\lambda^{2}(1+\log T)+C_{2}d\eta+C_{3}d\eta^{2}+C_{4} (12)

Now we lower bound the LHS of Eq. (10). With the assumption F~DtL||\tilde{F}_{D_{t}}||_{\infty}\leq L, we have

(ηt/F~Dt)jηLt.(\eta_{t}/\tilde{F}_{D_{t}})_{j}\geq\frac{\eta}{L\sqrt{t}}.

Thus

𝔼[t=1TηtJ(θt),J(θt)/F~Dt]𝔼[t=1TηLtJ(θt)22]TLmint[T]𝔼[J(θt)22]\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\eta_{t}\langle\nabla J(\theta_{t}),\nabla J(\theta_{t})/\tilde{F}_{D_{t}}\rangle\Bigg{]}\geq\mathbb{E}\Bigg{[}\sum_{t=1}^{T}\frac{\eta}{L\sqrt{t}}||\nabla J(\theta_{t})||_{2}^{2}\Bigg{]}\geq\frac{\sqrt{T}}{L}\min_{t\in[T]}\mathbb{E}[||\nabla J(\theta_{t})||_{2}^{2}] (13)

Combining Eq. (A.2) and (13) gives the desired result. ∎

A.3 Computation of Kronecker Factors

The Kronecker factors \mathcal{H} and 𝒮\mathcal{S}, which are integral to the AdaFisher optimizer, are computed following methodologies similar to those described in Grosse & Martens (2016). This section revisits the key equations used for this computation. For a given layer ii in a neural network, consider a mini-batch 𝔹𝐃\mathbb{B}\subset\mathbf{D}, where |𝔹|=M|\mathbb{B}|=M. The empirical Kronecker factors are computed as follows:

  • For fully connected layers, the Kronecker factors are:

    Di=diag(h¯iTh¯iM),𝒮Di=diag(siTsiM);\displaystyle\mathcal{H}_{D_{i}}=\text{diag}\left(\frac{\overline{h}_{i}^{T}\overline{h}_{i}}{M}\right),\quad\mathcal{S}_{D_{i}}=\text{diag}\left(\frac{s_{i}^{T}s_{i}}{M}\right);
  • For convolutional layers, the computation accounts for the spatial positions within the layer, denoted as 𝒯\mathcal{T}:

    Di=diag(h¯iTh¯iM|𝒯|),𝒮Di=diag(siTsiM|𝒯|);\displaystyle\mathcal{H}_{D_{i}}=\text{diag}\left(\frac{\llbracket\overline{h}_{i}\rrbracket^{T}\llbracket\overline{h}_{i}\rrbracket}{M|\mathcal{T}|}\right),\quad\mathcal{S}_{D_{i}}=\text{diag}\left(\frac{s_{i}^{T}s_{i}}{M|\mathcal{T}|}\right);

    The algorithm employs the expansion operation denoted by \llbracket\cdot\rrbracket (Grosse & Martens, 2016). This operation essentially takes the patches surrounding spatial locations, stretches them into vectors, and compiles these vectors into a matrix

  • For Normalization layers (BatchNorm & LayerNorm) please refer to Proposition. LABEL:prop:proposition_normalization

  • For all other type of layers the Kronecker factors are:

    Di=Idi,𝒮Di=Idi;\displaystyle\mathcal{H}_{D_{i}}=\textbf{I}_{d_{i}},\quad\mathcal{S}_{D_{i}}=\textbf{I}_{d_{i}};

    where did_{i} denotes the dimension of the iith layer and 𝐈\mathbf{I} is the identity matrix.

Table 1: AdaFisher training time per epoch (s) across various numbers of GPUs on ResNet-50 ImageNet.
GPU amount Batch Size AdaFisher training time per epoch (s)
1 256 2882
2 512 1438
3 768 963
4 1024 720

A.4 Distributed AdaFisher

The efficacy of AdaFisher hinges on its innovative approximation of the FIM, denoted as F~\tilde{F}, which leverages Kronecker factors for computation. In a distributed setting, it is crucial to aggregate these Kronecker factors across multiple GPUs before updating the model parameters. Consider a training environment consisting of NN GPUs. For any given layer ii, the Kronecker factors are computed and aggregated across all GPUs as

(𝐇Di)SUM=1Nn=1N(𝐇Di)n,(𝐒Di)SUM=1Nn=1N(𝐒Di)n\displaystyle(\mathbf{H}_{D_{i}}^{\prime})^{\text{SUM}}=\frac{1}{N}\sum_{n=1}^{N}(\mathbf{H}_{D_{i}}^{\prime})^{n},\quad(\mathbf{S}_{D_{i}}^{\prime})^{\text{SUM}}=\frac{1}{N}\sum_{n=1}^{N}(\mathbf{S}_{D_{i}}^{\prime})^{n} (14)

The theoretical justification for this aggregation lies in the linearity of expectation and the unbiasedness of the local Kronecker factor estimates. Specifically, if each (𝐇Di)n(\mathbf{H}_{D_{i}}^{\prime})^{n} and (𝐒Di)n(\mathbf{S}_{D_{i}}^{\prime})^{n} are unbiased estimators of their respective true factors 𝐇Di\mathbf{H}_{D_{i}}^{\prime} and 𝐒Di\mathbf{S}_{D_{i}}^{\prime}, then the averaged factors (𝐇Di)SUM(\mathbf{H}_{D_{i}}^{\prime})^{\text{SUM}} and (𝐒Di)SUM(\mathbf{S}_{D_{i}}^{\prime})^{\text{SUM}} remain unbiased estimators of 𝐇Di\mathbf{H}_{D_{i}}^{\prime} and 𝐒Di\mathbf{S}_{D_{i}}^{\prime}. Consequently, the aggregated EFIM for layer ii can be calculated as

F~DiSUM=(𝐇Di)SUM(𝐒Di)SUM+λ\displaystyle\tilde{F}_{D_{i}}^{\text{SUM}}=(\mathbf{H}_{D_{i}}^{\prime})^{\text{SUM}}\otimes(\mathbf{S}_{D_{i}}^{\prime})^{\text{SUM}}+\lambda

where λ\lambda is a regularization parameter added to ensure numerical stability. This methodology ensures that each GPU contributes to a comprehensive update of the model, enhancing both convergence and performance in large-scale distributed training environments. We assessed the distributed version of AdaFisher on ImageNet, utilizing batch sizes of 512 and 1024 (refer to Table LABEL:tab:imagenetresults and Figure 10 for details). Our findings indicate that AdaFisher scales nearly linearly with the number of GPUs, as evidenced in Table 1. There remains scope for additional low-level optimizations within the implementation to further enhance performance.

Appendix B Ablation Studies

Building on the ablative studies detailed in Section LABEL:sec:stabilityanalysis, this section extends our stability analysis to explore the impact of various learning rate schedulers and convergence efficiency, as discussed in Section B.1. Additionally, we conduct an in-depth examination of the key components of AdaFisher. This includes analyzing the effects of the EMA, the use of square root transformations, our novel approximation of the FIM, and the critical role of computing the FIM for normalization layers, all of which are detailed in Section B.2. We have consolidated the key findings of each ablation study in Table 2.

Table 2: Summary of Ablation Studies for AdaFisher Optimizer.
Ablation Study Component Studied Key Findings
Learning rate schedulers Impact of Cosine Annealing, StepLR, and no scheduler on AdaFisher AdaFisher maintains stable and efficient performance across various schedulers, demonstrating its robustness and adaptability in diverse training environments. For further details please refer to Section B.1.
Convergence Efficiencys Performance and alignment of FIM with Hessian AdaFisher shows marked performance improvements towards the end of training, with FIM alignment to the Hessian enhancing rapid convergence and stable generalization across training and testing phases. For further details please refer to Section B.1.
Square Root Utilization Effect of omitting square root in update rules Eliminating the square root enhances AdaFisher’s performance and stability, outperforming both its own version with the square root and Adam without the square root, while also improving computational efficiency. For further details please refer to Section B.2.
EMA of Kronecker Factors Utilization of EMA for curvature estimation Using EMA on Kronecker factors enhances AdaFisher’s curvature estimation, leveraging data from multiple mini-batches for continuous updates, demonstrating significant benefits in methods with diagonal or block-diagonal curvature approximations. For further details please refer to Section B.2.
Importance of Fisher Computation for Normalization Layers Impact of EFIM in normalization layers Incorporating Fisher computation in normalization layers significantly improves AdaFisher’s generalization and stability by enhancing parameter sensitivity and gradient variability insights, crucial for optimizing training dynamics and model convergence. For further details please refer to Section B.2.
New Approximation of the FIM Diagonal approximation of the FIM Our novel method focuses on the diagonal elements of the FIM, enhancing computation efficiency without losing critical information. Validation shows our approximation closely aligns with the true Fisher, confirming its efficacy. For further details please refer to Section B.2.
Refer to caption
Figure 5: Performance comparison of AdaFisher using the ResNet50 on the CIFAR10 with batch size of 256 with different learning rate schedulers.
Refer to caption
Figure 6: Comparison of FIM Diagonal Histograms during ResNet18 Training on CIFAR10 with Adam and AdaFisher over 1,000 training iterations. Panel (A) displays the FIM diagonal elements for the first convolutional layer; Panel (B) illustrates the FIM diagonal elements for the middle convolutional layer; Panel (C) shows the FIM diagonal elements for the last Linear layer.

B.1 Evaluating Stability Across Learning Rate Schedulers, and Assessing Convergence Efficiency

Learning rate schedulers. This analysis evaluates the impact of different learning rate schedulers–Cosine Annealing, StepLR, and no scheduler—on the performance of AdaFisher, as depicted in Figure 5. AdaFisher exhibits remarkable robustness across these scheduling strategies. Notably, its performance remains stable and efficient whether it is paired with the gradual adjustments of Cosine Annealing, the abrupt changes of StepLR, or even in the absence of any scheduler. This underscores AdaFisher’s adaptability and effectiveness in diverse training environments.

Convergence Efficiency. As training progresses, AdaFisher optimizer demonstrates a significant enhancement in performance compared to its counterparts, especially evident towards the end of the training period (see Appendix D.2.4). This rapid convergence is attributed to AdaFisher’s approach by incorporating the FIM. Early and mid-training, the FIM serves as an approximation to the Hessian matrix, equivalent to the Generalized Gauss Newton Matrix (Eschenhagen et al., 2024). However, as the model approaches a local minimum, the FIM increasingly aligns precisely with the Hessian (Martens, 2020). This precise alignment accelerates convergence, markedly improving the optimizer’s efficiency in the final phases of training. Additionally, AdaFisher’s tendency to converge to flat local minima leads to more stable generalization when transitioning from training to testing distributions (Cha et al., 2021), contrasting sharply with other optimizers. To support these points, we analyze the training distribution of our diagonal block-Kronecker FIM during the training of ResNet18 on CIFAR10. Specifically, we examine the FIM distribution for the first (Panel A), middle (Panel B) convolutional layers and the last linear layer (Panel C), as shown in Figure 6. It is evident that for each layer, the FIM distribution with AdaFisher narrows to smaller values with fewer variations compared to that with Adam. This pattern demonstrates AdaFisher’s convergence toward flatter local minima, as the Fisher Information, approximation of the Hessian, containing crucial curvature information.

B.2 Component Analysis: Evaluating the Significance of AdaFisher’s Elements

AdaFisher incorporates several key components, including a novel approximation of the FIM, the EMA of the Kronecker factors, the omission of the square root in the update rule, and a new EFIM formula for normalization layers. In this part, we elucidate each component and its significance within the AdaFisher optimizer.

Square Root Utilization. Recent studies, such as (Lin et al., 2024b), have reevaluated the necessity of the square root operation in the Adam family’s update rules. These studies suggest that eliminating the square root does not affect convergence and may even narrow the generalization gap compared to SGD in CNN models. Our analysis, shown in panel (A) of Figure 7, investigates this aspect by comparing the performance of AdaFisher and Adam, both with and without the square root operation. The findings reveal that removing the square root not only boosts the performance and stability of both optimizers but also significantly enhances computational efficiency. Specifically, AdaFisher without the square root not only outperforms the version with the square root but also surpasses Adam without the square root. However, Adam without the square root typically requires an additional scaling factor proportional to the batch size, denoted as fbatch sizef\propto\text{batch size}, to function correctly. Without this factor, Adam without the square root fails to learn effectively, making direct comparisons with AdaFisher invalid.

EMA of Kronecker Factors. As elucidated in Section LABEL:sec:effcomputFIM, employing an EMA over the Kronecker factors facilitates a more sophisticated curvature estimation. This technique leverages data across multiple mini-batches, enabling continuous updates to the Fisher information rather than relying solely on the data from a single batch. Panel (B) of Figure 7 underscores, using ResNet-50 on CIFAR10 over 200 epochs, the benefits of using EMA on Kronecker factors, a strategy particularly advantageous in methods that utilize diagonal or block-diagonal approximations of the curvature matrix.

Importance of Fisher Computation for Normalization Layers. The integration of the EFIM in normalization layers, as detailed in Proposition LABEL:prop:proposition_normalization, significantly enhances the generalization process. Panel (C) of Figure 7 illustrates the impact of incorporating Fisher computation in these layers during the training of AdaFisher with ResNet-50 on CIFAR10 over 200 epochs. In contrast, the identity matrix is employed when Fisher computation is omitted. The superior performance of AdaFisher when incorporating Fisher computation can be attributed to the critical role normalization layers play in adjusting the input distribution for each mini-batch. This adjustment substantially enhances the neural network’s learning stability (Jiang et al., 2024b). By quantifying the information each output yy carries about the parameters θ\theta under the model distribution p(yx;θ)p(y\mid x;\theta), the computation of the FIM in these layers provides valuable insights into parameter sensitivity and gradient variability. This insight is crucial for optimizing training dynamics and enhancing model convergence—areas that are often inadequately addressed by existing optimizers.

New Approximation of the FIM. In Proposition LABEL:prop:proposition1, we introduce a new methodology for approximating the FIM that diverges from the K-FAC optimizer. Unlike K-FAC, which utilizes the full Kronecker product, our approach focuses solely on the diagonal elements of the FIM, where, as demonstrated in Section LABEL:sec:kroneckerdetail, the energy of the Kronecker factors is predominantly concentrated. This method enables a more efficient computation of the FIM without sacrificing critical information. To validate our approach, we compare the true FIM diagonal with our approximation in convolutional and dense layers using a toy model composed of 2 convolutional layers and 2 linear layers on a subset of the MNIST dataset (Deng, 2012) over 50 epochs. Specifically, we calculate the true Fisher using the NNgeometry Python package (George, 2021), which facilitates the computation of the FIM, Gauss-Newton Matrix, or Neural Tangent Kernels applied to neural networks. We estimate p(yx)p(y\mid x) through Monte-Carlo sampling. During each epoch, we collected both the empirical and true Fisher information and calculated the Mean Absolute Error (MAE) between these two measures. Panel (D) of Figure 7 showcases the close approximation of AdaFisher’s empirical diagonal to the true Fisher, thus validating the efficacy of our approximation method.

Refer to caption
Figure 7: AdaFisher Component Analysis. (A) Comparison of MAE between the true FIM FDF_{D} and our approximation F~D\tilde{F}_{D} across convolutional and dense layers. (B) Performance comparison of AdaFisher with and without the EMA of Kronecker factors. (C) Assessment of AdaFisher’s performance with and without the computation of EFIM for Batch Normalization (BN) layers.

Appendix C Visualization

The convergence rate of an optimizer is crucial, serving as an indicator of its robustness against saddle points and its ability to generalize effectively. In this section, we introduce a novel methodology for visualizing the convergence behavior of optimizers through a statistical model, as depicted in Figure LABEL:fig:heatloss. Initially, our process employs Principal Component Analysis (PCA) for dimensionality reduction, reducing the dataset dimensions from 𝒟m×n\mathcal{D}\in\mathbb{R}^{m\times n} to 𝒟^m×2\hat{\mathcal{D}}\in\mathbb{R}^{m\times 2}, following the protocol established in F.R.S. (1901). We then apply this reduced dataset to a toy model composed of an LL-layer multi-layer perceptron (MLP). Notably, we focus on the first weight matrix W1eW_{1}^{e} of this MLP, which resides in 2\mathbb{R}^{2}, where ee denotes the epoch number. For consistency and to ensure comparability, all layers’ weights are initialized identically across different optimizers.

Refer to caption
Figure 8: Pipeline for visualization of optimization paths for various algorithms on a loss surface, comparing their convergence efficiency.

Following the training phase with various optimizers where we denote a set of optimizer results 𝒪\mathcal{O}, we analyze both the collection of first-layer weights, 𝒲\mathcal{W}, and the evolution of the loss function, 𝔏\mathfrak{L} defined as:

𝒲=[(W11)(W12)(W1E)],𝔏=[11,12,,1E]\mathcal{W}=\begin{bmatrix}(W_{1}^{1})^{\top}\\ (W_{1}^{2})^{\top}\\ \vdots\\ (W_{1}^{E})^{\top}\end{bmatrix},\quad\mathfrak{L}=[\mathcal{L}_{1}^{1},\mathcal{L}_{1}^{2},\ldots,\mathcal{L}_{1}^{E}]^{\top}

where (W1e)(W_{1}^{e})^{\top} represents the weight vector at the ee-th epoch, and 1e\mathcal{L}_{1}^{e} represents the loss at the ee-th epoch, extracted from the optimization results 𝒪\mathcal{O}. We construct a grid (𝐗,𝐘)(\mathbf{X},\mathbf{Y}) spanning the range of weight parameters, discretized into 200200 linearly spaced points along each axis:

𝐗,𝐘=meshgrid(min(𝒲:,1),max(𝒲:,1),min(𝒲:,2),max(𝒲:,2),200)\mathbf{X},\mathbf{Y}=\text{meshgrid}\left(\min(\mathcal{W}_{:,1}),\max(\mathcal{W}_{:,1}),\min(\mathcal{W}_{:,2}),\max(\mathcal{W}_{:,2}),200\right)

Finally, we interpolate the loss values \mathcal{L} over the grid using cubic interpolation to obtain a smooth loss surface 𝐙\mathbf{Z}:

𝐙=griddata(𝒲,,(𝐗,𝐘),method=cubic)\mathbf{Z}=\text{griddata}(\mathcal{W},\mathcal{L},(\mathbf{X},\mathbf{Y}),\text{method}=^{\prime}cubic^{\prime})

These elements are integral to the visualization process, which elucidates the optimizer’s trajectory through the parameter space across training epochs. It is important to note that while we focus on the first layer’s weight matrix for clarity, the methodology can be adapted to visualize the weights of any layer within the network. Figure 8 summarizes the pipeline.

In the experiment depicted in Figure LABEL:fig:heatloss, we selected the IRIS dataset (rz7, 2018), owing to its widespread recognition and compatibility with PCA application. Our model employs a 2-layer MLP architecture. We specifically attend to the weight matrix of the first layer, denoted by W12W_{1}\in\mathbb{R}^{2}. This particular focus is informed by the empirical observation that the parameters of the first layer tend to exhibit a faster convergence rate compared to those of subsequent layers in the network. Such a phenomenon can be attributed to the more direct influence of the input features on the first layer’s weights, which often results in a more pronounced and expedited learning dynamic. Given the classification nature of the task, we employed the Cross-Entropy loss function (Zhang & Sabuncu, 2018). The network was trained over 20 epochs using a suite of optimizers: Adam, AdaHessian, K-FAC, Shampoo, and AdaFisher. We standardized the learning rate across all optimizers at 1×1031\times 10^{-3} to ensure comparability of results. Examination of Figure LABEL:fig:heatloss reveals that AdaFisher’s convergence is markedly superior to that of its counterparts, achieving rapid convergence to the local minimum of the loss landscape concerning the first weight parameter within a few iterations. Conversely, the alternative optimizers demonstrate convergence to less optimal local minima.

Appendix D Experiments

D.1 Hardware

In total, we had a server with 6 NVIDIA RTX 6000 Ada Generation GPUS with 48 gigabytes of VRAM, and 128 gigabytes of RAM available for all experiments. All experiments described in this report were conducted on a system equipped with a single NVIDIA RTX 6000 Ada Generation GPU and 64 gigabytes of RAM, except for training AdaFisher on ImageNet with batch sizes of 512 and 1024, where four GPUs were utilized.

D.2 Image Classification

We provide further results and detailed descriptions of our image classification experiments in this section. We conducted five trials with random initializations for the CIFAR experiments, and one trial each for Tiny ImageNet and ImageNet. We present the mean and standard deviation of the results for these trials.
Note on training time. Given that various optimizers demonstrate significantly different epoch durations, we have standardized our comparisons by restricting training to the total WCT consumed by 200 epochs using AdaFisher for both CIFAR and Tiny ImageNet experiments. Conversely, for ImageNet, we report the results based on 90 WCT training epochs using Adam, as, surprisingly, AdaFisher and Adam exhibited the same duration in this experiment. The final selected number of epochs for each optimizer is detailed in Table 3. Please note that we were unable to train AdaHessian on ImageNet due to the significant computational resources required by this optimizer.

Table 3: Comparison of the final epoch counts for various optimizers across different datasets.
CIFAR10/100 & Tiny ImageNet ImageNet
Optimizers Adam/AdamW AdaHessian K-FAC Shampoo AdaFisher/AdaFisherW Adam K-FAC Shampoo AdaFisher
Epochs 210 89 107 36 200 90 60 26 90

D.2.1 Hyperparameter Tuning

Effective hyperparameter tuning is crucial for optimizing the performance of deep learning models. In this study, we systematically explored various hyperparameters for both CNNs and ViTs across multiple image classification tasks. The following subsections detail the tuning strategies employed for each model architecture and dataset.

CNNs. For all image classification tasks involving CNNs, we utilized ResNet18 as the backbone architecture and evaluated its performance on the CIFAR-10 dataset with a fixed batch size of 256. The hyperparameter tuning process encompassed the following components:

  • Optimizer Selection and Learning Rate Tuning: Each optimizer was fine-tuned using ResNet18 on CIFAR-10. We performed a grid search to identify the optimal learning rate from the set {0.0001,0.0003,0.0005,0.0009,,0.1,0.3,0.5,0.9}\{0.0001,0.0003,0.0005,0.0009,\dots,0.1,0.3,0.5,0.9\}.

  • Learning Rate Scheduling: A cosine annealing learning rate decay strategy was employed, aligning with the number of training epochs specified for each optimizer in Table 3. This approach follows the methodology proposed by Loshchilov & Hutter (2016) and was determined to be optimal for our experimental setup.

  • Weight Decay: We applied a uniform weight decay of 5×1045\times 10^{-4} across all optimizers for CIFAR-10 and Tiny ImageNet. An exception was made for MobileNetV3, where the weight decay was set to 1×1051\times 10^{-5}. For experiments on ImageNet, the weight decay was established at 1×1041\times 10^{-4}.

  • Damping Parameter Tuning:

    • AdaFisher, K-FAC, and Shampoo:

      • *

        K-FAC and AdaFisher: The damping parameter was searched within {0.0001,0.0003,0.0005,0.0009,0.001,0.003,0.005,0.009,0.01,0.03,0.05,0.09}\{0.0001,0.0003,0.0005,0.0009,0.001,0.003,0.005,0.009,0.01,0.03,0.05,0.09\}. This range was chosen based on prior research (Martens & Grosse, 2015) and our own experiments, which indicated optimal damping values around 1×1031\times 10^{-3}.

      • *

        Shampoo: The damping parameter was tuned within {1×106,3×106,5×106,9×106,1×105,3×105,5×105,9×105,1×104,3×104,5×104,9×104}\{1\times 10^{-6},3\times 10^{-6},5\times 10^{-6},9\times 10^{-6},1\times 10^{-5},3\times 10^{-5},5\times 10^{-5},9\times 10^{-5},1\times 10^{-4},3\times 10^{-4},5\times 10^{-4},9\times 10^{-4}\}, as optimal values typically reside around 1×1051\times 10^{-5}.

    • AdaHessian: The Hessian power was tuned within the range {0.1,0.2,,0.9,1.0}\{0.1,0.2,\dots,0.9,1.0\}.

    • AdaFisher Decay Factors: The decay factors γ1\gamma_{1} and γ2\gamma_{2} for AdaFisher were tuned within {0.1,0.2,,0.9,0.99}\{0.1,0.2,\dots,0.9,0.99\}.

  • Implementation Details: For the Shampoo and K-FAC optimizers, we utilized the ASDL library as implemented in PyTorch provided by Osawa et al. (2023).

ViTs. For ViT-based image classification tasks, we employed the Tiny Swin Transformer on the CIFAR-10 dataset with a batch size of 256. The hyperparameter tuning strategy for ViTs included the following elements:

  • Weight Decay: Weight decay values were set as indicated in the respective original publications for each model:

    • Tiny Swin: 1×1021\times 10^{-2}

    • FocalNet: 5×1025\times 10^{-2}

    • CCT-2/3×\times2: 6×1026\times 10^{-2}

  • Learning Rate Tuning: For AdaFisher, AdaHessian, K-FAC, and Shampoo optimizers, we conducted a grid search over the learning rates {0.3,0.15,0.1,0.05,0.03,0.015,0.01,0.005,0.003,0.0015,0.001}\{0.3,0.15,0.1,0.05,0.03,0.015,0.01,0.005,0.003,0.0015,0.001\}, as these optimizers typically operate with higher learning rates compared to Adam-based optimizers. For AdamW, the learning rates were adopted from the original publications:

    • Tiny Swin and FocalNet: 1×1041\times 10^{-4}

    • CCT-2/3×\times2: 5.5×1055.5\times 10^{-5}

  • Damping Parameter Tunning: We performed the same grid search over the damping parameter for K-FAC, Shampoo and AdaFisher, the Hessian power for AdaHessian and the decay factors for AdaFisher as explained in the CNNs part.

This meticulous hyperparameter tuning process ensures that each optimizer is optimally configured for the respective model architectures and datasets, thereby facilitating a fair and comprehensive comparison of their performance across different image classification tasks. The final learning rates for all optimizers and models are detailed in Table 4.

Table 4: Final selected learning rates for each optimizer, tuned using ResNet18 (for CNN) and Tiny Swin (for ViT) on CIFAR10 using a batch size of 256. We selected based on final validation top-1 accuracy.
Architecture Adam AdamW AdaHessian K-FAC Shampoo AdaFisher AdaFisherW
CNNs 0.001 - 0.15 0.3 0.3 0.001 -
ViTs - 0.0001/0.000055 0.01 0.003 0.003 - 0.001

D.2.2 Dataset Details

CIFAR. The training/test sets for Cifar10/100 dataset contain 50k/10k images, respectively. We consider a batch size of 256. For CIFAR-related experiments, we perform 32×3232\times 32 random-resize cropping, random horizontal flipping and cutout (DeVries & Taylor, 2017) as data augmentations. Please refer to Takahashi et al. (2020) for more details.
Tiny ImageNet. The training/test sets for TinyImageNet Le & Yang (2015) contains 100k/10k images. We perform 64×6464\times 64 random-resize cropping and random horizontal flipping. The batch size is set to be 256.
ImageNet. The training/test sets for ImageNet Russakovsky et al. (2015) contains 1,281,167/150k images. We consider a batch size of 256, as we performed experiments on a single GPU instance without any GPU parallelism. We follow He et al. (2016) and perform random resized cropping to 224×244224\times 244 and random horizontal flipping on the train set and 256×256256\times 256 resizing with 224×224224\times 224 center cropping on the test set.

D.2.3 Transfer Learning

Weights are initialized to the values provided by the publicly available checkpoints by PyTorch, except the first convolutional for the ResNet architecture and last dense layers for all networks, which change size to accomodate the new kernel size and number of classes respectively, that are randomly initialized. We train all models with weight decay 1e41e^{-4} as suggested in Wightman et al. (2021), expect for MobileNetV3 where weight decay is set to be 1e51e^{-5}. Moreover, we did a grid search for each optimizer for selecting the best learning rate of the range {0.3,0.15,0.1,0.03,0.015,0.01,,1e5}\{0.3,0.15,0.1,0.03,0.015,0.01,\dots,1e-5\} where we tabulate the selected learning rate for each optimizer in Table 5. We use a batch size of 256 and cosine learning rate decay. We use the same augmentation policy (without Cutout) as in the previous experiments. The results were obtained using the WCT technique over 50 training epochs of AdaFisher, with the final epoch count detailed in Table 6. All other parameters remained unchanged.

Table 5: Final selected learning rates for each optimizer with ImageNet-1k pretrained weights, tuned using ResNet50 on CIFAR10 using a batch size of 256. We tuned by completing a full WCT epoch training cycle, and selected based on final validation top-1 accuracy.
        Adam         AdaHessian         K-FAC         Shampoo         AdaFisher
        0.0001         0.15         0.3         0.03         0.001
Table 6: Final selected epoch counts for various optimizers across transfer learning task.
      Adam/AdamW       AdaHessian       K-FAC       Shampoo       AdaFisher/AdaFisherW
      55       22       27       18       50
Table 7: Performance metrics (mean, std) of different networks and optimizers on CIFAR10 and CIFAR100 using batch size 256 (a) without Cutout and (b) with Cutout. Reported using WCT of 200 AdaFisher training epochs as the cutoff.
(a) Without Cutout
CIFAR10 CIFAR100
Network Adam AdaHessian K-FAC Shampoo AdaFisher Adam AdaHessian K-FAC Shampoo AdaFisher
ResNet18 93.640.0293.64_{0.02} 94.050.0894.05_{0.08} 94.040.1694.04_{0.16} 94.520.1294.52_{0.12} 95.020.11\mathbf{95.02_{0.11}} 72.710.2472.71_{0.24} 73.640.2173.64_{0.21} 74.790.1974.79_{0.19} 76.530.1176.53_{0.11} 77.100.21\mathbf{77.10_{0.21}}
ResNet50 93.8901993.89_{019} 94.260.1194.26_{0.11} 94.250.0894.25_{0.08} 94.920.0994.92_{0.09} 95.420.21\mathbf{95.42_{0.21}} 73.120.7373.12_{0.73} 75.290.3175.29_{0.31} 75.490.1775.49_{0.17} 77.810.2277.81_{0.22} 78.910.91\mathbf{78.91_{0.91}}
ResNet101 93.140.1293.14_{0.12} 94.730.8594.73_{0.85} 94.230.1394.23_{0.13} 94.220.0494.22_{0.04} 95.510.13\mathbf{95.51_{0.13}} 73.230.3773.23_{0.37} 72.190.2372.19_{0.23} 75.460.2675.46_{0.26} 78.820.1178.82_{0.11} 79.740.28\mathbf{79.74_{0.28}}
DenseNet121 93.740.1693.74_{0.16} 94.540.0894.54_{0.08} 94.970.0594.97_{0.05} 94.990.1094.99_{0.10} 95.290.06\mathbf{95.29_{0.06}} 75.380.3475.38_{0.34} 72.540.8972.54_{0.89} 77.090.2677.09_{0.26} 78.700.2778.70_{0.27} 79.030.17\mathbf{79.03_{0.17}}
MobileNetV3 91.950.11591.95_{0.115} 91.43.0691.4_{3.06} 91.920.1491.92_{0.14} 91.910.1991.91_{0.19} 92.890.11\mathbf{92.89_{0.11}} 65.640.1965.64_{0.19} 60.783.6160.78_{3.61} 69.870.26269.87_{0.262} 68.010.2468.01_{0.24} 73.150.24\mathbf{73.15_{0.24}}
Tiny Swin 87.470.1687.47_{0.16} 78.340.2378.34_{0.23} 66.840.2566.84_{0.25} 68.440.1768.44_{0.17} 89.080.11\mathbf{89.08}_{0.11} 62.200.2262.20_{0.22} 54.120.2554.12_{0.25} 36.120.2836.12_{0.28} 33.750.2733.75_{0.27} 66.470.19\mathbf{66.47_{0.19}}
FocalNet 85.650.1285.65_{0.12} 71.030.2571.03_{0.25} 42.920.1842.92_{0.18} 41.490.2341.49_{0.23} 86.920.14\mathbf{86.92}_{0.14} 52.880.2552.88_{0.25} 38.050.2838.05_{0.28} 11.230.3111.23_{0.31} 11.060.3211.06_{0.32} 52.90.13\mathbf{52.9_{0.13}}
CCT-2/3×\times2 83.950.1283.95_{0.12} - 34.631.0334.63_{1.03} 35.10.7835.1_{0.78} 84.630.25\mathbf{84.63_{0.25}} 60.141.0660.14_{1.06} - 8.060.598.06_{0.59} 9.760.289.76_{0.28} 60.630.65\mathbf{60.63_{0.65}}
(b) With Cutout
CIFAR10 CIFAR100
Network Adam AdaHessian K-FAC Shampoo AdaFisher Adam AdaHessian K-FAC Shampoo AdaFisher
ResNet18 94.850.1094.85_{0.10} 95.440.0895.44_{0.08} 95.170.1695.17_{0.16} 94.080.2094.08_{0.20} 96.250.17\mathbf{96.25_{0.17}} 75.740.0975.74_{0.09} 71.790.2171.79_{0.21} 76.030.3376.03_{0.33} 76.780.1676.78_{0.16} 77.280.21\mathbf{77.28_{0.21}}
ResNet50 94.4501894.45_{018} 95.540.1195.54_{0.11} 95.660.1495.66_{0.14} 94.590.0994.59_{0.09} 96.340.21\mathbf{96.34_{0.21}} 74.650.4674.65_{0.46} 75.810.3175.81_{0.31} 77.400.3877.40_{0.38} 78.070.3578.07_{0.35} 79.770.35\mathbf{79.77_{0.35}}
ResNet101 94.570.1494.57_{0.14} 95.290.6495.29_{0.64} 96.010.1396.01_{0.13} 94.630.194.63_{0.1} 96.390.09\mathbf{96.39_{0.09}} 75.560.3475.56_{0.34} 73.380.2373.38_{0.23} 77.010.3977.01_{0.39} 78.830.1678.83_{0.16} 80.650.48\mathbf{80.65_{0.48}}
DenseNet121 94.860.1494.86_{0.14} 96.110.0596.11_{0.05} 96.120.0796.12_{0.07} 95.660.0795.66_{0.07} 96.720.04\mathbf{96.72_{0.04}} 75.870.4375.87_{0.43} 74.800.8974.80_{0.89} 79.790.2279.79_{0.22} 80.240.2580.24_{0.25} 81.360.28\mathbf{81.36_{0.28}}
MobileNetV3 93.320.1393.32_{0.13} 92.863.0692.86_{3.06} 94.340.1494.34_{0.14} 93.810.1993.81_{0.19} 95.280.10\mathbf{95.28_{0.10}} 70.620.3470.62_{0.34} 56.584.5456.58_{4.54} 73.750.26273.75_{0.262} 70.850.2770.85_{0.27} 77.560.13\mathbf{77.56_{0.13}}
Tiny Swin 87.370.6287.37_{0.62} 84.150.2384.15_{0.23} 64.790.4764.79_{0.47} 63.910.4363.91_{0.43} 88.740.39\mathbf{88.74_{0.39}} 60.210.4160.21_{0.41} 56.860.4556.86_{0.45} 34.450.4134.45_{0.41} 30.391.2130.39_{1.21} 66.050.46\mathbf{66.05_{0.46}}
FocalNet 86.230.0686.23_{0.06} 64.180.1664.18_{0.16} 38.940.8138.94_{0.81} 37.960.6537.96_{0.65} 87.900.14\mathbf{87.90}_{0.14} 52.710.4852.71_{0.48} 32.330.2832.33_{0.28} 9.980.579.98_{0.57} 9.180.149.18_{0.14} 53.690.37\mathbf{53.69_{0.37}}
CCT-2/3×\times2 83.890.3883.89_{0.38} - 33.082.3133.08_{2.31} 35.160.3535.16_{0.35} 84.940.28\mathbf{84.94}_{0.28} 59.780.5159.78_{0.51} - 7.170.217.17_{0.21} 8.600.138.60_{0.13} 62.910.54\mathbf{62.91_{0.54}}

D.2.4 Results

Table 8 displays the results for the Tiny ImageNet dataset using ResNet50 and Big Swin networks, with visualizations provided in Figure 9. AdaFisher and AdaFisherW consistently outperform current SOTA optimizers. Notably, Figure 9 illustrates that although AdaFisher converges slower than K-FAC during ResNet50 training, it achieves superior generalization. This is evidenced by lower testing errors, suggesting that AdaFisher tends to converge to a flatter local minimum, enabling smoother transitions between training and testing datasets with minimal generalization loss. For further explanation, please see Cha et al. (2021). Please note that due to AdaHessian’s high memory consumption, we were unable to train it on Big Swin.

Table 7 presents the performance of various networks on CIFAR10/100 datasets using different optimizers, both with and without the cutout augmentation technique. AdaFisher and AdaFisherW consistently outperform their counterparts in both scenarios, demonstrating stable training and robustness to the augmentation techniques. The training losses and test errors for the CIFAR experiments, both with and without cutout, are visually represented in Figures 11, 12, 13, and 14.

Table 8: Performance of various networks and optimizers on TinyImegneNet using batch size 256. Reported using wall clock time of 200 AdaFisher training epochs as the cutoff.
      Network       Adam       AdaHessian       K-FAC       Shampoo       AdaFisher
      ResNet50       53.0653.06       50.2150.21       50.0550.05       53.5353.53       57.41\mathbf{57.41}
      Big Swin       48.1148.11       -       8.898.89       4.114.11       48.86\mathbf{48.86}
Refer to caption
Figure 9: WCT training loss and testing error curves of several optimizers on Tiny ImageNet dataset, ResNet-50 and Big Swin with batch size of 256. AdaFisher consistently achieves lower test error as compared to Adam, AdaHessian, K-FAC and Shampoo. The final accuracy results are reported in Table 8.

Figure 10 illustrates the training and validation error of the distributed version of AdaFisher on ImageNet across various batch sizes. AdaFisher not only outperforms its counterparts with smaller batch sizes (256), but it also continues to achieve superior generalization as batch sizes increase. Furthermore, these results reinforce the stability analysis concerning batch sizes presented in Section LABEL:sec:stabilityanalysis, extending it to a more challenging dataset.

Refer to caption
Figure 10: Performance of distributed AdaFisher using ResNet50 on ImageNet with different batch sizes for 90 epochs. The final accuracy results are reported in Table LABEL:tab:imagenetresults.
Refer to caption
Figure 11: WCT training loss, test error, for CNNs and ViTs on CIFAR10 experiments, without Cutout. A batch size of 256 was used and all networks were tuned using ResNet18 applied on CIFAR10. The final accuracy results are reported in Table 7 (a).
Refer to caption
Figure 12: WCT training loss, test error, for CNNs and ViTs on CIFAR100 experiments, without Cutout. A batch size of 256 was used and all networks were tuned using ResNet18 applied on CIFAR10. The final accuracy results are reported in Table 7 (a).
Refer to caption
Figure 13: WCT training loss, test error, for CNNs and ViTs on CIFAR10 experiments, with Cutout. A batch size of 256 was used and all networks were tuned using ResNet18 applied on CIFAR10. The final accuracy results are reported in Table 7 (b).
Refer to caption
Figure 14: WCT training loss, test error, for CNNs and ViTs on CIFAR100 experiments, with Cutout. A batch size of 256 was used and all networks were tuned using ResNet18 applied on CIFAR10. The final accuracy results are reported in Table 7 (b).
Refer to caption
Figure 15: WCT training loss, test error, for CNNs on CIFAR10/100 experiments. A batch size of 256 was used and all networks were tuned using ResNet50 applied on CIFAR10. The final accuracy results are reported in Table LABEL:pretrained_cifar.

D.2.5 Comparison of Training Speed and Memory Utilization

As discussed in Section LABEL:sec:stabilityanalysis, AdaFisher emerges as a balanced trade-off between time complexity and performance. Similarly, its memory footprint is comparable to that of Adam, showcasing efficient VRAM utilization. We extend our stability analysis to the CIFAR-10 dataset to provide a dataset-independent evaluation of performance metrics, as depicted in Figure 16. Additionally, we analyze the memory usage for different batch sizes using the ResNet-50 model on the CIFAR-10/100, presented in Figure 17. The analysis reveals that AdaFisher, while maintaining high accuracy levels, uses memory comparably to Adam, especially evident in higher batch sizes. This suggests that AdaFisher can achieve competitive performance without excessive VRAM consumption, making it an optimal choice for scenarios with memory constraints.

Refer to caption
Figure 16: Performance comparison of AdaFisher and other optimizers across various batch sizes, epoch times and learning rate (with batch size of 256), evaluated using the ResNet50 on the CIFAR-10.
Refer to caption
Figure 17: Performance comparison of AdaFisher and other optimizer regarding the memory used, assessed using ResNet50 and CIFAR10/100 across different batch sizes. This figure highlights how AdaFisher competes closely with Adam in terms of memory efficiency and performance.

Epoch Times. Continuing our analysis of the time complexity for each optimizer, we present in Figure 18 the epoch times for various network architectures and datasets. Specifically, we compare the epoch times of Adam, AdaFisher, K-FAC, AdaHessian, and Shampoo optimizers on CIFAR10 and CIFAR100 datasets. As depicted in Figure 18 panel (A), AdaFisher demonstrates a comparable training time to Adam across multiple network architectures on the CIFAR10 dataset. This indicates that AdaFisher achieves efficient optimization without incurring significant additional computational cost. Similarly, in Figure 18 panel (B), we observe that the epoch times for AdaFisher remain close to those of Adam on the CIFAR100 dataset. While K-FAC and AdaHessian exhibit increased training times, Shampoo shows the highest epoch times across all tested networks. This further highlights the efficiency of AdaFisher as an optimizer, combining the advantages of advanced optimization techniques with practical training times.

Refer to caption
Figure 18: Epoch times for various networks on CIFAR10 (A) and CIFAR100 (B) using Adam, AdaFisher, K-FAC, AdaHessian and Shampoo.

D.3 Language Modelling

D.3.1 Dataset Details

The Wikitext-2 dataset, derived from high-quality Wikipedia articles, contains over two million words and is structured into training, validation, and test sets. It is widely used for benchmarking language models in natural language processing, especially assessing perplexity to evaluate predictive performance. This dataset offers a balance between computational efficiency and linguistic complexity, making it ideal for practical language model training and evaluation.

D.3.2 Network Details

Network. We utilize a streamlined GPT-1 architecture which incorporates four self-attention layers, a reduction from the original twelve. This configuration retains core modeling capabilities while reducing complexity, encompassing a total of 28,351,488 learnable parameters.
Embeddings & Parameter Sharing. To expedite training, we employ pretrained embeddings from OpenAI’s GPT, leveraging the benefits of parameter sharing for enhanced efficiency and faster convergence.

D.3.3 Hyperparameters

The model underwent training for 50 WCT epochs using AdaFisher on the WikiText-2 and PTB datasets, with the final epoch counts for each optimizer detailed in Table 9.

Table 9: Final selected epoch counts for various optimizers across language modelling task
          AdamW           AdaHessian           Shampoo           AdaFisherW
          55           18           12           50

For AdamW, we follow the learning rate setting in ElNokrashy et al. (2022). For the other optimizers we select the learning rate by doing a grid search of {0.3,0.15,0.1,0.05,0.03,0.015,0.01,,1e5}\{0.3,0.15,0.1,0.05,0.03,0.015,0.01,\dots,1e^{-5}\}. We tabulate the learning rate the we use in Table 10. The batch size was configured to 32, and the weight decay was established at 0.10.1. Despite optimizing the configuration of hyperparameters, Shampoo failed to converge, and K-FAC could not be trained at all.

Table 10: Final selected learning rates for each optimizer, tuned using GPT1 on WikiText-2 and PTB using a batch size of 32. We selected based on final validation PPL.
          AdamW           AdaHessian           Shampoo           AdaFisherW
          5e55e^{-5}           0.0150.015           0.0030.003           1e41e{-4}

D.3.4 Results

Figure 19 displays the training loss and testing error curves, clearly showing that AdaFisher surpasses both Adam and AdaHessian in performance on the WikiText-2 and PTB datasets.

Refer to caption
Figure 19: Training Loss and Test Perplexity of Small GPT-1 Model on WikiText-2 and PTB Datasets. Experiments were conducted using a batch size of 32 and optimal settings for all optimizers.