Trainable Weight Averaging: Accelerating Training and Improving Generalization
Abstract
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA’s advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40% on CIFAR datasets and 30% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.
Keywords: weight averaging, efficient training, learnable coefficients, optimization
1 Introduction
Weight averaging is a widely used technique for accelerating training and improving the generalization performance of deep neural networks (DNNs) (Izmailov et al., 2018; Gupta et al., 2019; Yang et al., 2019; Kaddour, 2022; Wortsman et al., 2022). It exploits the linear mode connectivity (Frankle et al., 2020) in DNNs’ loss landscapes, effectively mimicking model ensemble performance within a single weight space. Due to its simplicity and effectiveness, weight averaging has gained significant attention from practitioners.
In SWA (Izmailov et al., 2018), weight averaging is given by , where candidate solutions of the network collected at the tail stage of training are equally averaged. This strategy has been shown to be effective in enhancing generalization ability. Other improved strategies include selective averaging methods, such as greedy soup (Wortsman et al., 2022) and latest weight averaging (LAWA) (Kaddour, 2022), as well as leveraging pre-set strategies like exponential moving average (EMA). However, applying equal or pre-set weighting coefficients to candidate model weights can be inappropriate, especially when weights differ significantly, such as those sampled from different training configurations. This can potentially result in suboptimal performance.
In this paper, we study weight averaging with learnable coefficients. Let us regard each model weight as a basis. When we train the coefficients without considering the dependence of different weights, the number of model weights used for averaging is roughly equal to the dimension of the optimization space. It has been reported that for deep models, the training dynamics happen in a low-dimensional subspace (Gur-Ari et al., 2018; Li et al., 2022a) and crucial DNN mode-connectivity patterns emerge early in training (You et al., 2020a; Frankle et al., 2020). Therefore, optimizing the weighting coefficients provides great flexibility and can lead to improvements compared to equal averaging. This is especially important for applying weight averaging in the early stage when the parameters have not been well trained. These findings suggest a promising direction: effectively utilizing these early explorations may enable rapid composition of the final solution while maintaining high accuracy. Since many model weights are involved, optimizing the weighted coefficients also faces practical challenges in both memory and computation. Efficiency is an obstacle that must be solved when dealing with large-scale model training (Wortsman et al., 2022).
In this paper, we delve into efficient and effective weighted averaging of DNNs. We introduce Trainable Weight Averaging (TWA), a novel approach that enables explicit, trainable adjustments to the weighting coefficients. To facilitate efficient optimization, we reframe the problem as a subspace training task: regarding each weight as a point in the full parameter space, we construct a subspace containing the weight points to be averaged. By optimizing in this subspace, we can adaptively search for a good set of weighting coefficients. This scheme, which involves gradient projection onto the subspace, is computationally efficient and can be accelerated by GPUs. To cope with the challenge of serve memory burden when averaging with multiple weights (Wortsman et al., 2022), we develop an efficient distributed subspace training scheme that facilitates parallel processing across multiple nodes. This approach efficiently handles large-scale problems by evenly distributing memory and computational loads across nodes and can be seamlessly integrated with existing distributed training frameworks, such as Distributed Data Parallel (DDP) (Li et al., 2020). Furthermore, to enhance averaging efficiency, we introduce layer-wise processing to better utilize the model’s structure information and find that the matrix associated with gradient projection can be quantized to low-bit representations (e.g., 4 bits) without performance degradation, further reducing memory usage. Consequently, TWA enables efficient subspace training with minimal additional memory overhead, making it scalable for averaging large models.
By exploiting the low-dimensional nature of the subspace, we reduce the degrees of freedom from millions or billions to just dozens or hundreds. This reduction allows us to train with a relatively small number of samples. We, therefore, propose optimizing these coefficients using a small held-out validation set, leading to our variant TWA-v, in contrast to the version using training data, TWA-t. Since validation data are unseen in the training set, they could serve as new data to evaluate generalization capability and have played a crucial role in machine learning. For example, we use the validation set to assess the model checkpoints during the training and select the best one. We now have new ways to better utilize these model checkpoints by appropriately averaging with TWA-v, resulting in improved overall performance.
We demonstrate TWA’s effectiveness in two scenarios through extensive experiments across model architectures and tasks: 1) accelerating training by averaging historical solutions during the head stage of training, and 2) improving generalization by averaging solutions fine-tuned from single or multiple fine-tuning configurations. Our results show over 40% and 30% reduction in training epochs on CIFAR-100 and ImageNet respectively, while maintaining or improving performance (Tables 4, 5). We also demonstrate that TWA can better utilize the fine-tuned models, reducing the number of models required by 4x for reaching comparable performance compared to previous methods (Table 16). Interestingly, we also find that TWA-v exhibits particular effectiveness with transformer architectures.
In summary, our contributions can be summarized as follows:
-
1.
We propose Trainable Weight Averaging (TWA), an efficient approach that enables the averaging of weights with layer-wise learnable coefficients to accelerate training and improve the generalization of DNNs.
-
2.
We design an efficient scheme to handle large-scale problems via subspace training, enabling multi-node parallel training by evenly distributing the memory and computation burden into different nodes. We also devise a compression strategy to further reduce the memory footprint, facilitating effective coefficient optimization.
-
3.
Based on the fact that the number of learnable variables is small, we propose TWA-v that optimizes the coefficients with a small held-out validation set to enable more efficient and effective averaging. Additionally, we discover that TWA-v is particularly effective for transformer-based architectures.
-
4.
We demonstrate the efficiency and effectiveness of TWA through extensive experiments involving various architectures (e.g., CNNs, ViTs, and GPT-2), different tasks (e.g., image classification, machine translation, and language modeling), and multiple training scenarios (from scratch training to fine-tuning) using different optimizers (SGD and Adam).
Comparison with our conference work.
The content of this paper builds upon the previous ICLR 2023 conference version (Li et al., 2023) and includes the following substantial enhancements: 1) Refined Subspace Construction: We have refined the subspace construction by incorporating simple decentralization and normalization techniques. This allows us to avoid heavy and unstable numerical operations in the vector orthogonalization process and makes it more amenable to parallel processing. Additionally, we introduce layer-wise processing to better leverage the model’s structural information (Section 4.1). 2) Quantization of the Projection Matrix: We have developed a low-bit compression method that reduces the memory overhead associated with gradient projection by 8x, facilitating efficient large-scale training (Section 5.3.2); 3) Validation Supervised Optimization: We propose using a small held-out validation set to supervise the optimization of the weighting coefficients, leading to TWA-v. This introduces a new approach to utilizing validation sets during training by learning weighted averages of historical checkpoints, rather than selecting a single best model by checking the validation performance after each training epoch (Section 3). 4) Expanded Evaluation Scope: We have extended the TWA’s application to fine-tuning tasks (Section 5.2), improving fine-tuning performance by averaging weights fine-tuned from different training configurations. We have also conducted experiments across a broader range of tasks and model architectures to demonstrate the efficiency and effectiveness of our approach. Specifically, we evaluate TWA on machine translation (Section 5.1.2) and language modeling (Section 5.2.2) tasks and on ViT (Section 5.1.1) and GPTs (Section 5.2.2). 5) New Findings: Our investigations reveal that TWA-v is particularly effective when applied to transformer-based architectures, which are characterized by minimal inductive bias. This highlights the great potential of TWA-v since transformers are the core architecture for modern AI applications, such as large language models.
2 Related Work
Training neural networks in subspaces has recently become an interesting topic, garnering considerable interest from researchers (Vinyals and Povey, 2012; Gur-Ari et al., 2018; Tuddenham et al., 2020). The pioneering work (Li et al., 2018) first proposed training neural networks in a reduced random subspace to measure the intrinsic dimension of loss objectives. The following work (Gressmann et al., 2020) improved the training performance of random bases by considering the layer-wise structure and re-drawing the random bases at each step. Instead of utilizing random bases, Li et al. (2022a) proposed a low-dimensional trajectory hypothesis and extracted the subspaces from historical training dynamics, dramatically improving the dimensionality efficiency. Then Li et al. (2022b) applied subspace training to adversarial training, effectively addressing the existing catastrophic and robust overfittings, and thereby significantly improving the model robustness performance. In this paper, we reframe trainable weight averaging as a subspace training problem and develop an efficient training scheme that enables multi-node parallel training for large-scale problems, incorporating improved subspace construction procedures.
A lot of efforts have been made to speed up DNNs’ training (Shen et al., 2023). Apart from the well-known optimization methods on adaptive learning rates, e.g., Adam (Kingma and Ba, 2015) or accelerated schemes, e.g., Nesterov momentum (Nesterov, 1983), Zhang et al. (2019) proposed LookAhead optimizer that utilizes the search direction generated by another “fast” optimizer, achieving faster convergence and better learning stability. Goyal et al. (2017) adapted a large mini-batch to speed up the training and introduced a scaling rule for adjusting the learning rates. You et al. (2017, 2020b) proposed a layer-wise adaptive learning rate to further scale the batch size and shorten the training time. Gupta et al. (2019) proposed to use large mini-batches to compute an approximate solution quickly and then refined it by averaging the weights of multiple models computed independently and in parallel to accelerate the training. In this paper, we improve the DNNs’ training efficiency by sufficiently utilizing the historical solutions during the training and conducting training in a subspace with substantially reduced dimensions. In this way, we significantly reduce the required training epochs.
the number of model parameters | |
---|---|
the number of model layers | |
the -th layer | |
the number of model weights | |
model weights | |
model gradient | |
weighting coefficient | |
base vector | |
the coefficients for the base vectors | |
the number of nodes for distributed training | |
loss objective | |
datasets | |
data batch | |
data input | |
data label | |
projection matrix | |
quantized projection matrix | |
quantization bits | |
quantization scaling factor | |
quantization zero factor | |
the subspace containing the weights | |
the subspace containing the weights (layer-wise) | |
learning rate |
Next, we review previous weight-averaging strategies.
Preliminary.
In this paper, we consider a neural network function parameterized by weights with input . The training loss defined over a pair of data point is denoted as (shortened to ). Given a dataset drawn from the data distribution with i.i.d. condition, the empirical loss is defined as , where can be “train”, “val”, or “test”. Note that in this paper, we represent the model’s weights as a vector, i.e., , where is the number of model parameters. The main notations used in this paper are listed in Table 1.
In SWA (Gupta et al., 2019), weight averaging is simply given by
(1) |
where solutions of the network are equally weighted. This strategy has proven effective when all the solutions are already well-optimized, such as in the tail stage of training. However, as a static averaging approach, it is susceptible to suboptimal solutions and may not be suitable for more general scenarios, such as during the head stage of training when model weights are rapidly evolving or for diverse models fine-tuned with different training configurations.
Recently, LAWA (Kaddour, 2022) proposed to apply SWA to a consecutive segment of the most recent weight checkpoints along the training trajectory, i.e.,
(2) |
where is the averaging horizon that has to be pre-set. LAWA has shown effectiveness in accelerating the training of DNNs (Kaddour, 2022; Sanyal et al., 2023).
Besides, EMA (Polyak and Juditsky, 1992) averages the model weights using an exponentially decayed factor , i.e.,
(3) |
EMA typically averages model weights at each iteration.
Greedy soup (Wortsman et al., 2022) improves upon SWA by selectively averaging a subset of models. Specifically, it first sorts the fine-tuned models based on their validation accuracies and then sequentially adds models to the soup if doing so improves the validation performance. Greedy soup has been shown to empirically outperform SWA and is widely adopted by practitioners (Rame et al., 2022; Croce et al., 2023).
Method | Solution set | #Dimension | Inference cost |
---|---|---|---|
SWA | 0 | ||
LAWA | 0 | ||
EMA | 0 | ||
Greedy soup | 0 | ||
TWA (w/o layer-wise) | |||
TWA | |||
Ensemble | 0 |
3 Trainable Weight Averaging
In this paper, we propose trainable weight averaging (TWA), a method that optimizes the weighting coefficients of different model weights to improve performance. Given a set of weights , which can be sampled either from a single training trajectory or collected from multiple fine-tuning configurations, we aim to linearly combine them for creating an improved averaged solution. Specifically, the set of possible TWA solutions we considered, i.e., , can be represented as follows:
(4) |
One can see that the solution set forms a linear space with dimension .
Previous pre-set averaging strategies may fall short for complex weights, potentially leading to averaged solutions that are suboptimal within the subspace . In fact, the solutions obtained by these approaches can be viewed as specific points within . TWA addresses this limitation by optimizing the weighting coefficients to achieve better performance. A comparison between TWA and previous averaging approaches is presented in Table 2.
We now consider how to search for the optimal solution within the subspace . First, we will identify a set of bases to support the solution space such that , where are the coefficients to be optimized. Rather than using directly as bases, we would decouple them to since the weight vectors are often highly correlated, particularly when fine-tuned from the same pre-trained model. Such correlations complicate optimization and can degrade the training performance, as shown in Section 5.3.4. Then we search for in by optimizing with the following problem,
(5) |
The second term serves as a regularization for with a coefficient .
There are two choices for using the available data to optimize Eqn. (5):
-
•
Using the training data, i.e., , while still leveraging the validation data to select the best-performing model during training after each training epoch.
-
•
Using the validation data directly, i.e., , to optimize the weighting coefficients and select the final model after the training is complete.
The latter approach can be more efficient and effective: the rationale behind this choice is that since we only need to optimize a small number of coefficients, a small held-out validation set is sufficient. Given that the weights are unseen to , if performs well on the validation set, it is expected to generalize effectively during testing. Additionally, the validation set can be typically much smaller than the training set, which helps reduce training costs. We denote this validation-based approach as TWA-v, distinguishing it from the training-based approach denoted as TWA-t. This introduces a novel use of validation data by learning optimal weighting coefficients to combine multiple models, rather than selecting a single best-performing model.
4 An Efficient Training Scheme
A straightforward approach to optimizing the coefficients is to construct a computational graph that links the bases and coefficients to the model weights , enabling gradient descent on through standard backpropagation in an end-to-end manner. However, as noted by Wortsman et al. (2022), this approach results in a large computational graph, introducing excessive additional memory overhead and becoming prohibitive as the model size and number of weights increase.
In this paper, we propose a simpler and more efficient method for optimizing the coefficients without constructing an additional computational graph, by leveraging subspace training. We notice that there exists a mapping between the coefficient space and the parameter space :
(6) |
This implies that each set of coefficients uniquely corresponds to a point in the parameter space , thereby forming a subspace of dimensionality . As a result, we can optimize the coefficients by optimizing within this subspace, which is more efficient than performing direct optimization through a computational graph.
We then present our subspace training scheme for efficient trainable weight averaging, which consists of subspace construction and training. To further enhance performance and efficiency, we introduce a layer-wise extension to incorporate model structure, develop a multi-node distributed training approach for large-scale problems, and quantize the projection matrix to reduce memory overhead. Before detailing the training procedure, we first discuss the challenges of averaging multiple model weights.
-
•
Computation. As the model size and the number of models grow, the costs of both subspace construction and subspace training increase accordingly.
-
•
Memory. Subspace training requires storing a matrix for projecting gradients onto the subspace, which can be challenging in large-scale scenarios where the matrix may exceed a single GPU’s memory capacity.
-
•
Balance load. Both subspace construction and training are best performed on GPUs, with the computational load distributed across multiple nodes when available. This significantly reduces training time and improves the scalability of the algorithm.
4.1 Subspace Construction
Given candidate weights , where , we aim to construct a subspace covering these weight points. To achieve this, we first identify a set of bases to represent the subspace. Since the weights are often interdependent—particularly when sampled from a single training trajectory, we avoid directly using them as bases. Leveraging the fact that the high dimensionality of the weight space, which naturally, encourages independence, and the weights originate from a single linear mode111 The concept of linear mode connectivity, as defined by Frankle et al. (2020), refers to the property where, for two sets of weights and , their linear interpolation maintains good performance for all . , we propose to decouple them by performing simple decentralization and normalization, as follows:
(7) |
This results in a set of basis vectors . Next, we optimize the neural network within the subspace . We represent a weight point in this subspace using the variable . It is important to note that the dimension of is , the number of models, which is much smaller than , the dimension of the original parameter space.
Another approach to decouple the model weights is to orthogonalize them using Schmidt orthogonalization (Li et al., 2023). However, this sequential process can be computationally intensive, potentially compromising the accuracy of subspace basis estimation. In contrast, our decoupling method, which involves only decentralization and normalization of the weight vectors without intensive numerical operations, does not enforce strict orthogonalization. Our experiments show that this relaxation does not significantly impact the subspace training performance. In fact, this simplified strategy effectively decouples the weight vectors while being readily applicable to distributed training, resulting in more efficient subspace construction. We will compare this in Section 5.3.1.

