This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Few-Shot Continual Learning via Flat-to-Wide Approaches

Muhammad Anwar Ma’sum, Mahardhika Pratama, , Edwin Lughofer, Lin Liu, Habibullah, Ryszard Kowalczyk
Abstract

Existing approaches on continual learning call for a lot of samples in their training processes. Such approaches are impractical for many real-world problems having limited samples because of the overfitting problem. This paper proposes a few-shot continual learning approach, termed FLat-tO-WidE AppRoach (FLOWER), where a flat-to-wide learning process finding the flat-wide minima is proposed to address the catastrophic forgetting problem. The issue of data scarcity is overcome with a data augmentation approach making use of a ball generator concept to restrict the sampling space into the smallest enclosing ball. Our numerical studies demonstrate the advantage of FLOWER achieving significantly improved performances over prior arts notably in the small base tasks. For further study, source codes of FLOWER, competitor algorithms and experimental logs are shared publicly in https://github.com/anwarmaxsum/FLOWER.

Index Terms:
Continual Learning, Few-Shot Learning, Incremental Learning, Lifelong Learning

I Introduction

Existing studies on continual learning (CL) mostly focus on a solution of the catastrophic forgetting (CF) problem where old parameters of a deep neural network are overwritten when learning new tasks without relying on any information of old samples thus losing its generalization power of old tasks. Three approaches, namely regularization-based approach, structure-based approach, memory-based approach, have been well-established in the literature [1] where some of the above mentioned approaches have successfully tackled the class-incremental learning problem [2], i.e., a single head scenario without any task IDs, more challenging than the task-incremental learning problem. However, these approaches suffer from the over-fitting problem in the case of limited samples. Hence, they are unfit for many real world problems such as a slow-speed manufacturing process where the problem of data scarcity is present. This drawback also does not coincide with the natural learning where human being can learn from very few samples [3]. Our interest here is to attack the problem of few-shot continual learning [4] where each task carries very few samples in the N-way K-shot setting, i.e., each task comprises N classes where each task presents K samples per class. Such a case requires a model to handle not only a sequence of tasks without the CF problem but also to avoid the over-fitting problem due to scarce samples.

The few-shot continual learning approach is pioneered by the work in [4] where the concept of neural gas (NG) is adopted. Old nodes of the NG structure are stabilized to avoid the CF problem while expanding the NG structure to adapt to new concepts. [5] presents a few-shot continual learning algorithm using the concept of session trainable parameters, deemed unimportant for previous tasks to alleviate the CF problem. The training process is governed by a triplet loss function and classification is performed using the nearest prototype-based classification approach. The vector quantization concept in the deep embedding space is successfully implemented in [6] to deal with both regression and classification problems in few-shot continual learning environments. A continually evolved classifier (CEC) is proposed in [7] using a graph attention network (GAT) to adapt the classifier’s weights. [8] introduces the concept of flat learning, finding flat regions in the base learning task to be maintained in the few-shot learning task. However, it relies on a memory storing exemplar sets of previous tasks. The use of memory might lead to most samples of few-shot learning tasks to be kept in the memory. It heavily focuses on stability of base tasks but ignores plasticity of new tasks.

This paper presents a FLat-tO-WidE appRoach (FLOWER) for few-shot continual learning problems. FLOWER is configured as a prototypical network [9] driven by the prototypical loss. FLOWER puts forward the flat-to-wide learning idea [8, 10], to discover flat-wide minimum regions. It minimizes the losses of noise-perturbed parameters of the feature extractor over MM runs. The flat-wide parameters are maintained in the subsequent few-shot learning tasks by constraining feature extractor parameters to prevent the catastrophic forgetting problem. Our approach differs from [8] heavily focusing on the flat minimum regions resulting in over-dependence on the base learning task. In addition, no memory is imposed by FLOWER, and the wide learning concept is integrated as an attempt to drive flat regions to be wide, thus scaling well for large-scale problems. The concept of ball generator is integrated in FLOWER to prevent the over-training problem due to scarce samples in few-shot learning tasks.

A feature-space data augmentation approach is applied in a sequence of few-shot learning tasks. The data augmentation strategy follows the ball generator concept [11] where the sampling space is limited to the smallest enclosing ball of few-shot samples. A generator loss is incorporated such that generated samples are adjacent to the center of its own class and away to the centre of other classes. The projection-based memory aware synapses method is implemented to avoid classifier’s parameters being overwritten when embracing new few-shot tasks. That is, the memory aware synapses regularization approach is combined with the classifier’s projection concept of [10] and the clamping procedure inducing flat-wide local optimum regions scaling well for large-scale continual learning tasks. The main network and the ball generator are trained to minimize a joint loss function of the prototypical loss and the generator loss added with the regularizer. A constraint is implemented to feature extractor parameters where they are clamped if going beyond the flat-wide regions [8]. This approach also prevents the over-fitting problem because only few classifier parameters are adapted while leaving feature extractor parameters in a fixed range.

Five major contributions are proposed: 1) this paper proposes flat-to-wide approach (FLOWER) for few-shot continual learning synergizing the flat learning concept and the wide learning concept. The flat-wide minima are enforced to address the CF problem. This enables improved plasticity when learning new tasks as well as maintains decent performances with small base tasks. Note that the good average performance of [8] often comes at the cost of poor performances on new tasks and small base tasks; 2) the concept of ball generator is proposed to perform the feature space data augmentation as a solution of data scarcity. This approach is modified from [11] designed for texts and non-continual learning problems; 3) new loss functions unifying the flat learning loss, wide learning loss and ball-generator loss are derived for both base learning task and few-shot learning tasks leading to an end-to-end training process; 4) rigorous theoretical study is performed to confirm the convergence of FLOWER where its losses converge toward zero; 5) the efficacy of FLOWER has been validated and compared with state-of-the-art algorithms where FLOWER delivers improved performances and works well with the small base tasks. Source codes of FLOWER are placed in https://github.com/anwarmaxsum/FLOWER to assure convenient reproductions of our numerical results and further studies..

II Related Works

II-A Continual Learning

There exist three variants of CL approaches notably to attack the CF problem [1]: regularization-based approach, memory-based approach and architecture-based approach. The regularization-based approach relies on a penalty term for important parameters of previous tasks making sure no significant deviations when handling new tasks [12, 13, 14, 15, 16, 10, 17]. The performance of regularization-based approaches is usually inferior to other two approaches and often call for the presence of task IDs, thus performing poorly in the Class-Incremental Learning (CIL). The structure-based approach applies a network expansion strategy combined with a parameter isolation strategy to overcome the CF problem [18, 19, 20, 21, 22, 23, 24, 25]. In addition to expensive computational costs, above-mentioned works depend on task IDs and task boundaries, thus being inapplicable to deal with few-shot CIL problems. Although the data-driven approach is fast to adapt to changes and does not require any task IDs and task boundaries, it does not guarantee an optimal structure. The memory-based approach stores a subset of old samples in an episodic memory for experience replay when dealing with new tasks and often offers strong performances in the CIL problems [26, 27, 28, 29, 30, 31, 32]. Although these approaches perform strongly for the CIL problems, the memory-based approach is impractical for few-shot CL problems because of the limited sample constraint, i.e., it is akin to the retraining case undermining the CL spirit.

II-B Few-Shot Continual Learning

