B-SMALL: A Bayesian Neural Network Approach to Sparse Model-Agnostic Meta-learning
Abstract
There is a growing interest in the learning-to-learn paradigm, also known as meta-learning, where models infer on new tasks using a few training examples. Recently, meta-learning based methods have been widely used in few-shot classification, regression, reinforcement learning, and domain adaptation. The model-agnostic meta-learning (MAML) algorithm is a well-known algorithm that obtains model parameter initialization at meta-training phase. In the meta-test phase, this initialization is rapidly adapted to new tasks by using gradient descent. However, meta-learning models are prone to overfitting since there are insufficient training tasks resulting in over-parameterized models with poor generalization performance for unseen tasks. In this paper, we propose a Bayesian neural network based MAML algorithm, which we refer to as the B-SMALL algorithm. The proposed framework incorporates a sparse variational loss term alongside the loss function of MAML, which uses a sparsifying approximated KL divergence as a regularizer. We demonstrate the performance of B-MAML using classification and regression tasks, and highlight that training a sparsifying BNN using MAML indeed improves the parameter footprint of the model while performing at par or even outperforming the MAML approach. We also illustrate applicability of our approach in distributed sensor networks, where sparsity and meta-learning can be beneficial.
Index Terms— Meta-learning, Bayesian neural networks, overfitting, variational dropout
1 Introduction
The ability to adapt and learn new models with small amounts of data is a critical aspect of several systems such as IOTs, secure communication networks, biomedical signal processing, image processing etc. Traditional signal processing has addressed such problems using Bayesian and sparse signal processing techniques under a model-driven approach, incorporating several statistical assumptions on the distribution of input data. However, the modern era of artificial intelligence, brings in the promise of model-free processing using various machine learning algorithms, with no assumptions required on the statistical properties of the signals involved.
Among several machine learning approaches proposed to deal with low-data regimes [1, 2], meta-learning is a simple yet efficient technique which aims at obtaining rapid adaptation across various tasks111Often, task refers to a subset of observations sampled from the original dataset in such a way that only a subset of the final prediction problem can be solved in the task., given small amount of data for updating the parameters pertaining to each task. In particular, model-agnostic meta-learning (MAML) is an algorithm that trains a model’s parameters such that a small number of gradient updates will lead to fast learning on a new task. Specifically, MAML obtains a meta-initialization at meta-training phase using task-specific training, and this initialization is rapidly adapted to a new task by using gradient descent in the meta-test phase. MAML is a baseline for any state-of-the-art few-shot learning method since it has been used for supervised classification, regression and reinforcement learning in the presence of task variability. Furthermore, MAML substantially outperforms techniques that use pre-training as initialization. In order to further improve on the adaptation and accuracy performance of MAML, several authors have proposed modifications such as introducing novel regularizers by analysing the optimization landscape [3], feature reuse perspective based ANIL framework [4], a meta-regularizer using information theoretic approaches for mitigating the memorization problem in MAML [5], etc.
In signal processing based applications, such as distributed signal processing, there is a need for a technique that rapidly adapts in a distributed manner using low amount of heterogeneous data at each sensor node. Furthermore, it is essential that these machine learning models be computationally simple and memory-efficient in terms of the number of parameters they require [6]. The inherent structure of the MAML algorithm lends itself in such scenarios since the task level learning in the inner iteration can be associated to per-node learning, while outer iteration parameter update agglomerates the updates from neighboring nodes, effectively enabling inference capability at each node. However, a challenge in the existing meta-learning approaches is their tendency to overfit, thereby defeating the true purpose of designing such networks [7]. It is well-known that incorporating sparsity constraints during model training guarantees statistical efficiency and robustness to overfitting, hence improving generalization performance on previously unseen tasks [8]. In the context of compact meta-learning, network pruning [9] and regularization [10, 11] have led to sparse meta-learned models without compromising generalization performance. Several methods have been proposed in order to combine deep networks and probabilistic methods for few-shot learning. In particular, in [12], the authors employ hierarchical Bayesian models for few shot learning. In [13], the authors employ a graphical model via a hierarchical Bayesian model that includes prior distribution over the weights and hyperparameters of the meta-learning model.
A popular approach for incorporating uncertainty in deep networks is using the Bayesian neural networks (BNN) [14, 15]. Although exact inference in BNNs is not possible [16], approximations based on backpropagation and sampling have been effective in incorporating uncertainty into the weights [15]. Furthermore, these networks can be made sparse, and eventually compressed to obtain light neural networks [17]. However, so far, conventional BNNs directly learn only the posterior weight distribution for a single task and have not been employed in the meta-learning framework.
Contributions: We build a meta-learning based method to tackle low-data based ambiguity that occurs while learning from small amounts of data using simple albeit highly-expressive function approximators such as neural networks. To enable this, our natural choice is the optimization based MAML framework. In order to abate overfitting we propose to design a sparse MAML algorithm for BNNs. We propose B-SMALL, where, in each of the parameter update steps of the MAML algorithm, the parameters of the BNN are updated using the sparse variational loss function proposed in the context of the sparse variational dropout (SVD) algorithm [17].
We demonstrate the effectiveness of this technique in achieving sparse models with improved accuracy on well-known datasets in the context of classification as well as regression222Code for the experiments can be found on github at https://github.com/anishmadan23/B-SMALL. Finally, we present a use-case for the proposed B-SMALL algorithm as distributed algorithms for sensor networks.
2 Preliminaries
In this section, we describe MAML, an optimization based meta-learning paradigm, followed by description of the Bayesian neural network and the SVD paradigm.
2.1 Model-Agnostic Meta-Learning (MAML)
MAML considers a set of tasks distributed as , for few-shot meta-learning. In a given meta-training epoch, a model represented by a parameterized function with parameters is adapted to a new task drawn from , using samples drawn from the data distribution(-shot). The resulting update (single) of the model’s parameters given by
(1) |
Typically, the parameter updates are computed using a few gradient descent steps evaluated for each task . The outer iteration consists of meta-optimization across all the tasks, the model parameters are updated using as given by
(2) |
where is the meta step-size. Hence, the test error on sampled tasks is the training error of the meta-learning process [18].
2.2 Bayesian Neural Networks and Sparsity
Among several manifestations of employing Bayesian methods in deep neural networks, Bayesian inference based on variational dropout (VD) for inferring the posterior distribution of network weights is quite popular [15] . In [17], the authors proposed the sparse variational dropout (SVD) technique where they provided a novel approximation of the KL-divergence term in the VD objective [15], and showed that this leads to sparse weight matrices in fully-connected and convolutional layers. The resulting BNNs are robust to over-fitting, learn from small datasets and offer uncertainty estimates through the parameters of per-weight probability distributions.
Consider a BNN with weights , and a prior distribution over the weights, . Training a BNN involves optimizing a variational lower bound given by
(3) |
where , is an approximation of the true posterior of the weights of the network parameterized by , and is the KL-divergence between the true posterior and its approximation. We employ the approximation of the above variational lower bound, termed as sparse variational dropout, as derived in [17]. Here, a multiplicative Gaussian noise is applied on a weight , which is equivalent to sampling from . Training the BNN involves learning , and , i.e., the variational parameters are given by .
3 BNN based Sparse MAML (B-SMALL)
Consider a task distribution over the set of tasks that encompasses the data points in , for . These tasks are used for meta-training a model which predicts given inputs and parameters . We adopt the SVD framework to introduce sparsity and reduce overfitting in our meta-learning pipeline, i.e., we maximize the variational lower bound and accordingly modify the loss function of MAML in the inner and outer loop as follows:
(4) |
Here, similar to the previous section, denotes the variational parameters given by . Further, is interpreted as the dropout rate. Note that SVD enables us to have individual dropout rates for each neuron that are learnable. Furthermore, the regularization term is such that for several neurons. A high dropout value implies we can effectively ignore the corresponding weight or neuron and remove it from the model, leading to lighter neural network models. Also, takes the form similar to the cross entropy loss for discrete classification problem, and squared loss in the case of regression problem [17].
4 Experiments and Results
In this section, we illustrate the performance of the proposed B-SMALL approach in the context of both, classification and regression. We evaluate the classification performance on few-shot image recognition benchmarks such as the Mini-Imagenet[19] and CIFAR-FS[20] datasets [18]. The setup is a -way, -shot experiment, where we randomly select classes and choose images/samples for each class at every training step. All the models in different experiments are trained for steps. We measure sparsity as the ratio of total number of weights above a certain threshold , and total number of weights. We set for all our experiments, and consider those neurons as dropped out when , where is the variational parameters in (4).
4.1 -Shot Regression
We illustrate the performance of the proposed B-SMALL framework on -shot regression, where the underlying ground-truth function that relates the input to the output is . We choose the amplitude range as and phase as and construct the meta-train and meta-test sets by sampling data points uniformly from . We choose a neural network with hidden layers of neurons each, followed by ReLU activation for this experiment. Further, we train the meta-learning model using a single gradient step, and a fixed step size . We train the models only for and fine tune it for . We evaluate mean square error (MSE) for random test points, all of which are adapted using the same points. The results in Table 3 are averaged over different random seeds. We note that like MAML, B-SMALL also continues to improve after a single gradient step (i.e., the number of gradient steps it was trained on as depicted in Fig. 1). This implies that B-SMALL is able to find an initialization for the model such that it lies in region where further improvement is possible, while providing better MSE scores when compared to MAML, as depicted in Table 3. Furthermore, B-SMALL outperforms MAML in all 3 cases alongside providing sparse weight matrices. Even on such a small model, we manage to get sparsity.
4.2 Few-Shot Classification
To illustrate the few-shot classification performance of B-SMALL, we use the Mini-Imagenet dataset which consists of classes from the Imagenet dataset [21], with each class containing images, resized to for training. The dataset is divided into classes for training, for validation and for testing. We also use the CIFAR-FS dataset proposed in [20], which consists of classes and follows a similar split as Mini-Imagenet. We use a neural network architecture with blocks, where each block contains Convolution, batch normalization, a ReLU layer [18]. We also use a Maxpooling layer with kernel size , which is useful to reduce the spatial dimensionality of the intermediate features. We use filters for each convolutional layer. The models were trained using gradient steps, with step size , and evaluated them using steps. We use a batch size of and tasks for and -shot training, respectively. We observe that B-SMALL performs on par with or outperforms MAML as depicted in Table. 1 and Table. 2. The aspect to be highlighted is that B-SMALL leads to sparse models which enables less overfitting during meta-train as depicted in Fig. 2. An interesting observation is the amount of sparsity in each case - when input information is large (more examples while training, i.e., higher in -shot), the models are less sparse since the network encodes the additional information into its weights, in order to drive its decisions.
Model/Experiment | 5 way accuracy | |
---|---|---|
1 shot | 5 shot | |
MAML[18] | ||
CAVIA[22] | ||
MAML(Ours) | ||
B-SMALL | ||
Sparsity |
Model/Experiment | 5 way accuracy | |
---|---|---|
1-shot | 5-shot | |
MAML[23] | ||
MAML(Ours) | ||
B-SMALL | ||
Sparsity | ||
Model/Experiment | 2 way accuracy | |
1-shot | 5-shot | |
MAML[23] | ||
MAML(Ours) | ||
B-SMALL | ||
Sparsity |
Expt | MSE @ Num Grad Steps | ||
---|---|---|---|
1 | 5 | 10 | |
MAML (k=5) | 0.8347 | 0.5415 | 0.5668 |
B-SMALL (k=5) | 0.7697 | 0.4596 | 0.4392 |
MAML (k=10) | 1.493 | 0.8088 | 0.7119 |
B-SMALL (k=10) | 1.2007 | 0.3816 | 0.3386 |
MAML (k=20) | 0.5238 | 0.0848 | 0.04555 |
B-SMALL (k=20) | 0.3445 | 0.0628 | 0.0518 |