4.2 Subspace Training
We then consider how to optimize the variable for subspace training. Let . We parameterize the model weights in subspace as . Then the optimization target in Eqn. (5) can be reformed as follows:
(8) |
The gradient of w.r.t. to the variables can be derived using the chain rule:
(9) |
Thus, we can first calculate the model gradient through standard forward and backward propagation, and then project the gradient onto the bases to obtain the gradient w.r.t. . This projection ensures that the actual update to the model weights remains within the subspace . Together with the regularization, the gradient descent update rule for is then given by:
(10) |
where is the learning rate. For the corresponding TWA solution, we have:
(11) |
Note that the gradient projection is achieved by applying the projection matrix .
Here subspace training involves an additional gradient projection operation in the optimization step, compared to regular training. This operation is a matrix multiplication, which can be efficiently executed on GPUs. As a result, the training speed of subspace training can be nearly the same as that of regular training. We compare the training speed of the two methods in detail in Section 5.3.2, where we find that TWA training is only slightly slower than regular training.
4.3 Initialization
An important consideration in TWA training is the choice of initialization. In this paper, we use the SWA solution as the initialization, i.e., . This approach is advantageous because SWA typically delivers good performance, reducing the effort required for TWA training, and it avoids additional computation to evaluate the performance of individual model .
In practice, we set and use the following relation relationship:
(12) |
where and is the number of iteration. During the training, we only need to maintain in the memory.
4.4 Distributed Training Scheme
During subspace optimization, TWA requires the storage of the projection matrix , which has a size of . This poses a storage challenge for large models, as is ideally stored on the GPU to facilitate efficient matrix operations. As the model size and the number of weight points grow, storing in a single GPU becomes impractical.
To cope with this, we design an efficient scheme with parallel distributed training to enable the following:
-
•
Partitioning the memory burden of into multiple GPUs. We partition the projection matrix into multiple submatrices, each of which can be stored in a separate GPU. This reduces the memory burden on each GPU and allows us to train larger models with scalability.
-
•
Efficient parallel computation of gradient projection. We propose a distributed algorithm for computing the gradient projection in subspace optimization, which evenly distributes the computation load across multiple GPUs. It significantly reduces the time required for the gradient projection.
More specifically, suppose that there are GPUs for multi-node parallel training. We will uniformly divide into sub-matrices as . Each GPU stores a local sub-matrix for . Recall that for an iteration in distributed training, each GPU computes a local gradient and synchronizes it with other GPUs to obtain the global gradient through an efficient all-reduce operation (Rabenseifner, 2004). We mimic such a process for gradient projection. First, we perform an all-reduce operation such that the local gradient at each GPU is synchronized to . We then obtain local gradient projection . Finally, we synchronize this to with another all-reduce operation. The distributed approach is mathematically equivalent to the original gradient projection for the multiplication of the block matrix:
(13) |
Note that in this way, the computation for matrix multiplication is also uniformly divided into different nodes. We illustrate such a process in Figure 1.
4.5 Layer-wise Processing
In the above, we have treated the model weights as a whole, where the layer-wise structure is not explored. It would be helpful to take into account the layer-wise information for more delicate subspace training.
Suppose the network is composed of layers. To incorporate layer-wise information, we propose constructing subspaces for each layer. We start by partitioning the concatenated weights into groups that correspond to each layer, i.e., . The augmented subspace we consider is
(14) |
We then normalize and decentralize each group of weights to produce a set of basis vectors corresponding to each layer following Eqn. (7):
(15) |
Finally, we optimize the variables associated with each layer individually using the update rule as in Eqn. (10). This layer-wise approach enables us to delicately fine-tune the subspaces for each layer, which can result in improved performance.
In our main experiments, we process model weights layer-wise by default, as this approach yields better performance and enables effective quantization of the projection matrix, thereby reducing memory usage during gradient projection. We present an ablation study on layer-wise processing in Section 5.3.3.
4.6 Low-Bit Quantization of the Projection Matrix
Since the weights to be averaged are sampled from a single training trajectory or linear mode, it can be expected that the resulted projection matrix contains redundant information and can be further compressed. To further reduce the memory burden introduced by the projection matrix , we resort to quantization, a widely used technique that is shown effective in reducing the memory burden and speeding up the training (Shen et al., 2020; Dettmers et al., 2022, 2024; Zhang et al., 2024).
32 | ||
---|---|---|
We discover that the projection matrix can be efficiently quantized into low-bit representation . In this paper, we employ simple uniform min-max quantization, and more delicate methods can further improve the performance. Mathematically, given the bit width and model layer number , is quantized into by computing:
(16) |
where and are the scaling and zero factors, respectively; denotes the integer rounding operation. In this way, all elements in are mapped to the set and thus stored as -bits integers. We present the full algorithm steps of TWA in Algorithm 1.
For memory savings, we only need to store the projection matrices as -bits integers, along with the compression factors and , which are negligible. Compared to the original 32-bit format, we achieve a compression rate of . Note that the quantization technique can be incorporated into the distributed training scheme to further reduce the memory burden. A detailed analysis of the memory requirements for the projection matrix is presented in Table 3. In this paper, we default for TWA, and an ablation of the quantization performance is presented in Section 5.3.6.
5 Numerical Experiments
In this section, we conduct numerical experiments on various computer vision and natural language processing tasks to demonstrate the efficiency and effectiveness of our proposed TWA approach. First, we show that TWA can significantly accelerate the training of DNNs by averaging historical solutions from the head stage of training, especially for transformer-based architectures. Next, we demonstrate that TWA significantly improves the performance of fine-tuned models across both single and multiple training configurations. Finally, we present ablation studies to further analyze the properties of TWA training.
5.1 Accelerating Neural Network Training
We first apply TWA to the head stage of training to accelerate the training of DNNs. In the head training stage, model weights are fast-evolving, and equal averaging like SWA may not be enough and usually fails. Since TWA could adaptively adjust the weighting coefficients and reduce the estimation variance, it can be expected to work well in this stage and yield better performance. If so, it is promising to simultaneously attain generalization improvements and training efficiency. We conduct experiments over two representative computer vision and neural language processing tasks, i.e., image classification and machine translation, to evaluate the efficiency and effectiveness of the TWA scheme.
5.1.1 Image Classification
Setting.
We experiment over two benchmark image classification datasets, i.e., CIFAR-100 (Krizhevsky and Hinton, 2009) and ImageNet (Deng et al., 2009). Following Izmailov et al. (2018); Yang et al. (2019), we apply standard data preprocessing for experiments on CIFAR datasets and adopt the preprocessing and data augmentation procedures in the public Pytorch example on ImageNet (Paszke et al., 2017). We use three representative architectures, VGG-16 (Simonyan and Zisserman, 2014), ResNet-18 (He et al., 2016) and ViT-S/4 (Dosovitskiy et al., 2021) on CIFAR experiments. For ImageNet, we use ResNet-18/50 (He et al., 2016), ViT-S/32, ViT-B/16 (Dosovitskiy et al., 2021).
For CIFAR training, we adopt a standard training protocol with a step-wise learning rate schedule. We run all experiments with three random seeds and report the mean test accuracy and standard deviation. We use SGD optimizer with momentum , weight decay , and batch size . We train the models for epochs with an initial learning rate and decay it by at the 100th and the 150th epochs. For ViT-S/4, we use AdamW as the base optimizer and train for 200 epochs with an initial learning rate of 0.001, weight decay of 0.1, and a cosine learning rate schedule. For ImageNet training, we follow official PyTorch implementation222The implementation is available at: https://github.com/pytorch/examples/tree/main/imagenet.. We randomly split out 10% and 2% of training data for CIFAR-100 and ImageNet, respectively, as validation sets. The validation data is used to select the best-performing model for base training and TWA and serves as supervision for optimizing the weighting coefficients in TWA-v. For TWA, we sample solutions once after each epoch training for CIFAR and ImageNet. We apply the base optimizer for TWA training. We list the detailed training settings for the CIFAR experiments as follows. More training details can be found in Appendix A.
-
•
Train the model for 200 epochs and use the validation set to select the best model.
-
•
SWA: Average the model checkpoints from the first 100 epochs.
-
•
TWA-t: Apply TWA to model solutions from the first 100 epochs, optimize coefficients using training data for 10 epochs, and select the optimal model based on validation performance.
-
•
TWA-v: Apply TWA to model solutions from the first 100 epochs, optimize using validation data for 10 epochs, and utilize the final model directly.
For the ImageNet experiments, we maintain similar training settings while varying the training epochs. For instance, the notation “” denotes that TWA is applied to model solutions collected from the first 60 epochs, followed by 2 additional epochs to optimize the coefficients using training data.
Base ( epochs) | SWA ( epochs) | TWA-t ( epochs) | TWA-v ( epochs) | ||||
---|---|---|---|---|---|---|---|
Model | Accuracy | Gap | Accuracy | Accuracy | Gap | Accuracy | Gap |
VGG-16 | 21.63 () | () | |||||
ResNet-18 | () | () | |||||
ViT-S/4 | 41.10 | 41.05 () | 30.13 () |
Results.
We first investigate the experiments on CIFAR datasets. The base training schedule contains 200 epochs and we take the first 100 epochs’ explorations for TWA. The results are given in Table 4. It can be observed that TWA achieves comparable or even better performance compared to regular SGD training with a significant reduction in the generalization gap. For instance, TWA-v attains accuracy improvement on CIFAR-100 with VGG-16, while the generalization gap is reduced by . This suggests that a better solution could already be composed by weighted averaging these historical solutions without further training at finer learning rates, which may otherwise lead to overfitting problems and harm generalization. In comparison, we also apply SWA to average these samples, which shows degraded performance due to the existence of estimation error. Compared to TWA-t, TWA-v achieves much better performance in both test accuracy and generalization gap, confirming the effectiveness of using the validation set to optimize the coefficients. Notably, such improvement is more obvious on ViT-S/4, where TWA-v outperforms TWA-t by 1.93%.
For ImageNet, the effort required for training each epoch is significantly greater, making efficient methods to reduce the number of training epochs highly desirable. The comparison results of SGD/SWA/TWA are presented in Table 5. Beyond narrowing the generalization gap between training and test data, TWA achieves performance comparable to, or even better than, standard SGD training with 90 epochs by averaging the historical solutions from the first 60 epochs. For comparison, Lookahead (Zhang et al., 2019) is another advanced optimizer recently proposed for improving convergence and reported accuracy at the 60th epoch (Table 2 in Zhang et al. (2019)) with an aggressive learning rate decay (i.e., the learning rate is decayed at the 30th, 48th, and 58th epochs). In contrast, our TWA-v method achieves accuracy with the same budget, simply employing conventional decay. Again, This improvement is particularly pronounced with ViT architectures, where TWA-v can achieve a 3.23% accuracy increase with ViT-B/16 and a 2.32% increase with ViT-S/32 compared to standard AdamW training while requiring over 30% fewer training epochs. In fact, TWA-v, with approximately 60 epochs of training, outperforms AdamW training based on a 300-epoch baseline (e.g., 74.6% with ViT-B/16) as reported in Chen et al. (2022). This shows the great potential of TWA-v for accelerating the training of transformer-based architectures.
Base ( epochs) | SWA ( epochs) | TWA-t ( epochs) | TWA-v ( epochs) | ||||
Model | Accuracy | Gap | Accuracy | Accuracy | Gap | Accuracy | Gap |
ResNet-18 | 69.53 | -0.45 | 64.03 | 69.42 | () | 69.66 | () |
ResNet-50 | 67.66 | () | () | ||||
ViT-S/32 | 66.95 | 18.64 | 62.45 | 67.39 | 10.62 () | 69.27 | 6.31 () |
ViT-B/16 | 72.79 | 20.07 | 67.12 | 73.20 | 14.21 () | 76.02 | 7.67 () |
Comparison of training time.
In Table 6, we compare the training time cost of different methods. For TWA training, the training time is comprised of two components: Stage 1 training for collecting historical training and Stage 2 training for TWA training. Apart from the good performance in accuracy, TWA-v achieves around 45% savings in total training time on CIFAR-100 and 32% on ImageNet. Compared to TWA-t, TWA-v can further reduce the Stage 2 training time by optimizing on the much smaller validation set, while enabling weight averaging with much better generalization performance.
Datasets | Model | Method | Test Accuracy (%) | Stage 1 | Stage 2 | Total | Base% |
---|---|---|---|---|---|---|---|
CIFAR-100 | ResNet-18 | Base (SGD) | 73.45 | 0.51h | - | 0.51h | 100% |
TWA-t | 74.03 | 0.26h | 0.09h | 0.35h | 69% | ||
TWA-v | 74.78 | 0.26h | 0.02h | 0.28h | 55% | ||
ImageNet | ViT-B/16 | Base (AdamW) | 72.79 | 51.8h | - | 51.8h | 100% |
TWA-t | 73.08 | 34.5h | 3.0h | 37.5h | 72% | ||
TWA-v | 76.02 | 34.5h | 0.5h | 35.0h | 68% |
Comparison with LAWA.
LAWA (Kaddour, 2022) is a recently proposed method for accelerating the training of DNNs by averaging the latest weights checkpoints. We compare LAWA with TWA across different averaging epochs. We sample model weights once after each epoch of training for LAWA, maintaining consistency with TWA. We set the horizon for ImageNet as suggested by Kaddour (2022), and the accuracy reached by AdamW before averaging is also given for reference. The results are presented in Figure 2.
Generally, utilizing more epochs of explorations can provide a better estimation for the final minimum, and it could be observed that the model’s performance is consistently improved with more epochs of explorations. Notably, although each historical solution in a relatively short period of explorations by AdamW may not be good, satisfactory solutions have already emerged in the subspace spanned by these solutions. Through proper optimization within this subspace, TWA can identify them. For instance, on ImageNet with the ViT-B/16 model, averaging over 30 epochs via TWA matches the final performance of regular AdamW training. Moreover, TWA consistently outperforms LAWA across different averaging epochs, and the advantage of TWA-v over both TWA-t and LAWA becomes more pronounced as the training epochs increase. Notably, in the later stages of training, the performance of TWA-t and LAWA tends to decline, while TWA-v continues to improve, highlighting the effectiveness of validation loss supervision.