Few-shot CL problem aims to resolve the data-scarcity issue in addition to the CF problem. [4] is seen as a pioneering approach in this area where the concept of neural gas is adopted to address new concepts without compromising already seen experiences. [5] puts forward session trainable parameters to avoid the over-training problem while using the triplet-loss based metric learning. [6] applies the vector quantization concept and has been validated for regression as well as classification problems. [7] makes use of graph attention network (GAT) to control the adaptive mechanisms of new tasks. The concept of flat learning is proposed in [8] for the few-shot CL problems. This approach is, however, dependent on a memory which should not exist for the few-shot CL problems, i.e., all samples can be stored in the memory thus becoming equivalent to retraining approaches. FLOWER also differs from [8] where the ball generator concept is integrated for the feature space data augmentation mechanism to better address the data scarcity problem while introducing the wide learning concept to scale up to large-scale tasks.

III Problem Formulation

A few-shot continual learning problem is defined as a learning problem of a sequence of fully-labelled tasks 𝒯1,𝒯k,,𝒯K\mathcal{T}_{1},\mathcal{T}_{k},...,\mathcal{T}_{K} where k{1,,K}k\in\{1,...,K\} denotes the number of tasks unknown before a process runs. Each task carries NkN_{k} pairs of training samples 𝒯k={xik,yik}i=1Nk\mathcal{T}_{k}=\{x_{i}^{k},y_{i}^{k}\}_{i=1}^{N_{k}} where xi𝒳x_{i}\in\mathcal{X} stands for an input image and yi𝒴y_{i}\in\mathcal{Y} denotes its corresponding class label. All data samples are drawn from the same domain 𝒳×𝒴𝒟\mathcal{X}\times\mathcal{Y}\in\mathcal{D}. Each task features the same image size, but possesses disjoint target classes. Suppose that YkY_{k} labels a class set of a kthk-th task and Yk1Y_{k-1} denotes a class set of a k1k-1 task, k,k1,YkYk1=\forall k,k-1,Y_{k}\cap Y_{k-1}=\emptyset. The few-shot continual learning problem is commonly formulated as the class-incremental learning problem rather than the task-incremental learning problem [2] where a class prediction is inferred by a single-head network structure and a pair of data samples of every task arrive without the task IDs tikt_{i}^{k}.

𝒯1\mathcal{T}_{1} is a base learning task possessing a reasonably large number of samples while the remainder of the tasks is a few-shot learning task lacking of samples, i.e., each few shot learning task is set in the NN way KK shot configuration where it has NN target classes with KK samples per classes. In addition, a task 𝒯k>1\mathcal{T}_{k>1} is only accessible at the kthk-th session and completely discarded once seen. This constraint causes the catastrophic interference problem where previously valid parameters are over-written when learning a new task. The memory-based approach is unfeasible for the few-shot continual learning problems because it may lead to the storage of all past samples due to the data scarcity constraint, a retraining case. Hence, a few-shot continual learner is supposed to not only accumulate knowledge from already seen experiences but also present efficient sample utilization to avoid the over-fitting problem.

Our few-shot continual learner is defined as gWC(fWF(.))g_{W_{C}}(f_{W_{F}}(.)) where fWF(.)f_{W_{F}}(.) stands for a feature extractor with parameters WFW_{F} while gWC(.)g_{W_{C}}(.) denotes a fully-connected network with parameters WCW_{C}. Specifically, it is configured as a prototypical network [9] where the final prediction is expressed as follows:

P(y^=c|x;WC,WF)=exp(d(gWC(fWF(x)),pc))cMexp(d(gWC(fWF(x)),pc))P(\hat{y}=c|x;W_{C},W_{F})=\frac{\exp{(-d(g_{W_{C}}(f_{W_{F}}(x)),p_{c}))}}{\sum_{c\in M}\exp{(-d(g_{W_{C}}(f_{W_{F}}(x)),p_{c}))}} (1)

c=argmaxcMP(y^=c|x;WC,WF)c^{*}=\arg\max_{c\in M}P(\hat{y}=c|x;W_{C},W_{F}) stands for a predicted class label and MM denotes the number of classes seen thus far. d(.)d(.) is a distance metric implemented as the L2L_{2} distance metric in this paper. pc=1|Y1|(y)Y1fWF(xi)p_{c}=\frac{1}{|Y_{1}|}\sum_{(y)\in Y_{1}}f_{W_{F}}(x_{i}) is the prototype of the cthc-th class. The learner is tasked to perform well across all tasks and minimizes 1Kk=1Kk\frac{1}{K}\sum_{k=1}^{K}\mathcal{L}_{k} and k𝔼(x,y)𝒯k[(gWC(fWF(x)),y)]\mathcal{L}_{k}\triangleq\mathbb{E}_{(x,y)\backsim\mathcal{T}_{k}}[\mathcal{L}(g_{W_{C}}(f_{W_{F}}(x)),y)] where (.)\mathcal{L}(.) is a loss function of interest where a cross-entropy loss of the prototypical network is adopted here.

ce=ic𝟏(y=c)exp(d(gWC(fWF(xi)),pc))cMexp(d(gWC(fWF(xi)),pc))\mathcal{L}_{ce}=-\sum_{i}\sum_{c}\mathbf{1}(y=c)\frac{\exp{(-d(g_{W_{C}}(f_{W_{F}}(x_{i})),p_{c}))}}{\sum_{c\in M}\exp{(-d(g_{W_{C}}(f_{W_{F}}(x_{i})),p_{c}))}} (2)

where 𝟏(y=c)\mathbf{1}(y=c) is an indicator function returning 1 if y=cy=c. The inference adopts the nearest-mean classification strategy making use of a class label of the nearest prototype.

IV The Flat-to-Wide Approach (FLOWER)

The few-shot continual learning problem consists of two learning stages: a base learning phase and a sequence of few-shot learning phases.

IV-A Base Learning Phase (k=1k=1)

We apply the concept of b-flat-wide local minima where, for the parameters of interests in the flat-wide region, i.e., the feature extractor parameters WFW_{F} the losses are still minimized leading to well separations of the classes of interest and robustness against the catastrophic forgetting problem [8]. Once found in the base learning phase, these parameters are only tuned within this region during the few-shot learning phases. In other words, it is sufficient to adjust the network parameters within the flat-wide interval to avoid the over-training problem. A definition of the b-flat-wide local minima is formalised as:

Definition 1 [8]: given a real-valued loss function (.)\mathcal{L}(.), for any b>0b>0 and network parameters Θ={WC,WF}\Theta=\{W_{C},W_{F}\}, the b-flat-wide local minima Θ\Theta^{*} follow the following condition:

  • Condition 1: for (x;Θ+ϵ)\mathcal{L}(x;\Theta^{*}+\epsilon), where bϵ+b-b\leq\epsilon\leq+b.

  • Condition 2: there exist c1<Θbc_{1}<\Theta^{*}-b and c2>Θ+bc_{2}>\Theta^{*}+b s.t. (x;Θ)>(x;Θ)\mathcal{L}(x;\Theta)>\mathcal{L}(x;\Theta^{*}) where c1<Θ<Θbc_{1}<\Theta<\Theta^{*}-b and Θ+b<Θ<c2\Theta^{*}+b<\Theta<c_{2}

These conditions are extremely hard to hold in reality and an approximation is carried out using a noise perturbation. That is, we run MM trials with different noise perturbations ϵj,bϵj+b&jM\epsilon_{j},-b\leq\epsilon_{j}\leq+b\quad\&\quad j\in M, against network parameters Θ\Theta^{*}. Under different noise perturbations, the network parameters within the estimated b flat-wide regions are supposed to return similar losses and to have small function values [8]. Fig. 1 visualizes the flat learning paradigm. In a nutshell, the base learning phase is designed to minimize the following loss function.

