This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Explanation-Guided Training for Cross-Domain Few-Shot Classification

Jiamei Sun1, Sebastian Lapuschkin2  Wojciech Samek2, Yunqing Zhao1,
Ngai-Man Cheung1, and Alexander Binder1
1Information System of Technology and Design, Singapore University of Technology and Design 2Department of Video Coding & Analytics, Fraunhofer Heinrich Hertz Institute, Berlin, Germany Email: [email protected];[email protected];[email protected]
[email protected];[email protected];[email protected]
Abstract

Cross-domain few-shot classification task (CD-FSC) combines few-shot classification with the requirement to generalize across domains represented by datasets. This setup faces challenges originating from the limited labeled data in each class and, additionally, from the domain shift between training and test sets. In this paper, we introduce a novel training approach for existing FSC models. It leverages on the explanation scores, obtained from existing explanation methods when applied to the predictions of FSC models, computed for intermediate feature maps of the models. Firstly, we tailor the layer-wise relevance propagation (LRP) method to explain the predictions of FSC models. Secondly, we develop a model-agnostic explanation-guided training strategy that dynamically finds and emphasizes the features which are important for the predictions. Our contribution does not target a novel explanation method but lies in a novel application of explanations for the training phase. We show that explanation-guided training effectively improves the model generalization. We observe improved accuracy for three different FSC models: RelationNet, cross attention network, and a graph neural network-based formulation, on five few-shot learning datasets: miniImagenet, CUB, Cars, Places, and Plantae. The source code is available https://github.com/SunJiamei/few-shot-lrp-guided

I Introduction

Human beings can recognize new objects after seeing only a few examples. However, common image classification models require large amounts of labeled samples for training or fine-tuning. To address this issue, few-shot classification (FSC) aims at generalization to new categories with a few training samples [1, 2, 3, 4, 5, 6, 7, 8]. This is relevant for setups, in which humans annotate a few examples of novel categories after model deployment, not present in the originally trained model. FSC models are commonly evaluated using test data from the same domain as the training dataset. Lately, [9] stated that the current FSC methods meet difficulties in cases exhibiting domain shifts between the training data (source domain) and the test data (target domain). For example, people can recognize different kinds of birds and plants with a few examples of each category. Contrasting, existing FSC models trained on the bird domain may not accurately recognize various kinds of plants, which is demonstrated in [9, 10]. The cross-domain few-shot classification is a more challenging and more useful task.

Tackling the domain shift problem requires additional efforts to avoid overfitting to the source domain. A recent work addresses the domain shift issue by learning a noise distribution for intermediate layers in the feature encoder [10]. Other approaches rely on adding batch spectral regularization over the encoded image features [11] and employing novel losses [12, 9]. This paper proposes a novel approach for improving CD-FSC models from a different perspective: we leverage on explanations computed for intermediate feature maps of FSC models to guide the model to learn better feature representations. For explanations, we refer to methods such as gradient- or Shapley-type methods, LRP[13] or LIME[14] that compute a score for every dimension of a feature map, denoting the importance to the final prediction.

Although a large number of explanation methods have contributed substantial progress to the field of explaining model predictions [15, 16, 13, 17, 14, 18, 19, 20], they are usually applied in the testing phase, and frequently, do not consider the use cases of explanations. Some known use cases are the audit of predictions [21], explanation-weighed document representations that are more comprehensive[22], and identification of biases in datasets [17]. We will add a new use case for explanations during the training phase, and consider whether the explanations are suitable to improve model performance in cross-domain few-shot classification.

Many explanation methods [13, 15, 17] explain predictions on a per-sample basis. With a target label and an input sample, these explanation methods assign scores to each neuron of every feature map within the model. These scores are related to the importance of a neuron to the target label. Explanations are generated usually with a modified backward-pass and require no additional trainable parameters inside the model. In this paper, we study whether the explanation scores of intermediate feature maps can be employed to improve model generality in the few-shot classification, which is still a novel question.

Refer to caption
Figure 1: LRP explanation heatmaps of the input image with 5 target labels. The experiment model is a RelationNet trained on miniImagenet under the 5-way 5-shot setting. The first row illustrates some examples of the support images. The other two rows show the explanation heatmaps of two query images, Q1: African hunting dog (denoted as dog) and Q2: lion. Both images are correctly predicted and the heatmaps are generated using different target labels. Red pixels indicate positive LRP relevance score and blue indicates negative. The strength of the color corresponds to the value of the LRP relevance scores.

Concretely, we adapt LRP-type explanations[13] to FSC models. LRP has been used to explain convolutional neural networks (CNN)[13], recurrent neural networks (RNN)[23], graph neural networks (GNN)[24], and clustering models[25]. It backpropagates the relevance score of a target label through the neural network and assigns the relevance scores to the neurons within the network. The sign and the amplitude of LRP relevance scores reflect the contribution of a neuron to the prediction, as shown in Figure 1. Relying on this property, we propose “explanation-guided training” for FSC models. The LRP relevance scores of intermediate feature maps are employed as weights and used to construct LRP-weighted feature maps. This step emphasizes the feature dimensions, which are more relevant to the model prediction, and downscales the less relevant ones. The LRP-weighted features are then fed into the network to guide the training. Since LRP explanations are calculated for each sample-label pair separately, our explanation-guided training adds a label-dependent feature weighting mechanism during training. We will show that this mechanism can reduce overfitting to the source domain. We remark that the principles used for explanation-guided training strategy are model-agnostic and can be combined with other CD-FSC methods such as the learned feature-wise transformation (LFT) [10] and other explanation methods. The main contributions of this paper are described as follows.

  • We derive explanations for FSC models using LRP.

  • We investigate the potential of improving model performance using explanations in the training phase under few-shot settings.

  • We propose an explanation-guided training strategy to tackle the domain shift problem in FSC.

  • We conduct experiments to show that the explanation-guided training strategy improves the model generalization for a number of FSC models and datasets.

  • We combine our explanation-guided training strategy with another recent approach, LFT [10], which shares with our approach the property of being applicable on top of existing models, and observe a synergy of these two methods further improves the performance.