Comparison with training without splitting validation set.
One might be concerned that splitting out a validation set could reduce the available training data and thus harm the performance of base training. To investigate this, we compare with training without a validation set, i.e., using 100% of the training data for model training and evaluating with the last model. The results are shown in Table 7. We observe that indeed splitting out a small validation set slightly degrades the performance of base training compared to training without splitting a validation set. However, such degradation can be made up by applying TWA-v to average over the entire training trajectory—and can even achieve better results, e.g. by +0.50% with ResNet-18 and +3.14% on ViT-B/16. This provides a new perspective on utilizing available training data for DNNs’ training: instead of using the entire dataset to train a single model, we can allocate a portion of the training data for ensembling historical solutions to achieve better performance.
Model | Method | Validation Ratio (%) | Test Accuracy (%) |
---|---|---|---|
ResNet-18 | Base (SGD) | w/o splitting | 69.82 |
Base (SGD) | 2 | 69.53 | |
TWA-v | 2 | 70.32 | |
ViT-B/16 | Base (AdamW) | w/o splitting | 73.0 |
Base (AdamW) | 2 | 72.79 | |
TWA-v | 2 | 76.14 |
5.1.2 Machine Translation
Settings.
In this study, we train a transformer based model (Vaswani et al., 2017) to perform English-to-German translation on WMT2014 dataset (Bojar et al., 2014). The size of the embedding layer is set to 512. Similar to Vaswani et al. (2017), we utilize byte-pair encoding and construct a common vocabulary of 32,000 tokens. We use Adam optimizer with a weight decay of 0.0001, an initial learning rate of 0.0005, a batch size of 256, and a dropout rate of 0.1. We train the model for 100k steps using a ReduceLROnPlateau schedule following Bisla et al. (2022). For TWA, we sample the weights once per 1,000 steps and take for a total of the first 50 model checkpoints (corresponding to the initial 50k steps of training) to achieve training efficiency. We then train for a total 1000 steps with Adam optimizer and a constant learning rate 0.01. For TWA-v, the validation data are utilized for optimizing the coefficients, whereas for other methods, the best-performing model is selected based on the validation loss.
Results.
In Table 8, we report the BLEU scores of test data for different methods. We observe that applying TWA to the historical solutions from the first 50k iterations has already surpassed the performance of regular Adam training with 100k steps. This will lead to a significant time/computation saving of around 50% compared to regular training. Moreover, TWA-v can further improve the test BLUE score by 0.18 over TWA-t. We report the generalization gap between the training and test data on the BLUE score metric as well, and the results confirm the effectiveness of TWA in improving generalization.
Optimizer | Steps | Test BLUE Scores | Gap |
---|---|---|---|
Adam | 100k(100%) | ||
Adam | 50k (50%) | - | |
SWA | 50k (50%) | - | |
TWA-t | 50k+1k (51%) | () | |
TWA-v | 50k+1k (51%) | () |
5.2 Better Fine-tuning Performance
Next, we turn our attention to fine-tuning tasks. In our previous experiments, we focused on a single training run and aimed to achieve training efficiency by sufficiently utilizing the historical solutions generated during the training process. In this section, we seek to improve the model performance by leveraging the flexibility and good generalization ability of our trainable weight averaging scheme. Specifically, we aim to enhance model performance by fully utilizing the solutions from single or multiple fine-tuning configurations. Previously, greedy soup (Wortsman et al., 2022) demonstrated state-of-the-art performance by adding weights to the weight-averaging “soup” in a greedy manner. However, TWA has the potential to provide even better performance by leveraging trainable weighting coefficients, rather than relying on a fixed averaging strategy.
For fine-tuning tasks, the models to be averaged have generally adapted well to training data, as evidenced by their low training loss. Therefore, training them would provide little meaningful supervision. Motivated by Wortsman et al. (2022), we directly train over the held-out validation set. This makes sense because the validation data are “new” to these models, and subspace training involves a small number of independent variables, so a small validation set is expected to be sufficient to well train them. In this section, we conduct experiments on image classification and language modeling tasks to demonstrate the superiority of our scheme.
5.2.1 Image Classification
Setting.
We conduct experiments on CIFAR-10 (Krizhevsky and Hinton, 2009) and ImageNet (Deng et al., 2009) datasets using pretrained CLIP ViT-B/32 model (Radford et al., 2021). We utilize the publicly available fine-tuned checkpoints from Wortsman et al. (2022), which include 5 models for CIFAR-10 and 72 models for ImageNet333The model checkpoints are available at https://github.com/mlfoundations/model-soups.. These models are obtained by a random hyperparameter search over learning rate, weight decay, training epochs, label smoothing, and data augmentation. For TWA, we train the models using AdamW optimizer with a learning rate of 0.01 and a cosine learning rate schedule. For TWA, we train for 2 epochs on the training datasets, selecting the best model based on the held-out validation set, and for TWA-v, we directly train for 5 epochs on the validation set and report the final performance on the test set.
Results.
We present the results for CIFAR-10 and ImageNet datasets in Table 9 and Table 10, respectively. We compare our TWA with SWA (Izmailov et al., 2018) and greedy soup (Wortsman et al., 2022), and also list the performance of the best and second-best individual models for reference. For CIFAR-10, we observe that both SWA and greedy soup can significantly improve performance compared to the best individual model, confirming the effectiveness of weight averaging. However, our TWA can further improve the performance of greedy soup: TWA achieves an additional improvement, while TWA-v provides a further gain. Note that for CIFAR-10, the performance is already very high, making further improvements challenging. Our enhancements represent and increases over the performance gains of greedy soup relative to SWA.
Method | Test Accuracy (%) |
---|---|
Best individual model | |
Second best individual model | |
SWA | |
Greedy soup | |
TWA-t | |
TWA-v |
We then focus on the results of ImageNet datasets in Table 10, which involves a total of 72 fine-tuned models. Equal averaging, as done with SWA, does not yield any performance improvement compared to the best individual model and, in fact, results in a slight degradation. This occurs because the model configurations are diverse, and equal averaging can be significantly influenced by poorer solutions, leading to estimation errors. Greedy soup effectively selects an optimal subset of weights for averaging, providing an accuracy gain of 0.65% over the best individual model. Our TWA-v delivers further significant improvements, i.e., over greedy soup, respectively, confirming the superiority of optimizing the weighting coefficients. We can also observe that TWA-v offers greater advantages over TWA-t, as models in fine-tuning scenarios are often overfitted, and the training loss provides only minimal supervision.
Method | Test Accuracy (%) |
---|---|
Best individual model | |
Second best individual model | |
SWA | |
Greedy soup | |
TWA-t | |
TWA-v |
5.2.2 Language Modeling
Setting.
TWA can enhance fine-tuning performance not only by utilizing solutions fine-tuned with multiple configurations but also by leveraging solutions from a single training trajectory. In this experiment, we fine-tune a pre-trained language model GPT-2 (Radford et al., 2019) for causal language modeling tasks. Specifically, we use the raw WikiText-2 datasets (Merity et al., 2016) (no tokens are replaced before tokenization) following the official HuggingFace Implementation444The implementation is available at https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling.. The model is trained using AdamW optimizer for 1,000 steps, with a learning rate of 5e-5, batch size of 8, and a linear learning rate schedule. During the training, we sample the weights once per 100 steps for weight averaging, resulting in 10 model checkpoints. For TWA, we train for 100 steps using AdamW optimizer with a learning rate of 0.01 and a weight decay of 0.01.
Results.
We report the perplexity score of different methods in Table 11. For comparison, we consider the best/second best individual models among the 10 collected weights, as well as SWA and greedy soup solutions. We observe that again, TWA-v significantly improves the performance of SWA and greedy soup by 0.30 and 0.19 in perplexity, respectively. It is worth noting that the fine-tuned models in this experiment are derived from a single training run, which limits the performance gains compared to Section 5.2.1. Nevertheless, these results confirm that TWA can be effectively applied to fine-tuning tasks, whether under single or multiple training configurations.
Method | Perplexity () |
---|---|
Best individual model | |
Second best individual model | |
SWA | |
Greedy soup | |
TWA-t | |
TWA-v |
5.3 Ablation Study
5.3.1 Subspace Construction
We compare the execution time of different subspace construction methods in Table 12. Specifically, we experiment with the ResNet-50 model on ImageNet and a total of 60 historical solutions as specified in Section 5.1.1. We configure the number of GPUs to four. Our results show that compared with the orthogonalization approach in the previous conference version (Li et al., 2023), the decentralization and normalization approach dramatically reduces the extraction time by orders of magnitude, resulting in significant savings in floating-point numerical operations. This improvement is due to the orthogonalization approach being a sequential process—obtaining orthogonal vectors one by one—while our decentralization and normalization method can be efficiently executed in parallel.
Extraction Method | Time (seconds) |
---|---|
Orthogonalization (Li et al., 2023) | 382.6 |
Decentralization and normalization (ours) | 6.3 |
Time (seconds) | |||
---|---|---|---|
#GPUs | SGD | TWA (w/o quantization) | TWA |
1 | 1638 | 1692 (+3.3%) | 1725 (+5.3%) |
2 | 824 | 862 (+4.6%) | 865 (+5.0%) |
4 | 420 | 432 (+2.8%) | 435 (+3.5%) |
Memory (MB) | |||
#GPUs | SGD | TWA (w/o quantization) | TWA |
1 | 20287 | 26287 (+29.6%) | 21037 (+3.7%) |
2 | 20383 | 23383 (+14.7%) | 20758 (+1.9%) |
4 | 20875 | 22375 (+7.2%) | 21062 (+1.0%) |
5.3.2 Training Speed and Memory
We numerically measure the averaged epoch training time and memory burden for SGD and TWA under the DDP training setting. Specifically, we experiment with the ResNet-50 model on ImageNet, using 1, 2, and 4 GPUs with a batch size of 256 per GPU, and utilize a total of 60 historical solutions as specified in Section 5.1.1. For TWA, we select the TWA-t version that optimizes over the training set, keeping the same as SGD. The experiments are conducted on NVIDIA Tesla A100 40G GPUs. From the results reported in Table 13, we observe that TWA introduces minimal additional costs, e.g., +5.3% on time cost and +3.7% on memory burden with one GPU, compared with regular SGD training. With more GPUs, the additional memory burden can be further reduced to +1.0%, which shows the effectiveness of our distributed training scheme. Then compared to TWA without quantization on the projection matrix, we observe that the additional memory overhead is effectively reduced by around 8x with our 4-bit quantization scheme, though it slightly increases the training time (e.g., +2%). In fact, the gradient projection in our subspace training scheme incurs negligible extra training time compared to regular SGD, as matrix multiplication operations are highly efficient on GPUs. Overall, these results demonstrate that TWA provides an efficient and scalable weighted averaging approach for large-scale problems.
5.3.3 Effects of Layer-wise Processing
In this subsection, we study the effects of layer-wise processing. Table 14 compares the performance of TWA-v with and without layer-wise processing and projection matrix quantization. The results show that layer-wise processing significantly improves test performance—by 0.53% without quantization and more substantially by 2.08% with quantization. These findings indicate that layer-wise processing enables more effective quantization of the projection matrix without compromising performance, likely due to the varying magnitude of values across different layers.
Method | Test Accuracy (%) |
---|---|
TWA-v | 81.61 |
TWA-v+w/o quantization | 81.67 |
TWA-v+w/o layer-wise | 79.53 |
TWA-v+w/o layer-wise+w/o quantization | 81.14 |
5.3.4 Effects of Weight Decoupling
We then evaluate the performance of different subspace construction methods. From the results in Table 15, we observe that without applying any orthogonalization methods to the weights , the cosine similarity between them remains very high (e.g., 99%), which degrades training performance. While orthogonalization effectively reduces correlations between the bases, its performance is suboptimal, possibly due to the numerical errors introduced by the sequential orthogonalization process. In contrast, our proposed decentralization and normalization approach effectively reduces the correlations between the bases while introducing fewer numerical operations, achieving the best performance.
Method | Correlations | Test Accuracy (%) |
---|---|---|
W/o orthogonalization | 0.99 | 81.02 |
Orthogonalization | 0 | 81.35 |
Decentralization and normalization | 0.19 | 81.61 |
5.3.5 Effects of Fine-tuned Models’ Number
We vary the number of fine-tuned models in Table 10 and study its impacts on TWA performance. The results are in Table 16. As one could expect, the performance of all methods is consistently increasing with more fine-tuned models. However, TWA-v demonstrates significantly higher efficiency in utilizing fine-tuned models compared to competing methods. For example, TWA-v achieves comparable performance to greedy soup with only 18 models, while greedy soup requires 72 models. Furthermore, with 36 models, TWA-v surpasses its greedy soup counterpart with 72 models by a notable margin of 0.36%. These results highlight that TWA-v can significantly reduce computational costs in fine-tuning various models—by up to 4x—while achieving comparable performance to the greedy soup.
Test Accuracy (%) | |||
---|---|---|---|
Method | N=18 | N=36 | N=72 |
Best individual model | |||
SWA | |||
Greedy soup | |||
TWA-t | |||
TWA-v |
5.3.6 Effects of Quantization Bits
In this section, we study the impact of quantization bits on the TWA-v performance. We experiment over two scenarios: accelerating training by averaging historical solutions, where we use CIFAR-100 with ViT-S/4, and better fine-tuning setting where we use ImageNet with CLIP ViT-B/32. The original projection matrix is stored as float32 defaulted as model parameters, and we quantize it to different bit levels among . The results are shown in Figure 3. We observe that overall, TWA-v can achieve good performance even with 1-bit quantization, demonstrating the intriguing quantization properties of the projection matrix. Moreover, with 4-bit quantization, TWA-v achieves performance comparable to full-precision training, which we adopt as the default setting for our usage.