k=1=1Mj=1Mbasej(gWC(fWF+ϵj(x)),y)basej=𝔼(x,y)𝒯1[ce(gWC(fWF+ϵj(x)),y)+KL(gWC(fWF+ϵj(x)),PU)]+λ1𝔼yY1||pcpc||2\begin{split}\mathcal{L}_{k=1}=\frac{1}{M}\sum_{j=1}^{M}\mathcal{L}_{base}^{j}(g_{W_{C}}(f_{W_{F}+\epsilon_{j}}(x)),y)\\ \mathcal{L}_{base}^{j}=\mathbb{E}_{(x,y)\backsim\mathcal{T}_{1}}[\mathcal{L}_{ce}(g_{W_{C}}(f_{W_{F}+\epsilon_{j}}(x)),y)+\\ \mathcal{L}_{KL}(g_{W_{C}}(f_{W_{F}+\epsilon_{j}}(x)),P_{U})]+\lambda_{1}\mathbb{E}_{y\backsim Y_{1}}||p_{c}-p_{c}^{*}||_{2}\end{split} (3)

where pc,pcp_{c},p_{c}^{*} stand for the prototype of the cthc-th class before and after noise perturbations respectively. The first term of base\mathcal{L}_{base} is designed to unveil the b-flat-wide region while the second term prevents the prototype drift problem. PUP_{U} is a uniform distribution and KL(.)\mathcal{L}_{KL}(.) promotes wide local minima via output’s projections to uniform distributions. Note that (3) extends [8] where only the flatness is solicited and ignores the wide learning issue.

Refer to caption
Figure 1: Flat Continual Learning: during the base learning phase k=1k=1, FLOWER finds the network parameters Θ\Theta^{*} located in the flat local optimum region. These parameters might shift to non-flat local minima regions when learning the few-shot learning tasks k>1k>1 via ΘΘ1\Theta^{*}\rightarrow\Theta_{1}. Clamping is performed for Θ1\Theta_{1} to stay around a flat local optimum region.

The goal of (3) is to unveil a flat-wide region working well on all tasks, thus overcoming the catastrophic forgetting problem. It is done by sampling ϵj\epsilon_{j} MM times to vary network parameter leading to k=1\mathcal{L}_{k=1} considered as a non-convex loss function [8]. It is an approximation of the expected loss R(Θ)=𝔼z,ϵ[t=1]R(\Theta)=\mathbb{E}_{z,\epsilon}[\mathcal{L}_{t=1}], which is impossible to minimize in practise. zkz_{k} is the kthk-th data batch.

Assumption 1 (L-smooth risk function [8]): R:DR:\Re^{D}\rightarrow\Re is continuously differentiable and L-smooth L>0L>0:

R(Θ)R(Θ)2LΘΘ||\nabla R(\Theta)-\nabla R(\Theta^{{}^{\prime}})||_{2}\leq L||\Theta-\Theta^{{}^{\prime}}|| (4)

This assumption constraints the pace of the gradients w.r.t the parameter vectors.

Assumption 2: the expected loss function R(Θ)R(\Theta) holds the following properties [8]:

  • RR is bounded below a scalar RR^{*} given {Θk}\{\Theta_{k}\} where kk indicates the kthk-th data batch. It implies RR to be bounded by a minimum value RR^{*};

  • For all kk\in\mathbb{N} and j[1,M]j\in[1,M]:

    𝔼zk,ϵjbasej,k=R(Θk)\mathbb{E}_{z_{k},\epsilon_{j}}\nabla\mathcal{L}_{base}^{j,k}=\nabla R(\Theta_{k}) (5)

    where basej,k\nabla\mathcal{L}_{base}^{j,k} is an unbiased estimation of R(Θk)\nabla R(\Theta_{k}) to simplify the proof;

  • given m10,m20m_{1}\geq 0,m_{2}\geq 0, for all kk\in\mathbb{N} and j[1,M]j\in[1,M]:

    𝕍zk,ϵj[basej,k](m1+m2)R(Θk)22\mathbb{V}_{z_{k},\epsilon_{j}}[\nabla\mathcal{L}_{base}^{j,k}]\leq(m_{1}+m_{2})||\nabla R(\Theta_{k})||_{2}^{2} (6)

    meaning that the variance of the gradient cannot be arbitrarily large.

𝔼zk,ϵj,𝕍[.]\mathbb{E}_{z_{k},\epsilon_{j}},\mathbb{V}[.] respectively denote the expectation w.r.t the joint distribution of random variables zk,ϵjz_{k},\epsilon_{j} and the variance. These three conditions are reasonable in practise and applicable for convergence analysis.

Assumption 3: the learning rates λk\lambda_{k} satisfy [8]

k=1λk=,k=1λk2<\sum_{k=1}^{\infty}\lambda_{k}=\infty,\sum_{k=1}^{\infty}\lambda_{k}^{2}<\infty (7)

which can be easily met due to λk<1\lambda_{k}<1 and a decreasing function w.r.t kk. These assumptions establish a theorem.

Theorem 1: Given the assumptions 1-3 and RR is twice differentiable, we have:

limk𝔼[R(Θk)22]=0\lim_{k\rightarrow\infty}\mathbb{E}[||\nabla R(\Theta_{k})||_{2}^{2}]=0 (8)

the convergence of the flat-wide minima strategy. A set of proofs are provided in the appendices.

IV-B Few-Shot Learning Phase (k>1k>1)

The few-shot learning phase of FLOWER is driven by a feature-space data augmentation approach to address the data scarcity issue while applying a combination of the flat and wide learning principles to address the CF problem without any memory. Network parameters, namely those of feature extractors, are adjusted in such a way to maintain the b-flat-wide local minima conditions to mitigate the over-training problems and the projection concept is implemented in the memory aware synapses (MAS) method [13] to induce wide local optimum.

IV-B1 Feature Space Data Augmentation

The concept of ball generator [11] is adopted and operates in the feature space z=fWF(x)Dz=f_{W_{F}}(x)\in\Re^{D} rather than the high-dimensional data space xw×H×cx\in\Re^{w\times H\times c}. The ball generator produces synthetic samples of the smallest enclosing ball of the cthc-th class ω(Zc)={Cc,σc}\omega(Z_{c})=\{C_{c},\sigma_{c}\} where Cc,σcC_{c},\sigma_{c} denote the centroid and radius of the cthc-th ball/class. Unlike [11], the ball generator is directly executed in the few-shot learning tasks where there exist the data scarcity problem rather than to that of the support set. A synthetic sample z^D\hat{z}\in\Re^{D} is generated:

z^=Cc+u1Dσcϕϕ2\hat{z}=C_{c}+u^{\frac{1}{D}}\sigma_{c}\frac{\phi}{||\phi||_{2}} (9)

where uu denotes a shifting constant sampled from a uniform distribution [0,1][0,1] and ϕD\phi\in\Re^{D} labels a random vector generated from a normal distribution ϕ𝒩(0,1)\phi\backsim\mathcal{N}(0,1).

We feed a synthetic sample z^\hat{z} to a transformation module z^^=κWT(z^)D\hat{\hat{z}}=\kappa_{W_{T}}(\hat{z})\in\Re^{D} arranged as three fully connected layer network as per [11]. The goal of the transformation module is to avoid the bias of synthetic samples. That is, a ball generator loss is introduced as follows:

ball=z^^ci,z^^cjλ2max{0,d(z^^,Ci)+rd(z^^,Cj)}\mathcal{L}_{ball}=\sum_{\hat{\hat{z}}\in c_{i},\hat{\hat{z}}\notin c_{j}}\lambda_{2}\max{\{0,d(\hat{\hat{z}},C_{i})+r-d(\hat{\hat{z}},C_{j})\}} (10)

where rr is a predefined margin and z^^=κWT(z^)\hat{\hat{z}}=\kappa_{W_{T}}(\hat{z}) stands for a synthetically generated sample of the transformation module. Ci,CjC_{i},C_{j} respectively denote the centroid of the ithi-th and jthj-th ball or class. (10) is inspired by the triplet loss of the metric learning aiming to pull the synthetic sample z^^\hat{\hat{z}} close to its centroid and to push it away from other centroids. λ2\lambda_{2} is a predefined constant steering the influence of the ball generator loss function.

IV-B2 Projection-based Memory Aware Synapses

FLOWER relies on a regularizer PMAS\mathcal{R}_{PMAS} integrating the projecting concept [10] into the memory aware synapses approach to address the catastrophic forgetting problem. The application of projection concept induces wide local optimum regions improving scalability of a continual learner since it increases a chance of an overlapping region to be found. The PMAS method is formalized:

PMAS=λ3Nki=1NkKL(gWC(fWF(xik)),PU)+λ4iΞik1(ΘikΘik1)2\begin{split}\mathcal{R}_{PMAS}=\frac{\lambda_{3}}{N_{k}}\sum_{i=1}^{N_{k}}\mathcal{L}_{KL}(g_{W_{C}}(f_{W_{F}}(x_{i}^{k})),P_{U})\\ +\lambda_{4}\sum_{i}\Xi_{i}^{k-1}(\Theta_{i}^{k}-\Theta_{i}^{k-1})^{2}\end{split} (11)

where Ξ\Xi is a parameter importance matrix which can be estimated using different approaches: EWC [12], SI [13], MAS [14], etc. We apply the MAS method based on the online and unsupervised method, since it is faster to compute than EWC and SI. Θ={WF,WC}\Theta=\{W_{F},W_{C}\} where no regularization mechanism is applied to the transformation module κWT(.)\kappa_{W_{T}}(.) because it only functions to enrich the current concept. The first term of (11) promotes wide local minimum where KL(.)\mathcal{L}_{KL}(.) denotes the KL divergence loss function minimizing the discrepancies of the two components while the second term of (11) is the MAS-based regularization approach inspired by the L2L_{2} regularization technique except that of (ΘikΘik1)(\Theta_{i}^{k}-\Theta_{i}^{k-1}). Θik1,Θik\Theta_{i}^{k-1},\Theta_{i}^{k} respectively denote the network parameters before and after seeing the current kthk-th task. λ3,λ4\lambda_{3},\lambda_{4} stand for predefined constants controlling the influence of each term.

(11) introduces the KL divergence term to induce wide local optimum regions where it is perceivable as the optimization minQ𝒬KL(Q|P)\min_{Q\in\mathcal{Q}}\mathcal{L}_{KL}(Q|P) where PP is a given distribution and 𝒬\mathcal{Q} is a convex set of distributions in the probability simplex Δm\Delta_{m} [10] where gWC(.):DΔMg_{W_{C}}(.):\Re^{D}\rightarrow\Delta_{M} and MM is the label space dimension. After minimizing this term, PP^{*} is obtained and represents a distribution in 𝒬\mathcal{Q} being the closest to PP. Such case presents KL(Q|P)\mathcal{L}_{KL}(Q|P) as the Euclidean distance where (Q,P,P)(Q,P^{*},P) constructs a right triangle.

Lemma 1 [10]. Suppose PQP^{*}\in Q such that KL(P|P)=minQ𝒬KL(Q|P)\mathcal{L}_{KL}(P^{*}|P)=\min_{Q\in\mathcal{Q}}\mathcal{L}_{KL}(Q|P), then

KL(Q|P)KL(Q|P)+KL(P|P),Q𝒬\mathcal{L}_{KL}(Q|P)\geq\mathcal{L}_{KL}(Q|P^{*})+\mathcal{L}_{KL}(P^{*}|P),\forall Q\in\mathcal{Q} (12)

The classifier’s output gWC(.)g_{W_{C}}(.) is projected to a uniform distribution in a set 𝒞\mathcal{C} and interpreted as conditional distributions QY|XQ_{Y|X} between target YY and input XX. After learning the kthk-th task with (11), PY|Xk𝒞P^{k*}_{Y|X}\in\mathcal{C} is obtained. Suppose the presence of two classifiers in the k+1thk+1-th task PY|Xk+1𝒞P^{k+1}_{Y|X}\notin\mathcal{C} w/o the projection and PY|X(k+1)𝒞P^{(k+1)*}_{Y|X}\in\mathcal{C} with the projection, a right triangle is established across (PY|Xk,PY|X(k+1),PY|X(k+1))(P^{k*}_{Y|X},P^{(k+1)}_{Y|X},P^{(k+1)*}_{Y|X}) as per Lemma 1. PY|XkP^{k*}_{Y|X} is closer to PY|X(k+1)P^{(k+1)*}_{Y|X} than that without the projection PY|X(k+1)P^{(k+1)}_{Y|X} when being evaluated on the kthk-th task. A uniform distribution PUP_{U} is selected because it represents the centroid of ΔM\Delta_{M} that results in an upper bound of divergence logM\log{M}. Note that an ideal classifier performing well on all tasks does not exist in the continual learning context and thus we turn our attention to the set of possible classifiers 𝒞\mathcal{C} to be a KL divergence ball centered at the uniform distribution PUP_{U} [10]. The wide learning strategy is illustrated in Fig. 2.

Refer to caption
Figure 2: Wide Continual Learning: conventional continual learning approaches generate narrow ellipsoids where the local minima region is sharp leaving very small space for model candidate parameters, i.e., intersections of all ellipsoids. On the contrary, the wide continual learning generates wide ellipsoids having wide local minima regions allowing enough space of model candidate parameters.

IV-B3 Loss Function of Few-Shot Learning Phase

The few-shot learning phase of FLOWER is controlled by three loss functions: cross entropy loss function ce\mathcal{L}_{ce}, ball generator loss function ball\mathcal{L}_{ball} and regularizer PMAS\mathcal{R}_{PMAS}. The cross entropy loss function (2) learns the current concept supported by the feature space data augmentation method while the ball generator loss function (10) addresses the synthetic sample bias problem where a synthetic sample is pulled closely to its centroid and pushed away from other centroids. The regularizer (11) functions to protect against the catastrophic forgetting problem:

k>1=ce+ball+PMAS\mathcal{L}_{k>1}=\mathcal{L}_{ce}+\mathcal{L}_{ball}+\mathcal{R}_{PMAS} (13)

The network parameters are maintained in such a way to satisfy the b-flat-wide local minima conditions. That is, parameters of feature extractor are clamped if they go beyond the flat-wide region. Besides, the wide learning paradigm is actualized in the projection mechanism to scale FLOWER up to long sequences of tasks. (13) takes the advantage of the flat learning concept [8] and the wide learning concept [10] to address the CF problem without any memory and thus realizes the flat-wide learning strategy. Pseudo-codes of FLOWER are provided in the algorithm 1 while overview of flat-wide learning strategies is pictorially illustrated in Fig. 3.