II Related Work

II-A Few-shot Classification Methods

Optimization-based and metric-based approaches constitute two prominent directions in few-shot learning. The former one is learning initialization parameters that can be quickly adapted to new categories [2, 26, 6, 7] or designing a meta-optimizer that learns how to update the model parameters[27, 28, 29, 30]. Metric-based methods learn a distance metric to compare the support and query samples and classify the query image to the closest category[1, 4, 3, 5, 31, 32, 8]. Other approaches are noteworthy. [33, 34] design and add task-conditional layers to the model. [35, 36, 37] dynamically update the classifier weight for new categories. [38, 39] combine multiple modal information such as the word embedding of the class label. [40] augments the training data by hallucinating new samples. [41, 42] leverage unlabeled training samples and semi-supervising training strategy. [43] equips the model with a self-supervision mechanism. However, recent research discussed that existing FSC methods may meet difficulties with domain shift, a more challenging and practical problem[9].

II-B Cross-domain Few-shot Classification Methods

It is common to develop cross-domain few-shot classification methods from existing FSC methods. LFT[10] learns a noise distribution and adds the noise to intermediate feature maps to generate more diverse features during training and improve the model generality. In the recent CVPR Cross-Domain Few-Shot Learning challenges [44, 11] ensembled multiple feature encoders and employed batch spectral regularization over the image features for each encoder. Batch spectral regularization penalizes the singular values of the feature matrix within a batch so that the learned features maintain similar spectra across domains. [45] combined the first-order MAML[2] and the GNN metric-based method [5]. [12] applied a prototypical triplet loss to increase the inter-class distance and a large margin cosine loss to minimize the intra-class distance, which is also studied by [9] that reducing intra-class variation benefits FSC, especially for shallow image feature encoders. In our approach, we do not introduce more parameters like [10]. We are similar to [11] and [12] in adding constraints on the image features. We are different in using LRP-weighted features to guide the model to dynamically correct itself for each instance instead of penalizing feature statistics over a batch. The LRP-weighting idea has been used to generate more comprehensive document representations [22]. We are different from [22] that the re-weighting strategy is embedding into the training phase to improve the model.

II-C Explanation for Few-shot Classification Models

There exist explanation methods for deep neural networks (DNN) [13, 18, 15, 46, 17, 24] that can be adapted to FSC models, since many FSC models adopt CNN to encode image features and many metric-based methods also adopt DNN to learn the distance metric [4, 5, 31]. For FSC models that use non-parametric distance metrics, we refer to [25] that transforms various K-means classifiers into neural net structures and then applies LRP to obtain explanations. In this paper, we have chosen LRP due to its reasonable performance [47], our understanding of its hyperparameters, and its reasonable speed compared to LIME or some theoretically equally well-motivated but exhaustive Shapley-type approaches. While using other explanation methods among the faster ones would be possible, this would not change the qualitative message of this paper regarding the applicability of explanation methods for few-shot training. The results here are meant as a case for explanation methods in general, even when they are demonstrated for one approach.

III Explanation-Guided Training

Before presenting our explanation-guided training, we first introduce the cross-domain few-shot learning task and the notations. For a K-way N-shot task, denoted as an episode, we are given a support set 𝒮={(xs,ys)}s=1KN\mathcal{S}=\{(x_{s},y_{s})\}^{K*N}_{s=1} containing KK classes and NN labeled samples per class for training and a query set 𝒬={(xq,yq)}q=1nq\mathcal{Q}=\{(x_{q},y_{q})\}_{q=1}^{n_{q}} from the same classes as 𝒮\mathcal{S} for testing. The CD-FSC task is to train an FSC model using episodes {𝒮i,𝒬i}\{\mathcal{S}_{i},\mathcal{Q}_{i}\} randomly sampled from a base domain 𝒟seen\mathcal{D}_{seen} and test the model with episodes sampled from an unseen domain 𝒟unseen\mathcal{D}_{unseen}. We consider FSC models that can be outlined as Figure 2 in our study. This includes many metric-based FSC models.

The support set 𝒮\mathcal{S} and query set 𝒬\mathcal{Q} are encoded by a CNN [4, 8], possibly with augmented layers [33, 10] to obtain the support image features fsf_{s} and the query image features fqf_{q}. fsf_{s} and fqf_{q} are further processed before classification, for example, [4] simply averages the fsf_{s} over classes and concatenate the averaged class representations pairwise with fqf_{q}, [8] designs an attention module and generate the attention-weighted support and query image features, [31] applies GNN on fsf_{s} and fqf_{q} to obtain graph structured features. The processed features are fed into a classifier for predictions. The classifier can be cosine similarity[8], Euclidean distances[3], Mahalanobis distance[34], or neural nets[4, 5].

Refer to caption
Figure 2: Explanation-guided training. Blue paths denote the conventional FSC training. The red paths are originating from the explanation method. They are added after one step following the blue paths. The support samples 𝒮\mathcal{S} and the query sample 𝒬\mathcal{Q} are fed into an image encoder to obtain features fsf_{s} and fqf_{q}, which are further processed by a feature processing module. The output of feature processing fpf_{p} is fed into a classifier to make predictions. Both the feature processing and classifier modules vary across different FSC methods. The Explain block explains the model prediction pp and generate the explanations for fpf_{p}, denoted as R(fp)R(f_{p}), which are used to calculate the LRP weight wlrpw_{lrp}. The LRP-weighted feature wlrpfpw_{lrp}\odot f_{p} is fed into the classifier resulting in the updated prediction plrpp_{lrp}.

Explanation-guided training for FSC models involves the following steps. For each training episode:

Step1: One forward-pass through the model and obtain the prediction pp, illustrated as the blue path in Figure 2.