6 Conclusion
In this work, we propose Trainable Weight Averaging (TWA), an efficient framework that enables weight averaging with learnable weighting coefficients. It extends the manually defined weighting coefficients as in previous works to a trainable manner, which endows with much greater flexibility and enables handling weights from different stages and configurations. We design an efficient parallel training scheme to cope with large-scale training and propose quantization schemes for the projection matrix to achieve memory efficiency. Additionally, we derive two variants, TWA-t and TWA-v, based on the data used for training, and show that TWA-v allows for more efficient and effective averaging when validation data is available. Extensive experiments on both efficient training and fine-tuning tasks demonstrate the effectiveness and efficiency of our approach.
Limitation and Future Works.
Although TWA can significantly enhance training efficiency and improve performance for DNN’s training by effectively leveraging historical solutions, the selection of these solutions can influence overall performance, as illustrated in Figure 2. In practice, the specific strategy for selecting historical solutions can be predetermined in conjunction with the original training schedule for standard training. Currently, TWA can only manage weights from a single linear model; however, it holds promise for extending algorithms to merge weights from different linear models or even from different architectures. Furthermore, despite the relatively small number of training variables, there remains a risk of overfitting, which can potentially degrade generalization performance.
There are many promising future directions for TWA, both in terms of practical applications and theoretical analysis. These include: 1) Combining TWA with other lightweight methods, such as LoRA (Hu et al., 2022), to facilitate more efficient adapter training and better fine-tuning performance; 2) Applying TWA to other precision-crucial scenarios, such as low precision network training (Yang et al., 2019) and network quantization (Zhou et al., 2017); 3) Understanding the relationship between the number of held-out samples required for subspace training and the dimension of the subspace, as well as its impact on generalization performance; and 4) Exploring advanced quantization schemes for compressing the projection matrix to enable more efficient TWA training.
References
- Bisla et al. (2022) Devansh Bisla, Jing Wang, and Anna Choromanska. Low-pass filtering sgd for recovering flat optima in the deep learning optimization landscape. In International Conference on Artificial Intelligence and Statistics (AISTATIS), 2022.
- Bojar et al. (2014) Ondřej Bojar, Christian Buck, Christian Federmann, Barry Haddow, Philipp Koehn, Johannes Leveling, Christof Monz, Pavel Pecina, Matt Post, Herve Saint-Amand, et al. Findings of the 2014 workshop on statistical machine translation. In Proceedings of the Ninth Workshop on Statistical Machine Translation, 2014.
- Chen et al. (2022) Xiangning Chen, Cho-Jui Hsieh, and Boqing Gong. When vision transformers outperform resnets without pre-training or strong data augmentations. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=LtKcMgGOeLt.
- Croce et al. (2023) Francesco Croce, Sylvestre-Alvise Rebuffi, Evan Shelhamer, and Sven Gowal. Seasoning model soups for robustness to adversarial and natural distribution shifts. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2023.
- Deng et al. (2009) Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 248–255, 2009.
- Dettmers et al. (2022) Tim Dettmers, Mike Lewis, Younes Belkada, and Luke Zettlemoyer. Gpt3. int8 (): 8-bit matrix multiplication for transformers at scale. Advances in Neural Information Processing Systems, 35:30318–30332, 2022.
- Dettmers et al. (2024) Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, and Luke Zettlemoyer. Qlora: Efficient finetuning of quantized llms. Advances in Neural Information Processing Systems, 36, 2024.
- Dosovitskiy et al. (2021) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations (ICLR), 2021.
- Frankle et al. (2020) Jonathan Frankle, Gintare Karolina Dziugaite, Daniel Roy, and Michael Carbin. Linear mode connectivity and the lottery ticket hypothesis. In International Conference on Machine Learning (ICML), 2020.
- Goyal et al. (2017) Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677, 2017.
- Gressmann et al. (2020) Frithjof Gressmann, Zach Eaton-Rosen, and Carlo Luschi. Improving neural network training in low dimensional random bases. In Advances in Neural Information Processing Systems (NeurIPS), 2020.
- Gupta et al. (2019) Vipul Gupta, Santiago Akle Serrano, and Dennis DeCoste. Stochastic weight averaging in parallel: Large-batch training that generalizes well. In International Conference on Learning Representations (ICLR), 2019.
- Gur-Ari et al. (2018) Guy Gur-Ari, Daniel A Roberts, and Ethan Dyer. Gradient descent happens in a tiny subspace. arXiv preprint arXiv:1812.04754, 2018.
- He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pages 770–778, 2016.
- Hu et al. (2022) Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. LoRA: Low-rank adaptation of large language models. In International Conference on Learning Representations (ICLR), 2022.
- Izmailov et al. (2018) Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Averaging weights leads to wider optima and better generalization. arXiv preprint arXiv:1803.05407, 2018.
- Kaddour (2022) Jean Kaddour. Stop wasting my time! saving days of imagenet and bert training with latest weight averaging. arXiv preprint arXiv:2209.14981, 2022.
- Kingma and Ba (2015) Diederik P. Kingma and Jimmy Lei Ba. Adam: A method for stochastic optimization. In International Conference on Learning Representations (ICLR), 2015.
- Krizhevsky and Hinton (2009) Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. Technical Report, 2009.
- Li et al. (2018) Chunyuan Li, Heerad Farkhoor, Rosanne Liu, and Jason Yosinski. Measuring the intrinsic dimension of objective landscapes. In International Conference on Learning Representations (ICLR), 2018.
- Li et al. (2020) Shen Li, Yanli Zhao, Rohan Varma, Omkar Salpekar, Pieter Noordhuis, Teng Li, Adam Paszke, Jeff Smith, Brian Vaughan, Pritam Damania, et al. Pytorch distributed: Experiences on accelerating data parallel training. arXiv preprint arXiv:2006.15704, 2020.
- Li et al. (2022a) Tao Li, Lei Tan, Zhehao Huang, Qinghua Tao, Yipeng Liu, and Xiaolin Huang. Low dimensional trajectory hypothesis is true: Dnns can be trained in tiny subspaces. IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2022a.
- Li et al. (2022b) Tao Li, Yingwen Wu, Sizhe Chen, Kun Fang, and Xiaolin Huang. Subspace adversarial training. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022b.
- Li et al. (2023) Tao Li, Zhehao Huang, Qinghua Tao, Yingwen Wu, and Xiaolin Huang. Trainable weight averaging: Efficient training by optimizing historical solutions. In International Conference on Learning Representations (ICLR), 2023.
- Merity et al. (2016) Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
- Nesterov (1983) Yurii E Nesterov. A method for solving the convex programming problem with convergence rate o (1/k^ 2). In Dokl. akad. nauk Sssr, volume 269, pages 543–547, 1983.
- Paszke et al. (2017) Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. 2017.
- Polyak and Juditsky (1992) Boris T Polyak and Anatoli B Juditsky. Acceleration of stochastic approximation by averaging. SIAM journal on control and optimization, 30(4):838–855, 1992.
- Rabenseifner (2004) Rolf Rabenseifner. Optimization of collective reduction operations. In International Conference on Computational Science, pages 1–9. Springer, 2004.
- Radford et al. (2019) Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
- Radford et al. (2021) Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International Conference on Machine Learning (ICML), 2021.
- Rame et al. (2022) Alexandre Rame, Matthieu Kirchmeyer, Thibaud Rahier, Alain Rakotomamonjy, Patrick Gallinari, and Matthieu Cord. Diverse weight averaging for out-of-distribution generalization. In Advances in Neural Information Processing Systems (NeurIPS), 2022.
- Sanyal et al. (2023) Sunny Sanyal, Atula Neerkaje, Jean Kaddour, Abhishek Kumar, and Sujay Sanghavi. Early weight averaging meets high learning rates for llm pre-training. arXiv preprint arXiv:2306.03241, 2023.
- Shen et al. (2023) Li Shen, Yan Sun, Zhiyuan Yu, Liang Ding, Xinmei Tian, and Dacheng Tao. On efficient training of large-scale deep learning models: A literature review. arXiv preprint arXiv:2304.03589, 2023.
- Shen et al. (2020) Sheng Shen, Zhen Dong, Jiayu Ye, Linjian Ma, Zhewei Yao, Amir Gholami, Michael W Mahoney, and Kurt Keutzer. Q-bert: Hessian based ultra low precision quantization of bert. In Proceedings of the AAAI Conference on Artificial Intelligence (AAAI), 2020.
- Simonyan and Zisserman (2014) Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
- Tuddenham et al. (2020) Mark Tuddenham, Adam Prügel-Bennett, and Jonathan Hare. Quasi-newton’s method in the class gradient defined high-curvature subspace. arXiv preprint arXiv:2012.01938, 2020.
- Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
- Vinyals and Povey (2012) Oriol Vinyals and Daniel Povey. Krylov subspace descent for deep learning. In Artificial Intelligence and Statistics (AISTATS), pages 1261–1268. PMLR, 2012.
- Wortsman et al. (2022) Mitchell Wortsman, Gabriel Ilharco, Samir Ya Gadre, Rebecca Roelofs, Raphael Gontijo-Lopes, Ari S Morcos, Hongseok Namkoong, Ali Farhadi, Yair Carmon, Simon Kornblith, et al. Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time. In International Conference on Machine Learning (ICML), 2022.
- Yang et al. (2019) Guandao Yang, Tianyi Zhang, Polina Kirichenko, Junwen Bai, Andrew Gordon Wilson, and Chris De Sa. Swalp: Stochastic weight averaging in low precision training. In International Conference on Machine Learning (ICML). PMLR, 2019.
- You et al. (2020a) Haoran You, Chaojian Li, Pengfei Xu, Yonggan Fu, Yue Wang, Xiaohan Chen, Richard G Baraniuk, Zhangyang Wang, and Yingyan Lin. Drawing early-bird tickets: Towards more efficient training of deep networks. In International Conference on Learning Representations (ICLR), 2020a.
- You et al. (2017) Yang You, Igor Gitman, and Boris Ginsburg. Large batch training of convolutional networks. arXiv preprint arXiv:1708.03888, 2017.
- You et al. (2020b) Yang You, Jing Li, Sashank Reddi, Jonathan Hseu, Sanjiv Kumar, Srinadh Bhojanapalli, Xiaodan Song, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh. Large batch optimization for deep learning: Training bert in 76 minutes. In International Conference on Learning Representations (ICLR), 2020b.
- Zhang et al. (2019) Michael Zhang, James Lucas, Jimmy Ba, and Geoffrey E Hinton. Lookahead optimizer: k steps forward, 1 step back. In Advances in Neural Information Processing Systems (NeurIPS), 2019.
- Zhang et al. (2024) Zhenyu Zhang, Ajay Jaiswal, Lu Yin, Shiwei Liu, Jiawei Zhao, Yuandong Tian, and Zhangyang Wang. Q-galore: Quantized galore with int4 projection and layer-adaptive low-rank gradients. arXiv preprint arXiv:2407.08296, 2024.
- Zhou et al. (2017) Aojun Zhou, Anbang Yao, Yiwen Guo, Lin Xu, and Yurong Chen. Incremental network quantization: Towards lossless cnns with low-precision weights. arXiv preprint arXiv:1702.03044, 2017.
Appendix A Training Details
In the following, we list the exact training hyper-parameters used for our experiments.
Models | VGG / ResNet | ViTs |
---|---|---|
Base Optimizer | SGD | AdamW |
Epochs | 200 | |
Warm Up Epochs | 8 | |
Data Augmentation | Inception style | |
Peak Learning Rate | 0.1 | 1e-3 |
LR-Scheduler | Step | Cosine |
Batch Size | 128 | |
Weight Decay | 0.0001 | 0.1 |
TWA Hyper-parameters | ||
Optimizer | SGD | AdamW |
Peak Learning Rate | 0.01 | |
LR-Scheduler | Cosine | |
Weight Decay | 0 |
Models | ResNets | ViTs |
---|---|---|
Base Optimizer | SGD | AdamW |
Epochs | 90 | |
Warm Up Epochs | 8 | |
Data Augmentation | Inception style | |
Peak Learning Rate | 0.4 | 1e-3 |
LR-Scheduler | Step | Cosine |
Batch Size | 1024 | |
Weight Decay | 0.0001 | 0.1 |
TWA Hyper-parameters | ||
Optimizer | SGD | AdamW |
Peak Learning Rate | 0.01 | |
LR-Scheduler | Cosine | |
Weight Decay | 0 |
Models | Transformer |
---|---|
Base Optimizer | Adam |
Steps | 100k |
Peak Learning Rate | 0.0001 |
LR-Scheduler | ReduceLROnPlateau |
Batch Size | 256 |
Dropout | 0.1 |
Weight Decay | 0.0001 |
TWA Hyper-parameters | |
Optimizer | Adam |
Peak Learning Rate | 0.01 |
LR-Scheduler | Constant |
Weight Decay | 0 |
Models | CLIP ViT-B/32 |
---|---|
TWA Hyper-parameters | |
Batch Size | 128 |
Optimizer | AdamW |
Peak Learning Rate | 0.01 |
Weight Decay | 0 |
LR-Scheduler | Cosine |
Models | GPT-2 |
---|---|
Base Optimizer | AdamW |
Steps | 1,000 |
Peak Learning Rate | 0.00005 |
LR-Scheduler | Linear |
Batch Size | 8 |
Weight Decay | 0 |
TWA Hyper-parameters | |
Batch Size | 8 |
Optimizer | AdamW |
Peak Learning Rate | 0.01 |
Weight Decay | 0 |
LR-Scheduler | Cosine |
Appendix B Dataset Details
We provide a detailed overview of the dataset splits used in this paper in Table 22.
Datasets | Training | Validation | Test |
---|---|---|---|
CIFAR-10/100 | 45,000 | 5,000 | 10,000 |
ImageNet | 1,255,167 | 26,000 | 50,000 |
WMT2014 | 954,000 | 3,000 | 3,000 |
Wikitext-2 | 36718 | 3760 | 4358 |