Algorithm 1 Learning Policy of FLOWER
1:Input: The flat-wide region bound bb, randomly initialized parameters Θ={WF,WC},WT\Theta=\{W_{F},W_{C}\},W_{T}, hyper-parameters λ1,λ2,λ3,λ4\lambda_{1},\lambda_{2},\lambda_{3},\lambda_{4}, number of epochs EE, number of synthetic samples SS.
2:Output: Updated Network Parameters Θ={WF,WC}\Theta=\{W_{F},W_{C}\}
3:
4:for e=1:Ee=1:E do
5:     for j=1:Mj=1:M do
6:         Sample a noise vector ϵj\epsilon_{j}\backsim\mathcal{E} where bϵjb-b\leq\epsilon_{j}\leq b
7:         Perturb the feature extractor parameters Θ={WF+ϵj,WC}\Theta=\{W_{F}+\epsilon_{j},W_{C}\}
8:         Compute the base loss base\mathcal{L}_{base}
9:         Reset the network parameters Θ={WF,WC}\Theta=\{W_{F},W_{C}\}
10:     end for
11:     Update network parameters with k=1\mathcal{L}_{k=1} as per (M.3)
12:end for
13:
14:for k=2:Kk=2:K do
15:     for e=1:Ee=1:E do
16:         𝒯k=\mathcal{T}^{\prime}_{k}=\emptyset
17:         Calculate the ball/class parameters ω(Zc)={Cc,σc}\omega(Z_{c})=\{C_{c},\sigma_{c}\}
18:         for i=1:Si^{\prime}=1:S do
19:              Generate a synthetic sample z^i\hat{z}_{i^{\prime}} as per (M.9)
20:              Apply the transformation module z^^=κWT(z^i)\hat{\hat{z}}=\kappa_{W_{T}}(\hat{z}_{i^{\prime}}) to induce a
21:
22:              𝒯k=𝒯k{z^^i,yi}\mathcal{T}^{\prime}_{k}=\mathcal{T}^{\prime}_{k}\cup\{\hat{\hat{z}}_{i^{\prime}},y_{i}\}
23:         end for
24:         𝒯^ke=𝒯k𝒯k\hat{\mathcal{T}}_{k}^{e}=\mathcal{T}^{\prime}_{k}\cup\mathcal{T}_{k}
25:         Calculate the ball generator loss (M.10) with 𝒯^ke\hat{\mathcal{T}}_{k}^{e}
26:         Calculate the cross entropy loss (M.2) with 𝒯^ke\hat{\mathcal{T}}_{k}^{e}
27:         Calculate the final loss (M.13)
28:         Update network parameters Θ={WF,WC},WT\Theta=\{W_{F},W_{C}\},W_{T} with k\mathcal{L}_{k}
29:         Clamp feature extractor parameters WFbWFWF+bW_{F^{*}}-b\leq{W}_{F}\leq W_{F^{*}}+b to
30:
31:     end for
32:end for
Refer to caption
Figure 3: Few Shot Learning Phase of FLOWER is driven by the three loss functions, the cross entropy loss function, the ball generator loss function and the projection-based synaptic intelligence regularizer. It combines wide and flat learning paradigms supported by the feature space data augmentation module having a transformation module.

V Experiments

This section presents numerical validations of FLOWER: comparisons with prior art approaches with different sizes of base tasks, ablation study and sensitivity analyses of the bound bb and the number of trials MM.

V-A Experimental Setting

Datasets: our numerical study is carried out using CIFAR100, miniImageNet and CUB-200-2011. The CIFAR100 and miniImageNet problems are configured such that 60 classes are reserved for the base task while the remaining 40 classes as the few-shot tasks under the 5-way 5-shots configuration, i.e., 5 classes are sampled where each class comprises 5 samples, thus leading to 8 few-shot tasks. For the CUB-200-2011 problem with 200 classes, the base task is built upon 100 classes while the few-shot tasks utilizes the remainder 100 classes with the 10-way 5-shot configuration, thus forming 10 few-shot tasks. In addition, we also evaluate consolidated algorithms in the case of small base tasks where only 20 classes are included in the miniimagenet and CIFAR100 problems while 50 classes are presented for the CUB problem. Numerical results under a 1-shot configuration are also reported here.

Benchmark Algorithms: FLOWER is compared with 7 state-of-the-art algorithms: iCaRL [26], Rebalance [33], FSLL [5], cRT [34], TOPIC [4], EEIL [35] and F2M [8]. The baseline algorithm follows [8] where the training process is only carried out in the base task while only exemplar constructions are done for few-shot tasks for a simple nearest class mean classification algorithm. A joint training is also included where the training process encompasses base classes and few-shot classes. Although the joint-training method commonly serves as the upper bound for the continual learning problems, it is not strong enough here because the extreme imbalance problem between the base task and the few-shot tasks exists in the few-shot continual learning problems. Here, the cRT is added as the additional upper bound because it performs long-tailed classification trained with all encountered data [8]. Numerical results of ICARL, Rebalance, EEIL, TOPIC, cRT and joint-training are taken from [4] and [8] while we run ICARL, Rebalance Baseline, F2M and FSLL in the same computational environments applying the implementation of [8] with their best hyper-parameters found with the grid search.

Experimental Details: our numerical study is executed under single NVIDIA A100 GPU with 40 GB memory across five runs. ResNet18 is adopted as the backbone network for all methods in the miniimagenet and CUB-200-2011 problems, while ResNet20 is implemented as a backbone network for all methods in the CIFAR100 problem. As with [8], MM is set in between 2 to 4 and noise is injected to the last 4 or 8 convolutional layers while the flat region bound bb is assigned as 0.01. Epoch per session is selected as 15 and the learning rate is assigned as 0.2 initially and truncated by 1e-6 after 5 epochs. Hyper-parameters of all methods are detailed in Table I.