4.3 Use-case: Sparse MAML in Sensor Networks
Consider a sensor network whose communication links are governed by a graph given by , where represents the set of vertices with and represents the set of edges. The degree of the -th vertex is given by , and each vertex is equipped with a neural network. We also assume that the sensors are connected to fusion center, which can communicate with all the sensor nodes. Without loss of generality, we assume that at the -th vertex, the neural network learns the model parameters pertaining to a set of tasks, . Translating the MAML algorithm for the sensor network, say the inner iterations of the MAML algorithm are executed at each node, i.e., at the -th node, the parameter update is given by (1). The inner iteration update is communicated to the fusion center, which obtains such updates from other vertices as well. The fusion center executes the outer iteration using (2), and the final updated weights can be communicated to the sensors for the purposes of inference. It is challenging to extend B-MAML to a distributed sensor network (in the absence of the fusion center). For instance, if is a complete graph, i.e., for all , then it is possible to implement exact MAML with a slight modification to the original algorithm. Furthermore, individual sensors have limited computational capability, and bandwidth of the communication links are limited. Hence, it is pertinent to develop distributed algorithms that are memory-efficient with minimal message exchange between nodes. We address these aspects of B-SMALL in future work.
5 Conclusion and Future Work
In this paper, we proposed the B-SMALL framework, a sparse BNN-based MAML approach for rapid adaptation to various tasks using small amounts of data. The parameters of the BNN are learnt using a sparse variational loss function. We demonstrated that the proposed framework outperformed MAML in most of the scenarios, while resulting in sparse neural network models. The results obtained builds on the theory that often, in deep learning, we have more parameters as compared to training instances, and such models are prone to overfitting [24]. This gap is amplified in meta-learning since it operates in the low-data-regime and therefore it is important to use regularization technique as in B-SMALL. This helps to reduce the parameter footprint thereby reducing overfitting, and boosts generalization performance. As a future work, we plan to design and analyse B-MAML type algorithms for distributed processing.
References
- [1] Matthew Olson, Abraham Wyner, and Richard Berk, “Modern neural networks generalize on small data sets,” in Advances in Neural Information Processing Systems, 2018, pp. 3619–3628.
- [2] Brenden M Lake, Ruslan Salakhutdinov, and Joshua B Tenenbaum, “Human-level concept learning through probabilistic program induction,” Science, vol. 350, no. 6266, pp. 1332–1338, 2015.
- [3] Simon Guiroy, Vikas Verma, and Christopher Pal, “Towards understanding generalization in gradient-based meta-learning,” arXiv preprint arXiv:1907.07287, 2019.
- [4] Aniruddh Raghu, Maithra Raghu, Samy Bengio, and Oriol Vinyals, “Rapid learning or feature reuse? towards understanding the effectiveness of maml,” arXiv preprint arXiv:1909.09157, 2019.
- [5] Mingzhang Yin, George Tucker, Mingyuan Zhou, Sergey Levine, and Chelsea Finn, “Meta-learning without memorization,” arXiv preprint arXiv:1912.03820, 2019.
- [6] Aritra Dutta, El Houcine Bergou, Ahmed M Abdelmoniem, Chen-Yu Ho, Atal Narayan Sahu, Marco Canini, and Panos Kalnis, “On the discrepancy between the theoretical analysis and practical implementations of compressed communication for distributed deep learning,” AAAI, 2020.
- [7] Jaesik Yoon, Taesup Kim, Ousmane Dia, Sungwoong Kim, Yoshua Bengio, and Sungjin Ahn, “Bayesian model-agnostic meta-learning,” in Advances in Neural Information Processing Systems, 2018, pp. 7332–7342.
- [8] Michael E Tipping, “Sparse bayesian learning and the relevance vector machine,” Journal of machine learning research, vol. 1, no. Jun, pp. 211–244, 2001.
- [9] Hongduan Tian, Bo Liu, Xiao-Tong Yuan, and Qingshan Liu, “Meta-learning with network pruning,” arXiv preprint arXiv:2007.03219, 2020.
- [10] Hung-Yu Tseng, Yi-Wen Chen, Yi-Hsuan Tsai, Sifei Liu, Yen-Yu Lin, and Ming-Hsuan Yang, “Regularizing meta-learning via gradient dropout,” arXiv preprint arXiv:2004.05859, 2020.
- [11] Sibo Gai and Donglin Wang, “Sparse model-agnostic meta-learning algorithm for few-shot learning,” in 2019 2nd China Symposium on Cognitive Computing and Hybrid Intelligence (CCHI). IEEE, 2019, pp. 127–130.
- [12] Harrison Edwards and Amos Storkey, “Towards a neural statistician,” arXiv preprint arXiv:1606.02185, 2016.
- [13] Chelsea Finn, Kelvin Xu, and Sergey Levine, “Probabilistic model-agnostic meta-learning,” in Advances in Neural Information Processing Systems, 2018, pp. 9516–9527.
- [14] Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra, “Weight uncertainty in neural networks,” arXiv preprint arXiv:1505.05424, 2015.
- [15] Durk P Kingma, Tim Salimans, and Max Welling, “Variational dropout and the local reparameterization trick,” in Advances in neural information processing systems, 2015, pp. 2575–2583.
- [16] Jiri Hron, Alexander G de G Matthews, and Zoubin Ghahramani, “Variational bayesian dropout: pitfalls and fixes,” arXiv preprint arXiv:1807.01969, 2018.
- [17] Molchanov Dmitry, Ashukha Arsenii, and Vetrov Dmitry, “Variational dropout sparsifies deep neural networks,” in 34th International Conference on Machine Learning, 2017, pp. 2498–2507.
- [18] Chelsea Finn, Pieter Abbeel, and Sergey Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” arXiv preprint arXiv:1703.03400, 2017.
- [19] Sachin Ravi and Hugo Larochelle, “Optimization as a model for few-shot learning,” 2016.
- [20] Boris Oreshkin, Pau Rodríguez López, and Alexandre Lacoste, “Tadam: Task dependent adaptive metric for improved few-shot learning,” in Advances in Neural Information Processing Systems, 2018, pp. 721–731.
- [21] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton, “Imagenet classification with deep convolutional neural networks,” in Advances in neural information processing systems, 2012, pp. 1097–1105.
- [22] Luisa Zintgraf, Kyriacos Shiarli, Vitaly Kurin, Katja Hofmann, and Shimon Whiteson, “Fast context adaptation via meta-learning,” in International Conference on Machine Learning. PMLR, 2019, pp. 7693–7702.
- [23] Luca Bertinetto, Joao F Henriques, Philip HS Torr, and Andrea Vedaldi, “Meta-learning with differentiable closed-form solvers,” arXiv preprint arXiv:1805.08136, 2018.
- [24] Geoffrey E Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R Salakhutdinov, “Improving neural networks by preventing co-adaptation of feature detectors,” arXiv preprint arXiv:1207.0580, 2012.