Towards a General Framework for
Continual Learning with Pre-training
Abstract
In this work, we present a general framework for continual learning of sequentially arrived tasks with the use of pre-training, which has emerged as a promising direction for artificial intelligence systems to accommodate real-world dynamics. From a theoretical perspective, we decompose its objective into three hierarchical components, including within-task prediction, task-identity inference, and task-adaptive prediction. Then we propose an innovative approach to explicitly optimize these components with parameter-efficient fine-tuning (PEFT) techniques and representation statistics. We empirically demonstrate the superiority and generality of our approach in downstream continual learning, and further explore the applicability of PEFT techniques in upstream continual learning. We also discuss the biological basis of the proposed framework with recent advances in neuroscience. Our code is available at https://github.com/thu-ml/HiDe-Prompt.
1 Introduction
To cope with real-world dynamics, continual learning has received widespread attention, especially in the context of pre-training. Through adapting the pre-trained knowledge effectively to downstream tasks, it provides not only positive knowledge transfer but also robustness to catastrophic forgetting [10, 8, 16, 20]. An emerging direction is the implementation of parameter efficient fine-tuning (PEFT) techniques (e.g., Prompt [3], Adapter [11], LoRA [2], FiLM [9], etc.), which usually freeze a pre-trained transformer backbone and employ additionally a few parameters to steer representation learning. In particular, recent prompt-based approaches [19, 18, 17, 12, 15] focus on construction and inference of appropriate prompts for each task, and achieve outstanding performance under strong supervised pre-training. However, existing methods usually degrade in performance with challenges in upstream knowledge (e.g., different pre-training paradigms) and downstream tasks (e.g., out-of-distribution and fine granularity), with generality left to be desired.
In this work, we provide an in-depth theoretical analysis of the continual learning objective in the context of pre-training, which can be decomposed into hierarchical components such as within-task prediction, task-identity inference and task-adaptive prediction. By leveraging the well-distributed pre-trained representations, we then propose an innovative approach applicable to various PEFT techniques to optimize explicitly the hierarchical components. We perform extensive experiments on downstream continual learning to demonstrate the superiority and generality of our approach, and further explore the applicability of PEFT techniques in upstream continual learning. We also provide neurological insights into the proposed framework for acquisition of open-world knowledge.
2 Hierarchical Decomposition of Continual Learning Objective
Continual learning aims to master a sequence of tasks represented by their respective training sets and excel on their corresponding test sets. Each training set , where denotes the size of . and indicate the sample and label elements, respectively. Consider a neural network model with a backbone parameterized by , and an output layer parameterized by . This model seeks to learn the projection from to , aiming to predict the label of an unseen test sample drawn from previous tasks. The backbone function is assumed to be pre-trained with a substantial quantity of additional training samples external to each . There are commonly three distinct settings for continual learning [13]: task-, domain-, and class-incremental learning (TIL, DIL, and CIL). Specifically, are identical for DIL while disjoint for TIL and CIL. The task identity is provided for TIL at test time but is not available for DIL and CIL.
Here we take CIL as a typical scenario for theoretical analysis, where , . Let and , where indicates the -th class in task . Now assume we have a ground event denoted as and a pre-trained model . For any sample , a general goal of the CIL problem is to learn , where and . This can be decomposed into two probabilities, including task-identity inference (TII) and within-task prediction (WTP), denoted as and , respectively. Based on Bayes’ theorem, we have
(1) |
Let and be the ground truth of an w.r.t. the task identity and within-task index. Eq. (1) shows that if we can improve either the WTP performance , the TII performance , or both, then the CIL performance would be improved. However, such an improvement is limited since it is upper-bounded by WTP or TII. To further improve the CIL performance, we propose a hierarchical decomposition of its objective. That is, besides the improvement of , we also need to improve the performance of task-adaptive prediction (TAP), denoted as , where represents the domain of class in all previous tasks, and is the ground truth label of . Then the final goal of CIL is formulated as a multi-objective optimization problem, i.e., . Notice that the WTP probability is a categorical distribution over all observed tasks , while the TAP probability is over all observed classes .
To resolve the problems above, we derive the sufficient and necessary conditions in the context of the widely-used cross-entropy loss. Specifically, we define
(2) | ||||
(3) | ||||
(4) |
where , , and are the cross-entropy values of WTP, TII, and TAP, respectively. The operation . is a one-hot encoding function.
We now present the first theorem under the CIL scenario:
Theorem 1
For continual learning with pre-training, if , , and , we have the loss error , regardless whether WTP, TII and TAP are trained together or separately.
With the use of cross-entropy, the continual learning performance tends to be better as the bounds are tightened. In Theorem 1 we have shown that good performances of WTP, TII and TAP are sufficient to guarantee a good performance of CIL. For completeness, we now study the necessary conditions of a well-performed CIL method in Theorem 2.
Theorem 2
For continual learning with pre-training, if the loss error , then there always exist (1) a WTP, s.t. ; (2) a TII, s.t. ; and (3) a TAP, s.t. .
Theorem 2 suggests that if a continual learning model is well trained (i.e., with low loss), then the WTP, TII and TAP for sequential tasks are always implied to be small.
3 Optimization of Hierarchical Components
Motivated by these theoretical insights, we propose to optimize explicitly the hierarchical components (i.e., WTP, TII and TAP) for continual learning with pre-training. Our proposal stems from two particular advantages of pre-training: (1) the representations can be effectively adapted to downstream tasks through PEFT techniques, and (2) the distributions of unadapted and adapted representations (denoted as and for each class , respectively) can be effectively preserved through their statistical information. For efficiency and generality, here we employ multiple centroids obtained from K-Nearest Neighbor (KNN) and add Gaussian noise as a specific implementation.
First, we improve WTP through effectively incorporating task-specific knowledge from each . Specifically, we construct task-specific parameters with a PEFT technique (e.g., Prompt [3], Adapter [11], LoRA [2], FiLM [9], etc.), and optimize with cross-entropy (CE). are frozen to avoid catastrophic forgetting, while is initialized with to transfer knowledge. Besides, the adapted representations of , although allowing the new task to be performed well, may overlap with that of the old tasks and thus affect TAP. To overcome this issue, we preserve statistics of adapted representations collected by and , where for classification we calculate the mean of for each class , and design a contrastive regularization (CR):
(5) |
where is the embedding transformation of with and . is the temperature coefficient, which is insensitive and set to 0.8 in practice. Then, the loss function of WTP can be defined as
(6) |
Therefore, the adapted representations of new classes can be well distinguished for WTP while avoiding overlap with the previous ones. is a hyperparamter to balance the impact of old classes.
Second, we improve TII and TAP through leveraging the approximated distributions of unadapted and adapted representations, respectively. For , we construct an auxiliary output layer parameterized by , learning explicitly the projection from unadapted representations to task identity via cross-entropy:
(7) |
where is constructed by sampling an equal number of pseudo representations from for and . Similarly, the final output layer is further optimized for :
(8) |
where is constructed by sampling an equal number of pseudo representations from for and . As and are usually light-weight, the optimization of TII and TAP is computationally efficient. At test time, our approach predicts the task identity and then the label .
4 Experiment
Experimental Setup: We consider two CIL benchmarks that are widely used for downstream continual learning [19, 18, 12], such as Split ImageNet-R [5] of 200-class natural images and Split CUB-200 [14] of 200-class bird images, randomly split into 10 incremental tasks. After learning multiple incremental tasks, we further evaluate upstream continual learning with the ability of few-shot learning, i.e., adapting the backbone to a N-way K-shot task [1] randomly sampled from subsequent unseen classes. We consider supervised and self-supervised pre-training on ImageNet-21K, denoted as Sup-21K and iBOT-21K, respectively.
Experimental Result: We implement two representative PEFT techniques as the task-specific parameters in our approach, such as Prompt [3] (adjusting intermediate inputs through prepending a short sequence of learnable prompt parameters) and LoRA [2] (adjusting backbone parameters through adding a learnable low-rank parameter matrix). We first evaluate the performance of downstream continual learning with different pre-training paradigms and CIL benchmarks. As shown in Table 1, the performance of state-of-the-art prompt-based approaches degrades remarkably under self-supervised pre-training (e.g., iBOT-21K) and fine-grained classification (e.g., Split CUB-200), while both versions of our approach outperform them significantly.
On the other hand, a potential limitation of prompt-based methodologies is that, the pre-trained knowledge in backbone parameters cannot be updated and enriched from incremental tasks, which has been rarely discussed in previous literature. Motivated by this, we then consider upstream continual learning, i.e., the ability of accumulating knowledge in backbone parameters. Specifically, after downstream continual learning of multiple incremental tasks, we evaluate the performance of the backbone to perform few-shot learning of an additional task randomly sampled from subsequent unseen classes. As shown in Table 2, the backbone adapted by the LoRA version of our approach acquires strong improvements in few-shot learning, compared to the unadapted backbone of the Prompt version. In addition to Split ImageNet-R and Split CUB-200 that split all tasks from the same dataset, we further consider a mixture of tasks sampled from CUB-200 [14] and Cars-196 [4] datasets, where the improvements remain significant. These results demonstrate the importance and feasibility of synchronizing upstream and downstream continual learning.
PTM | Method | Split ImageNet-R | Split CUB-200 | ||||
FAA () | CAA () | FFM () | FAA () | CAA () | FFM () | ||
Sup-21K | L2P [19] | 63.65 | 67.25 | 7.51 | 75.58 | 80.32 | 6.38 |
DualPrompt [18] | 68.79 | 71.96 | 4.49 | 81.32 | 83.45 | 5.31 | |
S-Prompt [17] | 69.68 | 72.50 | 3.29 | 81.51 | 83.24 | 4.48 | |
CODA-Prompt [12] | 70.03 | 74.26 | 5.17 | 74.34 | 80.71 | 7.42 | |
\cdashline2-8[2pt/2pt] | Ours-Prompt | 73.55 | 75.93 | 0.95 | 84.60 | 83.87 | 0.21 |
Ours-LoRA | 69.59 | 74.18 | 8.68 | 85.26 | 86.56 | 3.58 | |
Ours-Adapter | 70.48 | 75.03 | 7.89 | 84.69 | 86.51 | 4.10 | |
iBOT-21K | L2P [19] | 55.35 | 58.62 | 3.73 | 45.93 | 56.02 | 9.20 |
DualPrompt [18] | 54.55 | 58.69 | 5.38 | 41.46 | 54.57 | 14.03 | |
S-Prompt [17] | 55.16 | 58.48 | 4.07 | 39.88 | 53.71 | 13.15 | |
CODA-Prompt [12] | 61.22 | 66.76 | 9.66 | 47.79 | 59.24 | 11.81 | |
\cdashline2-8[2pt/2pt] | Ours-Prompt | 70.63 | 72.94 | 1.31 | 72.27 | 73.66 | 1.94 |
Ours-LoRA | 70.94 | 74.92 | 5.61 | 71.75 | 76.57 | 5.33 | |
Ours-Adapter | 72.07 | 76.01 | 5.48 | 74.29 | 78.74 | 5.35 |
PTM | Method | Split ImageNet-R | Split CUB-200 | Split CUB & Cars | |||
5W1S () | 5W5S () | 5W1S () | 5W5S () | 5W1S () | 5W5S () | ||
Sup-21K | Ours-Prompt | 49.88 | 67.88 | 77.50 | 80.63 | 65.87 | 82.13 |
Ours-LoRA | 67.00 | 82.88 | 79.50 | 92.88 | 71.87 | 86.93 | |
iBOT-21K | Ours-Prompt | 42.87 | 64.63 | 36.50 | 62.88 | 34.87 | 58.93 |
Ours-LoRA | 58.38 | 78.87 | 53.63 | 79.75 | 40.60 | 68.40 |
5 Discussion
In this work, we propose a general framework for continual learning in the context of pre-training, with decomposing the objective into three hierarchical components (i.e., WTP, TII and TAP) and optimizing them explicitly with PEFT techniques and representation statistics. Through extensive experiments, we demonstrate the superiority and generality of our approach in downstream continual learning, and further elaborate on the importance of upstream continual learning, which requires updating the backbone parameters rather than instructing only (intermediate) inputs. Interestingly, the proposed framework requires sequential invocation of the unadapted and (task-specific) adapted representations for inference, which is consistent with recent advances in biological learning and memory [7, 6] that the activation of non-memory cells and memory cells (as well as their specific sub-populations) is internally switched. This connection potentially bridges the intrinsic mechanisms of biological and artificial intelligence in acquisition of open-world knowledge.
Acknowledgements
This work was supported by the National Key Research and Development Program of China (No. 2020AAA0106302), NSFC Projects (Nos. 62061136001, 92248303, 62106123, 61972224), BNRist (BNR2022RC01006), Tsinghua Institute for Guo Qiang, and the High Performance Computing Center, Tsinghua University. L.W. is also supported by Shuimu Tsinghua Scholar, and J.Z. is also supported by the XPlorer Prize.
References
- [1] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International Conference on Machine Learning, pages 1126–1135. PMLR, 2017.
- [2] Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
- [3] Menglin Jia, Luming Tang, Bor-Chun Chen, Claire Cardie, Serge Belongie, Bharath Hariharan, and Ser-Nam Lim. Visual prompt tuning. In European Conference on Computer Vision, pages 709–727. Springer, 2022.
- [4] Jonathan Krause, Michael Stark, Jia Deng, and Li Fei-Fei. 3d object representations for fine-grained categorization. In Proceedings of the IEEE International Conference on Computer Vision Workshops, pages 554–561, 2013.
- [5] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009.
- [6] Bo Lei, Bilin Kang, Wantong Lin, Haichao Chen, Yuejun Hao, Jian Ma, Songhai Shi, and Yi Zhong. Adult newborn granule cells confer emotional state–dependent adaptability in memory retrieval. Science Advances, 8(45):eabn2136, 2022.
- [7] Bo Lei, Li Lv, Shiqiang Hu, Yikai Tang, and Yi Zhong. Social experiences switch states of memory engrams through regulating hippocampal rac1 activity. Proceedings of the National Academy of Sciences, 119(15):e2116844119, 2022.
- [8] Sanket Vaibhav Mehta, Darshan Patil, Sarath Chandar, and Emma Strubell. An empirical investigation of the role of pre-training in lifelong learning. arXiv preprint arXiv:2112.09153, 2021.
- [9] Ethan Perez, Florian Strub, Harm De Vries, Vincent Dumoulin, and Aaron Courville. Film: Visual reasoning with a general conditioning layer. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 32, 2018.
- [10] Vinay Venkatesh Ramasesh, Aitor Lewkowycz, and Ethan Dyer. Effect of scale on catastrophic forgetting in neural networks. In International Conference on Learning Representations, 2021.
- [11] Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi. Learning multiple visual domains with residual adapters. Advances in Neural Information Processing Systems, 30, 2017.
- [12] James Seale Smith, Leonid Karlinsky, Vyshnavi Gutta, Paola Cascante-Bonilla, Donghyun Kim, Assaf Arbelle, Rameswar Panda, Rogerio Feris, and Zsolt Kira. Coda-prompt: Continual decomposed attention-based prompting for rehearsal-free continual learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 11909–11919, 2023.
- [13] Gido M Van de Ven and Andreas S Tolias. Three scenarios for continual learning. arXiv preprint arXiv:1904.07734, 2019.
- [14] Catherine Wah, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. The caltech-ucsd birds-200-2011 dataset. 2011.
- [15] Liyuan Wang, Jingyi Xie, Xingxing Zhang, Mingyi Huang, Hang Su, and Jun Zhu. Hierarchical decomposition of prompt-based continual learning: Rethinking obscured sub-optimality. arXiv preprint arXiv:2310.07234, 2023.
- [16] Liyuan Wang, Xingxing Zhang, Hang Su, and Jun Zhu. A comprehensive survey of continual learning: Theory, method and application. arXiv preprint arXiv:2302.00487, 2023.
- [17] Yabin Wang, Zhiwu Huang, and Xiaopeng Hong. S-prompts learning with pre-trained transformers: An occam’s razor for domain incremental learning. Advances in Neural Information Processing Systems, 35:5682–5695, 2022.
- [18] Zifeng Wang, Zizhao Zhang, Sayna Ebrahimi, Ruoxi Sun, Han Zhang, Chen-Yu Lee, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, et al. Dualprompt: Complementary prompting for rehearsal-free continual learning. In European Conference on Computer Vision, pages 631–648. Springer, 2022.
- [19] Zifeng Wang, Zizhao Zhang, Chen-Yu Lee, Han Zhang, Ruoxi Sun, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, and Tomas Pfister. Learning to prompt for continual learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 139–149, 2022.
- [20] Gengwei Zhang, Liyuan Wang, Guoliang Kang, Ling Chen, and Yunchao Wei. Slca: Slow learner with classifier alignment for continual learning on a pre-trained model. arXiv preprint arXiv:2303.05118, 2023.