hyperparameter value
image_size 32 (CIFAR), 84 (MiniImageNet), 224(CUB)
networks ResNet18, ResNet20(CIFAR)
optimizer SGD
bacth_size 256
epoch
- base task
- continual tasks
- 1000, 600 (CUB)
- 15, generate prototype only (CUB)
initial learning rate (lr)
- base task
- continual tasks
- 0.1, 0.001 (CUB)
- 0.2
gamma
- base task
- continual tasks
- 0.1, 1e-6 (CUB)
- 1e-6
weight_decay 5e-4
momentum 0.9
random_noise
- distribution
- type
- num_layers
- low
- high
- reduction_factor
- bound_value
- random_times
- DiscreteBeta
- suffix_conv_weight
- 4 or 8
- 0.1
- 5.0, 3.0 (CUB)
- 4
- 0.01, 0.0002 (SUB)
- 2-4
lambda
- lambda1
- lambda2
- lambda3
- lambda4
- 1.0
- 1.0
- 10
- 100
TABLE I: Hyper-parameters setting for our experiments
Session
Method 1 2 3 4 5 6 7 8 9 Avg Gap
Base classes = 60
cRT* [34] 67.30 64.15 60.59 57.32 54.22 51.43 48.92 46.78 44.85 55.06 0.94
Joint-
training* [8]
67.30 62.34 57.79 54.08 50.93 47.65 44.65 42.61 40.29 51.96 -2.17
Baseline [8] 66.78 61.29 57.25 54.17 51.54 49.00 46.56 44.99 43.78 52.82 -1.31
F2M [8] 66.07 61.05 56.82 53.51 50.76 48.26 45.79 44.07 42.62 52.11 -2.02
FSLL [5] 66.98 57.23 52.47 49.66 47.45 45.20 43.29 42.06 41.18 49.50 -4.62
ICARL [26] 66.05 56.47 53.26 50.14 47.55 45.08 42.47 41.04 39.60 49.07 -5.05
Rebalance [33] 66.45 60.66 55.61 51.38 47.93 44.64 41.40 38.75 37.11 49.33 -4.80
ICARL+ [26] 61.31 46.32 42.94 37.63 30.49 24.00 20.89 18.80 17.21 33.29 -20.84
EEIL+ [35] 61.31 46.58 44.00 37.29 33.14 27.12 24.10 21.57 19.58 34.97 -19.16
Rebalance+ [33] 61.31 47.80 39.31 31.91 25.68 21.35 18.67 17.24 14.17 30.83 -23.30
TOPIC+ [4] 61.31 50.09 45.17 41.16 37.48 35.52 32.19 29.46 24.42 39.64 -14.48
FLOWER 68.83 63.27 59.00 55.61 52.64 49.96 47.56 45.86 44.40 54.13 0.00
Base classes = 20
Baseline [8] 67.90 46.32 37.24 30.13 25.56 22.11 20.03 18.14 16.80 31.58 -3.65
F2M [8] 65.45 45.70 36.15 29.19 24.73 21.29 19.18 17.18 15.91 30.53 -4.70
FSLL [5] 68.05 47.05 37.74 30.70 26.00 22.43 20.26 18.22 17.11 31.95 -3.28
ICARL [26] 60.70 38.91 30.65 25.30 21.46 17.97 16.38 14.58 13.77 26.64 -8.59
Rebalance [33] 67.55 44.89 33.67 29.04 25.32 22.19 20.73 18.88 17.70 31.11 -4.12
FLOWER 75.70 52.66 41.67 33.88 28.64 24.63 22.00 19.67 18.21 35.23 0.00
TABLE II: Classification accuracy on split MiniImagenet dataset with 60 and 20 base classes averaged across 5 times run, #B indicates the number of base classes, * indicates the results are copied from F2M, + indicates the results copied from TOPIC.
Session
Method 1 2 3 4 5 6 7 8 9 Avg Gap
Base classes = 60
cRT* [34] 65.18 63.89 60.20 57.23 53.71 50.39 48.77 47.29 45.28 54.66 -4.76
Joint-
training* [8]
65.18 61.45 57.36 53.68 50.84 47.33 44.79 42.62 40.08 51.48 -7.94
Baseline [8] 72.2 67.88 64.06 60.43 57.28 54.54 52.83 50.93 48.8 58.77 -0.65
F2M [8] 71.4 66.65 63.2 59.54 56.61 54.08 52.2 50.49 48.4 58.06 -1.36
FSLL [5] 72.52 64.2 59.6 55.09 52.85 51.57 51.21 49.66 47.86 56.06 -3.36
ICARL [26] 71.72 64.86 60.46 56.61 53.76 51.1 49.1 46.95 44.64 55.47 -3.95
Rebalance [33] 74.57 66.65 60.96 55.59 50.87 46.55 43.82 40.29 37.23 52.95 -6.47
ICARL+ [26] 64.10 53.28 41.69 34.13 27.93 26.06 20.41 15.48 13.73 32.98 -26.44
EEIL+ [35] 64.10 53.11 42.71 35.15 28.96 24.98 21.01 17.26 15.85 33.68 -25.74
Rebalance+ [33] 64.10 53.05 43.96 36.97 31.61 26.73 21.23 16.78 13.54 34.22 -25.20
TOPIC+ [4] 64.10 55.88 47.07 45.16 40.11 36.38 33.96 31.55 29.37 42.62 -16.80
FLOWER 73.4 68.98 64.98 61.2 57.88 55.14 53.28 51.16 48.78 59.42 0.00
Base classes = 20
Baseline [8] 74.75 54.39 41.83 35.24 30.72 27.67 24.91 22.82 20.94 37.03 -3.29
F2M [8] 73.4 53.25 42.1 35.42 31.03 28.12 24.95 23.04 21.38 36.97 -3.35
FSLL [5] 75 52.76 39.87 33.29 29.74 26.73 24.12 22.22 20.28 36.00 -4.32
ICARL [26] 76.85 53.58 41.51 34.65 30.22 27.15 23.73 21.65 19.62 36.55 -3.77
Rebalance [33] 73.9 49.27 36.95 31.73 28.62 26.68 24.44 22.62 20.86 35.01 -5.31
FLOWER 83.2 58.93 46.01 38.56 33.48 29.54 26.49 24.4 22.27 40.32 0.00
TABLE III: Classification accuracy on split CIFAR100 dataset with 60 and 20 base classes averaged across 5 times run, #B indicates the number of base classes, * indicates the results are copied from F2M, + indicates the results copied from TOPIC.
Session
Method 1 2 3 4 5 6 7 8 9 10 11 Avg Gap
Base classes = 100
cRT*[34] 80.83 78.51 76.12 73.93 71.46 68.96 67.73 66.75 64.22 62.53 61.08 70.19 4.52
Joint- training* [8] 80.83 77.57 74.11 70.75 68.52 65.97 64.58 62.22 60.18 58.49 56.78 67.27 1.60
Baseline [8] 76.92 72.83 68.96 64.74 62.38 59.63 58.44 57.25 54.97 54.33 53.49 62.18 -3.49
F2M [8] 77.41 73.5 69.52 65.27 63.07 60.41 59.2 58.02 55.89 55.49 54.51 62.94 -2.73
FSLL [5] 76.92 70.58 64.73 57.77 55.96 54.34 53.62 53.04 50.45 50.36 49.48 57.93 -7.74
ICARL\circ [26] 75.73 69.97 65.95 61.27 57.62 55.5 54.16 53.37 50.75 50.35 48.77 58.49 -7.18
Rebalance [33] 74.83 55.51 48.36 42.38 36.98 32.82 30.15 27.82 25.96 24.44 22.83 38.37 -27.30
ICARL+ [26] 68.68 52.65 48.61 44.16 36.62 29.52 27.83 26.26 24.01 23.89 21.16 36.67 -29.00
EEIL+ [35] 68.68 53.63 47.91 44.20 36.30 27.46 25.93 24.70 23.95 24.13 22.11 36.27 -29.40
Rebalance+ [33] 68.68 62.55 50.33 45.07 38.25 32.58 28.71 26.28 23.80 19.91 17.82 37.63 -28.04
TOPIC+ [4] 68.68 62.49 54.81 49.99 42.25 41.40 38.35 35.36 32.22 28.31 26.28 43.65 -22.02
FLOWER 79.02 75.77 72.01 67.96 65.99 63.38 62.14 61.12 58.96 58.52 57.49 65.67 0.00
Base classes = 50
Baseline [8] 77.83 66.94 61.18 59.13 54.85 49.96 46.73 43.93 42.3 39.98 39.74 52.96 -4.86
F2M [8] 78.11 68.52 63.5 61.61 57.58 52.64 49.26 46.43 44.81 42.71 42.41 55.23 -2.59
FSLL [5] 77.61 65.96 60.37 58.59 55.96 51.75 48.35 45.84 44.92 42.8 42.62 54.07 -3.75
ICARL\circ [26] 75.88 62.54 57.7 54.73 49.16 43.96 40.84 38.09 36.63 34.38 34.79 48.06 -9.76
Rebalance [33] 74.73 60.91 56.29 53.93 48.31 42.49 37.4 33.42 31.32 29.29 27.93 45.09 -12.73
FLOWER 79.63 71.14 66.18 64.22 60.03 55.17 51.74 49.13 47.79 45.68 45.29 57.82 0.00
TABLE IV: Classification accuracy on split CUB 2011 dataset with 100 and 50 base classes averaged across 5 times run, #B indicates the number of base classes, * indicates the results are copied from F2M, + indicates the results copied from TOPIC, \circ indicates the method is run 1 time due to system crash.
Session
FM PMAS Ball 1 9 1-9(Avg) Gap
67.82 44.21 53.77 -0.36
67.43 44.20 53.63 -0.50
68.83 3.48 40.43 -13.70
68.83 44.40 54.13 0.00
TABLE V: Ablation study on split MiniImagenet dataset with 60 base classes averaged across 5 times run.
Session
Bound value (b) 1 9 1-9(Avg)
0.0001 67.03 43.93 53.26
0.001 67.43 43.53 53.14
0.0025 68.43 44.68 54.48
0.005 68.37 44.77 54.29
0.01 68.83 44.40 54.13
0.025 65.47 41.73 51.51
0.05 65.93 42.48 52.05
0.075 65.33 42.50 51.80
0.1 64.45 41.43 50.85
TABLE VI: Sensitivity analysis the impact of bound value to classification accuracy on MiniImageNet dataset with 60 base classes averaged across 5 time runs.
Session
Random times (M) 1 9 1-9(Avg)
1* 1.67 - -
2 68.83 44.40 54.13
3 68.75 44.65 54.52
4 69.12 44.49 54.49
5 68.73 43.74 53.89
6 68.45 40.64 53.26
7 68.77 42.69 53.57
8 69.05 44.48 54.36
9 69.13 33.77 50.12
10 68.45 44.26 54.04
TABLE VII: Sensitivity analysis the impact of random times to classification accuracy on MiniImageNet dataset with 60 base classes averaged across 5 time runs, * denotes only run in base task due to poor performance
Dataset Baseline[8] F2M[8] FSLL[5] ICARL[26] Rebalance[33] FLOWER
MiniImageNet 51.59 50.90 46.49 8.44 51.19 53.11
CIFAR-100 57.01 56.07 52.89 9.07 57.44 57.43
CUB-2011 57.76 57.73 52.86 8.22* 40.62 59.27
TABLE VIII: All sessions average accuracy with 1-shot setting on MiniImageNet, CIFAR-100 and CUB-200 dataset averaged across 5 time runs,* denotes only run 1 time due to system crash. Complete results are offered in the appendices.