Step2: Explaining the classifier. We initialize the LRP relevance for each label and apply LRP to explain the classifier. We can obtain the relevance of the classifier input R(fp)R(f_{p}), illustrated as the Explain block.

For FSC models that implement a neural network as the classifier, the relevance scores for each label can be initialized with their logits. For the models using non-parametric distance measures such as cosine similarity and Euclidean distance, the predicted scores are positive for all labels, which will result in similar explanations. For such cases, we refer to the logit function in [25] to initialize the relevance scores. Taking the cosine similarity as an example, we first calculate the probability for each class using the exponential function via equation (1) 111For distance measures such as Euclidean distance, we need to use the negative distance to replace the similarity metric..

P(yc|fp)=exp(βcsc(fp))k=1Kexp(βcsk(fp))P(y_{c}|f_{p})=\frac{exp(\beta\cdot cs_{c}(f_{p}))}{\sum_{k=1}^{K}exp(\beta\cdot cs_{k}(f_{p}))} (1)

csk()cs_{k}(\cdot) means the cosine similarity between a query sample and class kk. fpf_{p} is the processed feature fed to the classifier. β\beta is a constant scale parameter to strengthen the highest probability. Using the probability defined above, the relevance score of class cc is defined as:

Rc=log(P(yc|fp)1P(yc|fp)(K1))R_{c}=log\left(\frac{P(y_{c}|f_{p})}{1-P(y_{c}|f_{p})}(K-1)\right) (2)

Rc,c=1KR_{c},c=1\dots K is positive when the P(yc|fp)P(y_{c}|f_{p}) is larger than 1/K1/K. In other words, the class label whose probability is larger than the random guessing probability receives a positive relevance score. With the relevance score of each target label RcR_{c}, standard LRP is applicable to backpropagate RcR_{c} through the classifier to generate the explanations.

Consider the forward pass from layer ll to layer l+1l+1 as:

yjl+1\displaystyle y^{l+1}_{j} =iwijzil+bj\displaystyle=\sum_{i}w_{ij}z^{l}_{i}+b_{j} (3)
zjl+1\displaystyle z^{l+1}_{j} =f(yjl+1)\displaystyle=f(y^{l+1}_{j})

where ii and jj are the indices of neurons in lthl^{th} and l+1thl+1^{th} layer, f()f(\cdot) is an activation function. Let R()R(\cdot) denote the relevance of a neuron and RijR_{i\leftarrow j} denote the relevance attribution from zjl+1z^{l+1}_{j} to zilz^{l}_{i}. We rely on two established LRP backpropagation mechanisms here, the LRPϵLRP_{\epsilon}-rule and the LRPαLRP_{\alpha}-rule [13].

  1. 1.

    LRPϵLRP_{\epsilon}-rule

    Rij=R(zjl+1)zilwijyjl+1+ϵsign(yjl+1)R_{i\leftarrow j}=R(z^{l+1}_{j})\frac{z^{l}_{i}w_{ij}}{y^{l+1}_{j}+\epsilon\odot\mathrm{sign}(y^{l+1}_{j})} (4)

    ϵ\epsilon is a small positive number and ϵsign(yjl+1)\epsilon\odot\mathrm{sign}(y^{l+1}_{j}) guarantees safe division.

  2. 2.

    LRPαLRP_{\alpha}-rule

    Rij=R(zjl+1)(α(zilwij)+(yjl+1)+(α1)(zilwij)(yjl+1))R_{i\leftarrow j}=R(z^{l+1}_{j})\left(\alpha\frac{(z^{l}_{i}w_{ij})^{+}}{(y^{l+1}_{j})^{+}}-(\alpha-1)\frac{(z^{l}_{i}w_{ij})^{-}}{(y^{l+1}_{j})^{-}}\right) (5)

    where α1\alpha\geqslant 1 controls the ratio of positive relevance to backpropagate. ()+=max(,0)(\ast)^{+}=\max(\ast,0), ()=min(,0)(\ast)^{-}=\min(\ast,0).

The relevance of zilz^{l}_{i} is the summation of all the relevance attribution flowing to it.

R(zil)=jRijR(z^{l}_{i})=\sum_{j}R_{i\leftarrow j} (6)

We adopt the LRPϵLRP_{\epsilon}-rule for linear layers and the LRPαLRP_{\alpha}-rule for convolutional layers to obtain R(fp)R(f_{p}), which is the suggested setting for explaining CNNs in [48]. R(fp)R(f_{p}) is normalized by its maximal absolute value.

Step3: LRP-weighted features. To emphasize the features which are more relevant to the prediction and downscale the less relevant ones, we define the LRP weights and the LRP-weighted features as

wlrp\displaystyle w_{lrp} =1+R(fp)\displaystyle=1+R(f_{p}) (7)
fplrp\displaystyle f_{p-lrp} =wlrpfp\displaystyle=w_{lrp}\odot f_{p} (8)

where \odot is the element-wise product. Note that R(fp)[1,1]R(f_{p})\in[-1,1] after normalization, thus wlrpw_{lrp} magnifies the features with positive relevance scores and downscales those with negative relevance scores. The maximal feature scaling after weighting with wlrpw_{lrp} is 22.

Step4: Finally, we forward the LRP-weighted features to the classifier to generate the explanation-guided predictions plrpp_{lrp}. The objective function merges both the model prediction pp and the explanation-guided prediction plrpp_{lrp}.

=ξce(y,p)+λce(y,plrp)\mathcal{L}=\xi\mathcal{L}_{ce}(y,p)+\lambda\mathcal{L}_{ce}(y,p_{lrp}) (9)

where ce\mathcal{L}_{ce} is the cross entropy loss. ξ\xi and λ\lambda are positive scalars that control how much information from pp and plrpp_{lrp} are used. In our experiment, ξ\xi and λ\lambda are empirically adjusted for different FSC models.

IV Experiments