V-B Numerical Results

Numerical results of consolidated algorithms across the CIFAR100, miniimagenet and CUB problems under moderate base tasks, e.g., 50 for CIFAR100 and miniimagenet and 100 for CUB, are reported in Tables II, III, IV, while the detailed (with standard deviation) results are tabulated in Table LABEL:tab:full_miniimagenet, LABEL:tab:full_cifar100, LABEL:tab:full_cub. FLOWER delivers the most encouraging performance in the CIFAR100 problem and even demonstrates better performances than the upper bounds, cRT and joint-training. It respectively beats baseline with 0.65%0.65\% gap and F2M with 1.36%1.36\% gap. The same finding is observed in realm of per session accuracy where FLOWER outperforms other consolidated algorithms. FLOWER consistently exceeds other algorithms in the miniimagenet problem even with higher differences than that of the CIFAR100 problem. It respectively beats the baseline with 1.31%1.31\% gap and the F2M with 2.02%2.02\% gap. FLOWER also delivers higher per session accuracy than those other algorithms. As with other two problems, FLOWER is superior to others in the CUB problem where it outperforms F2M (2.73%2.73\% margin) and Baseline (3.49%3.49\% margin) while consistently attains higher per session accuracy than other algorithms.

FLOWER’s performances are even more promising in the small base task than in the moderate base task where numerical results of consolidated algorithms are reported in Tables II, III and IV. It outperforms Baseline with 3.29%3.29\% gap and F2M with 3.35%3.35\% in the CIFAR100 problem. It achieves higher accuracy than FSLL in the miniimagenet problem with 3.28%3.28\% margin and F2M in the CUB problem with 2.59%2.59\% margin. The same pattern is observed for the context of per session accuracy. These support the efficacy of FLOWER for few-shot continual learning problems notably for small base tasks. We also confirm that the Baseline where only prototype constructions are performed in the few-shot learning sessions while leaving other parameters fixed is capable of producing strong performances as observed in [8]. That is, a model is supposed to retain decent performances of the base task here for the few-shot continual learning problems due to a large number of classes.

Summarized numerical results in the 1-shot configuration are tabulated in Table VIII, while the detailed (per-session) results are tabulated in Table LABEL:tab:1shot_miniimagenet, LABEL:tab:1shot_cifar100, and LABEL:tab:1shot_cub. FLOWER remains superior to other algorithms in the three benchmark problems under the 1-shot setting, i.e., each class of few-shot learning phases comprises 1 sample only. It produces better accuracy than baseline with visible gaps, i.e., 2%2\% for miniimagenet and CUB while its accuracy is on par to Rebalance in the CIFAR100. Note that the moderate base tasks are implemented here and Rebalance is poor on the CUB. Performance of consolidated algorithms for moderate base tasks under the 5-shot setting are pictorially illustrated in Fig. 4 where ours consistently delivers the most improved trends. Visualizations for small base tasks under the 5-shot setting and that of the 1-shot setting are provided in Fig. 5 and Fig. 6 where both exhibit the same and consistent trends .

Refer to caption
Figure 4: Visualization of Performances of Consolidated Algorithms across three problems under the moderate base tasks
Refer to caption
Figure 5: Visualization of Consolidated algorithms in MiniImagenet, CIFAR-100, and CUB-2011 dataset with 20, 20, and 50 base classes respectively.
Refer to caption
Figure 6: Visualization of Consolidated algorithms in MiniImagenet, CIFAR-100, and CUB-2011 dataset with 60, 60, and 100 base classes respectively with 1-shot setting.

V-C Ablation Study

Ablation study is executed to study the contribution of each learning module of FLOWER to the overall learning performance where numerical results under the moderate base task of the miniimagenet problem are tabulated in Table V. In a nutshell, each learning module of FLOWER contributes positively where the absence of one component results in noticeable performance degradation. FLOWER is underpinned by the ball generator for feature space data augmentation to alleviate the issue of data scarcity where its absence significantly brings down FLOWER’s accuracy by 13.70%13.70\%. The absence of flat learning and wide learning incurs performance losses of 0.36%0.36\% and 0.5%0.5\% respectively.

V-D Sensitivity Analysis

Sensitivity analyses are performed to study the effect of the flat-wide region bound bb and the number of trials MM. Numerical results across different bounds bb are reported in Table VI. It is perceived that a too small bound leads to performance deterioration because of non-flat-wide regions, i.e., the confirmation of the presence of flat-wide local minima. On the other hand, a too large bound causes inaccurate approximations of flat-wide regions, thus ending up with the CF problem. b=0.01b=0.01 is applied in our experiments.

Numerical results across different random times MM are presented in Table VII. M=1M=1 significantly compromises FLOWER’s performances since no approximation of flat-wide regions is carried out in this case. The increases of the number of trials M>1M>1 result in trivial differences in performances. However, high MM slows down execution times since FLOWER has to undergo many trials in the base learning tasks to find wide-flat regions. In this paper, we set the number of trials M[2,4]M\in[2,4] for all experiments.

VI Conclusion

This paper presents FLOWER as a solution of few-shot continual learning problems. Our in-depth numerical study confirms the advantage of FLOWER where it outperforms prior arts with 14%1-4\% margin in the three benchmark problems for both average and per session accuracy. This fact is even more obvious with small base tasks where larger margins to its counterparts are observed than moderate base tasks whereas prior arts are compromised in such setting because they heavily focus on stability of base tasks. Ablation study and sensitivity analysis further demonstrates the advantage of each learning module and robustness of FLOWER. Current few-shot continual learning relies on a strong assumption where each task is drawn from the same domain. Our future study will be devoted to address cross-domain few-shot continual learning problems.

References

  • [1] G. I. Parisi, R. Kemker, J. L. Part, C. Kanan, and S. Wermter, “Continual lifelong learning with neural networks: A review,” Neural networks : the official journal of the International Neural Network Society, vol. 113, pp. 54–71, 2019.
  • [2] G. M. van de Ven and A. Tolias, “Three scenarios for continual learning,” ArXiv, vol. abs/1904.07734, 2019.
  • [3] C. Finn, P. Abbeel, and S. Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” ArXiv, vol. abs/1703.03400, 2017.
  • [4] X. Tao, X. Hong, X. Chang, S. Dong, X. Wei, and Y. Gong, “Few-shot class-incremental learning,” 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 12 180–12 189, 2020.
  • [5] P. Mazumder, P. Singh, and P. Rai, “Few-shot lifelong learning,” in AAAI, 2021.
  • [6] K. Chen and C.-G. Lee, “Incremental few-shot learning via vector quantization in deep embedded space,” in ICLR, 2021.
  • [7] C. Zhang, N. Song, G. Lin, Y. Zheng, P. Pan, and Y. Xu, “Few-shot incremental learning with continually evolved classifiers,” 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 12 450–12 459, 2021.
  • [8] G. Shi, J. Chen, W. Zhang, L.-M. Zhan, and X.-M. Wu, “Overcoming catastrophic forgetting in incremental few-shot learning by finding flat minima,” in NeurIPS, 2021.
  • [9] J. Snell, K. Swersky, and R. S. Zemel, “Prototypical networks for few-shot learning,” ArXiv, vol. abs/1703.05175, 2017.
  • [10] S. Cha, H. Hsu, F. du Pin Calmon, and T. Moon, “Cpr: Classifier-projection regularization for continual learning,” ArXiv, vol. abs/2006.07326, 2021.
  • [11] P. Sun, Y. Ouyang, W. Zhang, and X. Dai, “Meda: Meta-learning with data augmentation for few-shot text classification,” in IJCAI, 2021.
  • [12] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, D. Hassabis, C. Clopath, D. Kumaran, and R. Hadsell, “Overcoming catastrophic forgetting in neural networks,” 2016, cite arxiv:1612.00796. [Online]. Available: http://arxiv.org/abs/1612.00796
  • [13] F. Zenke, B. Poole, and S. Ganguli, “Continual learning through synaptic intelligence,” Proceedings of machine learning research, vol. 70, pp. 3987–3995, 2017.
  • [14] R. Aljundi, F. Babiloni, M. Elhoseiny, M. Rohrbach, and T. Tuytelaars, “Memory aware synapses: Learning what (not) to forget,” in ECCV, 2018.
  • [15] I. Paik, S. Oh, T. Kwak, and I. Kim, “Overcoming catastrophic forgetting by neuron-level plasticity control,” ArXiv, vol. abs/1907.13322, 2020.
  • [16] F. Mao, W. Weng, M. Pratama, and E. Yapp, “Continual learning via inter-task synaptic mapping,” ArXiv, vol. abs/2106.13954, 2021.
  • [17] Z. Li and D. Hoiem, “Learning without forgetting,” IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 40, pp. 2935–2947, 2018.
  • [18] A. A. Rusu, N. C. Rabinowitz, G. Desjardins, H. Soyer, J. Kirkpatrick, K. Kavukcuoglu, R. Pascanu, and R. Hadsell, “Progressive neural networks,” ArXiv, vol. abs/1606.04671, 2016.
  • [19] J. Yoon, E. Yang, J. Lee, and S. J. Hwang, “Lifelong learning with dynamically expandable networks,” ArXiv, vol. abs/1708.01547, 2018.
  • [20] B. Zoph and Q. V. Le, “Neural architecture search with reinforcement learning,” ArXiv, vol. abs/1611.01578, 2017.
  • [21] X. lai Li, Y. Zhou, T. Wu, R. Socher, and C. Xiong, “Learn to grow: A continual structure learning framework for overcoming catastrophic forgetting,” in ICML, 2019.
  • [22] J. Xu, J. Ma, X. Gao, and Z. Zhu, “Adaptive progressive continual learning.” IEEE transactions on pattern analysis and machine intelligence, vol. PP, 2021.
  • [23] A. Ashfahani and M. Pratama, “Unsupervised continual learning in streaming environments,” IEEE transactions on neural networks and learning systems, vol. PP, 2022.
  • [24] M. Pratama, A. Ashfahani, and E. Lughofer, “Unsupervised continual learning via self-adaptive deep clustering approach,” ArXiv, vol. abs/2106.14563, 2021.
  • [25] A. Rakaraddi, S.-K. Lam, M. Pratama, and M. V. de Carvalho, “Reinforced continual learning for graphs,” Proceedings of the 31st ACM International Conference on Information & Knowledge Management, 2022.
  • [26] S.-A. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert, “icarl: Incremental classifier and representation learning,” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 5533–5542, 2017.
  • [27] D. Lopez-Paz and M. Ranzato, “Gradient episodic memory for continual learning,” in NIPS, 2017.
  • [28] A. Chaudhry, M. Ranzato, M. Rohrbach, and M. Elhoseiny, “Efficient lifelong learning with a-gem,” ArXiv, vol. abs/1812.00420, 2019.
  • [29] A. Chaudhry, A. Gordo, P. Dokania, P. H. S. Torr, and D. Lopez-Paz, “Using hindsight to anchor past knowledge in continual learning,” in AAAI, 2021.
  • [30] P. Buzzega, M. Boschini, A. Porrello, D. Abati, and S. Calderara, “Dark experience for general continual learning: a strong, simple baseline,” ArXiv, vol. abs/2004.07211, 2020.
  • [31] T. Dam, M. Pratama, M. M. Ferdaus, S. G. Anavatti, and H. Abbas, “Scalable adversarial online continual learning,” ArXiv, vol. abs/2209.01558, 2022.
  • [32] M. V. de Carvalho, M. Pratama, J. Zhang, and Y. San, “Class-incremental learning via knowledge amalgamation,” ArXiv, vol. abs/2209.02112, 2022.
  • [33] S. Hou, X. Pan, C. C. Loy, Z. Wang, and D. Lin, “Learning a unified classifier incrementally via rebalancing,” 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 831–839, 2019.
  • [34] B. Kang, S. Xie, M. Rohrbach, Z. Yan, A. Gordo, J. Feng, and Y. Kalantidis, “Decoupling representation and classifier for long-tailed recognition,” ArXiv, vol. abs/1910.09217, 2020.
  • [35] F. M. Castro, M. J. Marín-Jiménez, N. G. Mata, C. Schmid, and A. Karteek, “End-to-end incremental learning,” ArXiv, vol. abs/1807.09536, 2018.