We evaluate the proposed explanation-guided training on RelationNet(RN)[4] and two of the state-of-the-art models, cross attention network(CAN) [8] and GNN network [5]. The correspondence of the three FSC models to the framework in Figure 2 is summarized in Table I. We will demonstrate that explanation-guided training improves the performance of the three models on 4 cross-domain test sets.

Moreover, we also combine explanation-guided training with another approach, LFT [10]. We show that explanation-guided training is compatible with LFT and the combination further improves the performance.

IV-A Dataset and Model Preparation

TABLE I: The correspondence of RelationNet(RN), cross attention network(CAN), and graph neural network(GNN) to the framework in Figure 2.
feature processing classifier
RN pairwise concatenation relation module
CAN cross attention module cosine similarity
GNN fc layer and concatenation graph neural network

Five datasets are used in our experiment including miniImagenet[49], CUB[50], Cars[51], Places[52], and Plantae[53], which are introduced in [10]. Each dataset consists of train/val/test splits. We choose miniImagenet as the 𝒟seen\mathcal{D}_{seen} and train the FSC models on the training set, validate the models on the validation set of miniImagenet, and adopt the test sets of the other four datasets for testing.

We use Resnet10[54] as the image encoder for RN and GNN, and Resnet12 for CAN model. The three models are trained under 5-way 5-shot and 5-way 1-shot settings. The LRP parameters are α=1,ϵ=0.001\alpha=1,\epsilon=0.001 for all the experiments, following the suggestions in [48].

We experimented with varying ξ\xi and λ\lambda in eq(9) and observed that, for the classifiers with trainable parameters such as RN and GNN, fully relying on ce(y,plrp)\mathcal{L}_{ce}(y,p_{lrp}) (ξ=0\xi=0) makes the model hard to converge and only gain marginal improvement, while it works well for the non-parametric classifier such as the cosine similarity in CAN. The reason is that the explanations for a poor classifier make less sense and will distract the parameters of the classifier from the beginning, especially when there are fewer shots (smaller N). Thus, we combine the ce(y,p)\mathcal{L}_{ce}(y,p) to stable the training and increase the contribution of ce(y,p)\mathcal{L}_{ce}(y,p) for 1-shot setting. In the experiments using RN and GNN, we set ξ=1,λ=0.5\xi=1,\lambda=0.5 for 5-way 1-shot setting and ξ=1,λ=1\xi=1,\lambda=1 for 5-way 5-shot setting. For the CAN model, we set β\beta in eq(1) as 7, the same as the original model, and ξ=0,λ=1\xi=0,\lambda=1.

We follow the same implementation details as [10]222https://github.com/hytseng0509/CrossDomainFewShot and [8]333https://github.com/blue-blue272/fewshot-CAN to train the RN, GNN, and CAN model. At test time, we evaluate the performance over 2000 randomly sampled episodes, with 16 query images per episode.

TABLE II: Evaluation of explanation-guided training on cross-domain datasets using RN and CAN. We report the average accuracy of over 2000 episodes with 95% confidence intervals. The models are trained on the miniImagenet training set and tested on the test set of various domains. LRP- means explanation-guided training using LRP. T indicates transductive inference.
miniImagenet 1-shot 1-shot-T 5-shot 5-shot-T
RN 58.31±\pm0.47% 61.52±\pm0.58% 72.72±\pm0.37% 73.64±\pm0.40%
LRP-RN 60.06±\pm0.47% 62.65±\pm0.56% 73.63±\pm0.37% 74.67±\pm0.39%
CAN 64.66±\pm0.48% 67.74±\pm0.54% 79.61±\pm0.33% 80.34±\pm0.35%
LRP-CAN 64.65±\pm0.46% 69.10±\pm0.53% 80.89±\pm0.32% 82.56±\pm0.33%
mini-CUB 1-shot 1-shot-T 5-shot 5-shot-T
RN 41.98±\pm0.41% 42.52±\pm0.48% 58.75±\pm0.36% 59.10±\pm0.42%
LRP-RN 42.44±\pm0.41% 42.88±\pm0.48% 59.30±\pm0.40% 59.22±\pm0.42%
CAN 44.91±\pm0.41% 46.63±\pm0.50% 63.09±\pm0.39% 62.09±\pm0.43%
LRP-CAN 46.23±\pm0.42% 48.35±\pm0.52% 66.58±\pm0.39% 66.57±\pm0.43%
mini-Cars 1-shot 1-shot-T 5-shot 5-shot-T
RN 29.32±\pm0.34% 28.56±\pm0.37% 38.91±\pm0.38% 37.45±\pm0.40%
LRP-RN 29.65±\pm0.33% 29.61±\pm0.37% 39.19±\pm0.38% 38.31±\pm0.39%
CAN 31.44±\pm0.35% 30.06±\pm0.42% 41.46±\pm0.37% 40.17±\pm0.40%
LRP-CAN 32.66±\pm0.46% 32.35±\pm0.42% 43.86±\pm0.38% 42.57±\pm0.42%
mini-Places 1-shot 1-shot-T 5-shot 5-shot-T
RN 50.87±\pm0.48% 53.63±\pm0.58% 66.47±\pm0.41% 67.43±\pm0.43%
LRP-RN 50.59±\pm0.46% 53.07±\pm0.57% 66.90±\pm0.40% 68.25±\pm0.43%
CAN 56.90±\pm0.49% 60.70±\pm0.58% 72.94±\pm0.38% 74.44±\pm0.41%
LRP-CAN 56.96±\pm0.48% 61.60±\pm0.58% 74.91±\pm0.37% 76.90±\pm0.39%
mini-Plantae 1-shot 1-shot-T 5-shot 5-shot-T
RN 33.53±\pm0.36% 33.69±\pm0.42% 47.40±\pm0.36% 46.51±\pm0.40%
LRP-RN 34.80±\pm0.37% 34.54±\pm0.42% 48.09±\pm0.35% 47.67±\pm0.39%
CAN 36.57±\pm0.37% 36.69±\pm0.42% 50.45±\pm0.36% 48.67±\pm0.40%
LRP-CAN 38.23±\pm0.45% 38.48±\pm0.43% 53.25±\pm0.36% 51.63±\pm0.41%
TABLE III: Evaluation of explanation-guided training on cross-domain datasets using GNN. We report the average accuracy of over 2000 episodes with 95% confidence intervals. The models are trained on the miniImagenet training set and tested on the test set of various domains. LRP- means explanation-guided training using LRP.
5-way 1-shot miniImagenet Cars Places CUB Plantae
GNN 64.47±\pm0.55% 30.97±\pm0.37% 54.64±\pm0.56% 46.76±\pm0.50% 37.39±\pm0.43%
LRP-GNN 65.03±\pm0.54% 32.78±\pm0.39% 54.83±\pm0.56% 48.29±\pm0.51% 37.49±\pm0.43%
5-way 5-shot miniImagenet Cars Places CUB Plantae
GNN 80.74±\pm0.41% 42.59±\pm0.42% 72.14±\pm0.45% 63.91±\pm0.47% 54.52±\pm0.44%
LRP-GNN 82.03±\pm0.40% 46.20±\pm0.46% 74.45±\pm0.47% 64.44±\pm0.48% 54.46±\pm0.46%

IV-B Evaluation for Explanation-Guided Training on Cross-Domain Setting

In this section, we evaluate the performance of RN, GNN, and CAN models trained with and without explanation-guided training on CD-FSC tasks. For more comprehensive analyses, we also implement the Transductive inference proposed by [8]. Transductive inference iteratively augments the support set using the confidently classified query images during the test phase. Specifically, we first predict the label of query images with the trained model; second, we choose the query images with higher predicted scores as the candidate images. The candidate images and their predicted label are augmented to the support set. This is an iterative process. In our experiment, we implement the transductive operation for two iterations with 35 candidates for the first iteration and 70 for the second iteration, the same strategy as [8]. GNN requires a fixed number of support images, thus we implement the transductive inference on RN and CAN models.

Table II and Table III summarise the accuracy of the RN, CAN, and GNN models trained with and without explanation-guided training. We can observe a consistent improvement after implementing explanation-guided training on both the seen-domain and the cross-domain test sets. The results are also competitive with the recent work on LFT [10] which learns a noise distribution by adding feature-wise transformation layers to the image encoder while explanation-guided training does not introduce more training parameters. To show that our approach exploits a different mechanism to improve the model, we also combine the LFT and our explanation-guided training in the next section.

TABLE IV: The results of multiple domains experiment using RelationNet. We report the average accuracy of over 2000 episodes with 95% confidence intervals. FT and LFT indicate the feature-wise transformation layer with fixed or trainable parameters. LRP- means explanation-guided training using LRP. LFT-LRP is the combination of LFT and explanation-guided training.
5-way 1-shot Cars Places CUB Plantae
RN 29.40±\pm0.33% 48.05±\pm0.46% 44.33±\pm0.43% 34.57±\pm0.38%
FT-RN 30.09±\pm0.36% 48.12±\pm0.45% 44.87±\pm0.44% 35.53±\pm0.39%
LRP-RN 30.00±\pm0.32% 48.74±\pm0.45% 45.64±\pm0.42% 36.04±\pm0.38%
LFT-RN 30.27±\pm0.34% 48.07±\pm0.46% 47.35±\pm0.44% 35.54±\pm0.38%
LFT-LRP-RN 30.68±\pm0.34% 50.19±\pm0.47% 47.78±\pm0.43% 36.58±\pm0.40%
5-way 5-shot Cars Places CUB Plantae
RN 40.01±\pm0.37% 64.56±\pm0.40% 62.50±\pm0.39% 47.58±\pm0.37%
FT-RN 40.52±\pm0.40% 64.92±\pm0.40% 61.87±\pm0.39% 48.54±\pm0.38%
LRP-RN 41.05±\pm0.37% 66.08±\pm0.40% 62.71±\pm0.39% 48.78±\pm0.37%
LFT-RN 41.51±\pm0.39% 65.35±\pm0.40% 64.11±\pm0.39% 49.29±\pm0.38%
LFT-LRP-RN 42.38±\pm0.40% 66.23±\pm0.40% 64.62±\pm0.39% 50.50±\pm0.39%

IV-C Synergies in Combining Explanation-guided Training with Feature-wise Transformation

To compare and to combine our idea with the LFT method, we apply the explanation-guided training to the multiple domain experiment as [10]. The LFT model is trained using the pseudo-seen domain and pseudo-unseen domains. In our experiment, the miniImagenet is the pseudo-seen domain. Three of the other four datasets are the pseudo-unseen domains and the model is tested on the last domain. The pseudo-unseen domains are used to train the feature-wise transformation layers and the pseudo-seen domain is used to update the other trainable parameters of the model. If the parameters of the feature-wise transformation layers are fixed, we will get the FT method that adds the noise with a fixed distribution on certain intermediate layers.

The performance of the standard RN, the FT and LFT methods, explanation-guided training, and its combination with LFT are shown in Table IV. These models are trained with the same random seed, learning rate, optimizer, and datasets. The combination of our explanation-guided training and LFT(LFT-LRP-RN) achieves the best accuracy. Comparing the results of FT-RN and LRP-RN, we can see explanation-guided training is even better without introducing more trainable parameters to the model.

We remark that the improvement observed when combining explanation-guided training with LFT shows that both optimize the model from different angles. This demonstrates the independence of both approaches as well as both their strength.

IV-D Explaining the Effect of Explanation-Guided Training

In this section, we provide an intuition for the improvement of FSC models by explanation-guided training. It is known from the information bottleneck framework that training a discriminative classifier implies learning to filter irrelevant features[55]. This compression of task-irrelevant information is also acknowledged in recent works that shed critical light on the application of the information bottleneck to deep networks[56]. There is a difference between traditional classification and few-shot classification regarding removable information. The removable information means that some intermediate feature channels related to these removable features are not activated. The traditional classification task is to classify a fixed set of classes, therefore, removing information irrelevant to these classes will not influence the discriminative capability. For example, a classifier for cat breeds will likely learn rather the features of eyes, tails, and legs than features of sofas or grass. In FSC, the classes vary across episodes. Thus, the irrelevant information of one episode can be discriminative for the next episode. Excessive information removal can be detrimental for FSC that requires generalization across new classes. This is also the reason why we obtain low accuracy on the cross-domain test sets in Table II, III, and IV.

From a classifier trained to classify a fixed set of classes, one would expect that in higher-layer feature maps, a few channels are highly activated somewhere in the spatial dimensions, while most channels show only low values overall. Explanation-guided training adopts explanation scores of the predicted class to re-weight intermediate features. If a classifier is overfitting and frequently predicts a wrong class label, then the explanation-guided training will identify the relevant features for the wrongly predicted class(step2 in SectionIII), upscale them, and the subsequent loss minimization will penalize these upscaled features more (step3&4 in SectionIII). Thus, it avoids the intermediate features from being too specialized towards a fixed set of classes and achieves better generalization.

To quantify the above intuition, we analyze the CNN encoded image features of the RN, GNN, and CAN models trained with or without explanation-guided training under the 5-way 5-shot setting, the same models as SectionIV-B. We use the test set images of the four cross-domain datasets in this experiment. Each CNN encoded image feature has a shape fCNNC×H×Wf_{CNN}\in\mathbb{R}^{C\times H\times W}. We first perform a pooling over the spatial dimensions [H,W][H,W], then compute a statistic over channels CC, and finally average the statistics over the test images. We use the 95% quantile for spatial pooling, resulting in a vector fCf\in\mathbb{R}^{C}. We do not use spatial average pooling due to the spatial sparsity of features as discriminative parts are usually present only in a small region of an image. For the same reason, median pooling would yield zeros mostly.

Refer to caption
Figure 3: The variance (the first row) and quantile difference (the second row) of the CNN encoded image feature vectors. We report the mean and the standard deviation of the two feature vector statistics over all the test set images of four cross-domain datasets. The experiment models are RelationNet(RN), GNN, and CAN with(dark-pink)/without(dark-blue) explanation-guided training.

To verify that explanation-guided training indeed reduces excessive information removal, we observe the variance and intervals between the quantiles of the image feature vectors ff, S2=(i=1C(fif¯)2)/CS^{2}=(\sum_{i=1}^{C}(f^{i}-\bar{f})^{2})/C and the 95%45%95\%-45\% quantile difference. We calculate the two statistics for each image and calculate the mean and standard deviation of the two statistics over all the test set images of four cross-domain datasets, as illustrated in Figure 3.

Lower S2S^{2} and quantile difference mean that the features are not focused on a few channels but are more balanced over every channel, which preserves more diverse information and results in better generalization for new classes. The consistent decrease of S2S^{2} and quantile difference over four cross-domain datasets after applying explanation-guided training provides some evidence that the explanation-guided training effectively avoids excessive information removal and avoids overfitting on the source domain. We note that the lower S2S^{2} and quantile difference are not due to lower first-order statistics such as the mean. For the CAN model, we observe an increased mean of ff and a decreased S2S^{2} with explanation-guided training for all the cross-domain datasets. Furthermore, the variance of some first-order statistics of ff over the test set also decrease with explanation-guided training. This is comparable to the effect of batch normalization, while batch normalization is naturally less effective for FSC.

IV-E Qualitative Results of LRP Explanation for FSC Models

The above experiments have demonstrated that, by leveraging the LRP explanation of the intermediate feature map to re-weight the same feature map, explanation-guided training effectively improves the performances of FSC models and successfully reduces the domain gap. In this section, we visualize the LRP explanation of the input images as heatmaps. From the LRP heatmaps, we can easily observe which parts of the image are used by the model to make the predictions, in other words, what features have the model learned to differentiate classes. To our best knowledge, this is the first attempt to explain the FSC models though many existing explanation methods are in principle applicable.

Figure 1 has already presented some heatmaps for the RelationNet. We further illustrate the LRP explanations of the CAN model under the 5-way 1-shot setting in Figure 4. Since there is only one training sample per class, we also show the attention heatmaps for the support images. For the correctly classified Q1Q1 and Q3Q3, LRP heatmaps for the correct label highlight the relevant features. Specifically, the LRP heatmaps can capture the features of the window frames for the bus and the head features for the malamute.

Refer to caption
Figure 4: LRP heatmaps and the attention heatmaps of the CAN model from one episode. The model is trained under the 5-way 1-shot setting. The first row shows the support images of each class. For each query image, we illustrate the attention heatmaps and the LRP heatmaps of both the support images and the query images with 5 target labels.

On the other hand, the LRP heatmaps of the other wrong labels show more negative evidence, while we can still find some interesting resemblance between the query image and the explained label. For example, in Figure 1, when we explain the label lion for Q1:African hunting dog, the LRP heatmap highlights the legs of the African hunting dog and when we explain the label cuirass (a kind of medieval soldiers’ armor) for Q2:lion, the LRP heatmap emphasizes the round contour that resembles an armor plate. In Figure 4, when we explain the label trifle for Q3:malamute, the LRP heatmap highlights the texture within a circle structure.

V Conclusion

This paper shows the usefulness of explanation methods for few-shot learning during the training phase, exemplified by, but not limited to LRP. We find two points noteworthy. Firstly, explanation-guided training successfully addresses the domain shift problem in few-shot learning, as demonstrated in the cross-domain few-shot classification task. Secondly, when combining explanation-guided training with feature-wise transformation, the model performance is further improved, indicating that these two approaches optimize the model in a non-overlapping manner. We conclude that applying explanation methods to the few-shot classification can not only provide intuitive and informative visualizations but can also be used to improve the models.

VI Acknowledgement

This work was supported by the Singaporean Ministry of Education Tier2 Grant MOE-T2-2-154 and the SUTD internal grant SGPAIRS1811. This work was also partly supported by the German Ministry for Education and Research as BIFOLD (ref. 01IS18025A and ref. 01IS18037A), and TraMeExCo (ref. 01IS18056A).

References

  • [1] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra et al., “Matching networks for one shot learning,” in NIPS, 2016, pp. 3630–3638.
  • [2] C. Finn, P. Abbeel, and S. Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” in Proceedings of the 34th ICML Volume 70.   JMLR. org, 2017, pp. 1126–1135.
  • [3] J. Snell, K. Swersky, and R. Zemel, “Prototypical networks for few-shot learning,” in NIPS, 2017, pp. 4077–4087.
  • [4] F. Sung, Y. Yang, L. Zhang, T. Xiang, P. H. Torr, and T. M. Hospedales, “Learning to compare: Relation network for few-shot learning,” in Proceedings of the IEEE CVPR, 2018, pp. 1199–1208.
  • [5] V. G. Satorras and J. B. Estrach, “Few-shot learning with graph neural networks,” in ICLR, 2018.
  • [6] A. A. Rusu, D. Rao, J. Sygnowski, O. Vinyals, R. Pascanu, S. Osindero, and R. Hadsell, “Meta-learning with latent embedding optimization,” in ICLR, 2019.
  • [7] Q. Sun, Y. Liu, T.-S. Chua, and B. Schiele, “Meta-transfer learning for few-shot learning,” in Proceedings of the IEEE CVPR, 2019, pp. 403–412.
  • [8] R. Hou, H. Chang, M. Bingpeng, S. Shan, and X. Chen, “Cross attention network for few-shot classification,” in NIPS, 2019, pp. 4005–4016.
  • [9] W.-Y. Chen, Y.-C. Liu, Z. Kira, Y.-C. F. Wang, and J.-B. Huang, “A closer look at few-shot classification,” in ICLR, 2019.
  • [10] H.-Y. Tseng, H.-Y. Lee, J.-B. Huang, and M.-H. Yang, “Cross-domain few-shot classification via learned feature-wise transformation,” in ICLR, 2020.
  • [11] B. Liu, Z. Zhao, Z. Li, J. Jiang, Y. Guo, H. Shen, and J. Ye, “Feature transformation ensemble model with batch spectral regularization for cross-domain few-shot classification,” arXiv preprint arXiv:2005.08463, 2020.
  • [12] J.-F. Yeh, H.-Y. Lee, B.-C. Tsai, Y.-R. Chen, P.-C. Huang, and W. H. Hsu, “Large margin mechanism and pseudo query set on cross-domain few-shot learning,” arXiv preprint arXiv:2005.09218, 2020.
  • [13] S. Bach, A. Binder, G. Montavon, F. Klauschen, K.-R. Müller, and W. Samek, “On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation,” PLOS ONE, vol. 10, no. 7, p. e0130140, 2015.
  • [14] M. T. Ribeiro, S. Singh, and C. Guestrin, “Why should i trust you?: Explaining the predictions of any classifier,” in Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining.   ACM, 2016, pp. 1135–1144.
  • [15] K. Simonyan, A. Vedaldi, and A. Zisserman, “Deep inside convolutional networks: Visualising image classification models and saliency maps,” in ICLR (workshop track), 2014.
  • [16] M. Sundararajan, A. Taly, and Q. Yan, “Axiomatic attribution for deep networks,” in Proceedings of the 34th ICML Volume 70.   JMLR. org, 2017, pp. 3319–3328.
  • [17] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra, “Grad-cam: Visual explanations from deep networks via gradient-based localization,” in Proceedings of the IEEE ICCV, 2017, pp. 618–626.
  • [18] G. Montavon, S. Lapuschkin, A. Binder, W. Samek, and K.-R. Müller, “Explaining nonlinear classification decisions with deep taylor decomposition,” Pattern Recognition, vol. 65, pp. 211–222, 2017.
  • [19] A. Shrikumar, P. Greenside, and A. Kundaje, “Learning important features through propagating activation differences,” in ICML, 2017, pp. 3145–3153.
  • [20] P. Kindermans, K. T. Schütt, M. Alber, K.-R. Müller, D. Erhan, B. Kim, and S. Dähne, “Learning how to explain neural networks: Patternnet and patternattribution,” in ICLR, 2018.
  • [21] S. Lapuschkin, S. Wäldchen, A. Binder, G. Montavon, W. Samek, and K.-R. Müller, “Unmasking clever hans predictors and assessing what machines really learn,” Nature Communications, vol. 10, no. 1, p. 1096, 2019.
  • [22] L. Arras, F. Horn, G. Montavon, K.-R. Müller, and W. Samek, ““what is relevant in a text document?”: An interpretable machine learning approach,” PLOS ONE, vol. 12, no. 8, p. e0181142, 2017.
  • [23] L. Arras, G. Montavon, K.-R. Müller, and W. Samek, “Explaining recurrent neural network predictions in sentiment analysis,” in Proceedings of the 8th Workshop on Computational Approaches to Subjectivity, Sentiment and Social Media Analysis, 2017, pp. 159–168.
  • [24] T. Schnake, O. Eberle, J. Lederer, S. Nakajima, K. T. Schütt, K.-R. Müller, and G. Montavon, “XAI for Graphs: Explaining graph neural network predictions by identifying relevant walks,” arXiv preprint arXiv:2006.03589, 2020.
  • [25] J. Kauffmann, M. Esders, G. Montavon, W. Samek, and K.-R. Müller, “From clustering to cluster explanations via neural networks,” arXiv:1906.07633.
  • [26] A. Nichol and J. Schulman, “Reptile: a scalable metalearning algorithm,” arXiv preprint arXiv:1803.02999, vol. 2, no. 3, p. 4, 2018.
  • [27] S. Ravi and H. Larochelle, “Optimization as a model for few-shot learning,” in ICLR, 2017.
  • [28] A. Santoro, S. Bartunov, M. Botvinick, D. Wierstra, and T. Lillicrap, “Meta-learning with memory-augmented neural networks,” in ICML, 2016, pp. 1842–1850.
  • [29] N. Mishra, M. Rohaninejad, X. Chen, and P. Abbeel, “A simple neural attentive meta-learner,” in ICLR, 2018.
  • [30] T. Munkhdalai and H. Yu, “Meta networks,” in Proceedings of the 34th ICML Volume 70, 2017, pp. 2554–2563.
  • [31] Y. Liu, J. Lee, M. Park, S. Kim, E. Yang, S. Hwang, and Y. Yang, “Learning to propagate labels: transductive propagation network for few-shot learning,” in ICLR, 2019.
  • [32] A. Devos and M. Grossglauser, “Regression networks for meta-learning few-shot classification,” 7th ICML Workshop on Automated Machine Learning, 2020.
  • [33] B. Oreshkin, P. R. López, and A. Lacoste, “Tadam: Task dependent adaptive metric for improved few-shot learning,” in NIPS, 2018, pp. 721–731.
  • [34] P. Bateni, R. Goyal, V. Masrani, F. Wood, and L. Sigal, “Improved few-shot visual classification,” in Proceedings of the IEEE CVPR, 2020, pp. 14 493–14 502.
  • [35] S. Gidaris and N. Komodakis, “Dynamic few-shot visual learning without forgetting,” in Proceedings of the IEEE CVPR, 2018, pp. 4367–4375.
  • [36] H. Qi, M. Brown, and D. G. Lowe, “Low-shot learning with imprinted weights,” in Proceedings of the IEEE CVPR, 2018, pp. 5822–5830.
  • [37] S. Gidaris and N. Komodakis, “Generating classification weights with gnn denoising autoencoders for few-shot learning,” in Proceedings of the IEEE CVPR, 2019, pp. 21–30.
  • [38] A. Li, W. Huang, X. Lan, J. Feng, Z. Li, and L. Wang, “Boosting few-shot learning with adaptive margin loss,” in Proceedings of the IEEE CVPR, 2020, pp. 12 576–12 584.
  • [39] C. Xing, N. Rostamzadeh, B. Oreshkin, and P. O. Pinheiro, “Adaptive cross-modal few-shot learning,” in NIPS, 2019, pp. 4848–4858.
  • [40] Y.-X. Wang, R. Girshick, M. Hebert, and B. Hariharan, “Low-shot learning from imaginary data,” in Proceedings of the IEEE CVPR, 2018, pp. 7278–7286.
  • [41] X. Li, Q. Sun, Y. Liu, Q. Zhou, S. Zheng, T.-S. Chua, and B. Schiele, “Learning to self-train for semi-supervised few-shot classification,” in NIPS, 2019, pp. 10 276–10 286.
  • [42] M. Ren, E. Triantafillou, S. Ravi, J. Snell, K. Swersky, J. B. Tenenbaum, H. Larochelle, and R. S. Zemel, “Meta-learning for semi-supervised few-shot classification,” in ICLR, 2018.
  • [43] S. Gidaris, A. Bursuc, N. Komodakis, P. Perez, and M. Cord, “Boosting few-shot visual learning with self-supervision,” in Proceedings of the IEEE ICCV, October 2019.
  • [44] Y. Guo, N. C. Codella, L. Karlinsky, J. R. Smith, T. Rosing, and R. Feris, “A new benchmark for evaluation of cross-domain few-shot learning,” arXiv preprint arXiv:1912.07200, 2019.
  • [45] J. Cai and S. M. Shen, “Cross-domain few-shot learning with meta fine-tuning,” arXiv preprint arXiv:2005.10544, 2020.
  • [46] J. Springenberg, A. Dosovitskiy, T. Brox, and M. Riedmiller, “Striving for simplicity: The all convolutional net,” in ICLR (workshop track), 2015.
  • [47] N. Poerner, B. Roth, and H. Schütze, “Evaluating neural network explanation methods using hybrid documents and morphosyntactic agreement.”   Association for Computational Linguistics (ACL), 2018.
  • [48] M. Kohlbrenner, A. Bauer, S. Nakajima, A. Binder, W. Samek, and S. Lapuschkin, “Towards best practice in explaining neural network decisions with LRP,” in IJCNN, 2020, pp. 1–7.
  • [49] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra et al., “Matching networks for one shot learning,” in NIPS, 2016, pp. 3630–3638.
  • [50] C. Wah, S. Branson, P. Welinder, P. Perona, and S. Belongie, “The caltech-ucsd birds-200-2011 dataset,” 2011.
  • [51] J. Krause, M. Stark, J. Deng, and L. Fei-Fei, “3d object representations for fine-grained categorization,” in Proceedings of the IEEE CVPR workshops, 2013, pp. 554–561.
  • [52] B. Zhou, A. Lapedriza, A. Khosla, A. Oliva, and A. Torralba, “Places: A 10 million image database for scene recognition,” IEEE TPAMI, vol. 40, no. 6, pp. 1452–1464, 2017.
  • [53] G. Van Horn, O. Mac Aodha, Y. Song, Y. Cui, C. Sun, A. Shepard, H. Adam, P. Perona, and S. Belongie, “The inaturalist species classification and detection dataset,” in Proceedings of the IEEE CVPR, 2018, pp. 8769–8778.
  • [54] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE CVPR, 2016, pp. 770–778.
  • [55] N. Tishby and N. Zaslavsky, “Deep learning and the information bottleneck principle,” in 2015 IEEE Information Theory Workshop (ITW).   IEEE, 2015, pp. 1–5.
  • [56] A. M. Saxe, Y. Bansal, J. Dapello, M. Advani, A. Kolchinsky, B. D. Tracey, and D. D. Cox, “On the information bottleneck theory of deep learning,” Journal of Statistical Mechanics: Theory and Experiment, vol. 2019, no. 12, p. 124020, 2019.