This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Feature Transformation Ensemble Model with Batch Spectral Regularization for Cross-Domain Few-Shot Classification

Bingyu Liu, Zhen Zhao, Zhenpeng Li, Jianan Jiang, Yuhong Guo, Jieping Ye
AI Tech, DiDi ChuXing
Abstract

In this paper, we propose a feature transformation ensemble model with batch spectral regularization for the Cross-domain few-shot learning (CD-FSL) challenge. Specifically, we proposes to construct an ensemble prediction model by performing diverse feature transformations after a feature extraction network. On each branch prediction network of the model we use a batch spectral regularization term to suppress the singular values of the feature matrix during pre-training to improve the generalization ability of the model. The proposed model can then be fine tuned in the target domain to address few-shot classification. We also further apply label propagation, entropy minimization and data augmentation to mitigate the shortage of labeled data in target domains. Experiments are conducted on a number of CD-FSL benchmark tasks with four target domains and the results demonstrate the superiority of our proposed model.

1 Introduction

Many current deep learning methods for visual recognition tasks often rely on large amounts of labeled training data to achieve high performance. Collecting and annotating such large training datasets is expensive and impractical in many cases. In order to speed up the research progress, the cross-domain few-shot learning (CD-FSL) challenge [2] has been released. It contains data from the CropDiseases, EuroSAT, ISIC2018 and ChestX datasets. The selected datasets can reflect the actual use cases for deep learning.

Meta-learning is a widely used strategy for few-shot learning. However, recent research [2] indicates that traditional “pre-training and fine-tuning” can outperform meta-learning based few-shot learning algorithms when there exists a large domain gap between source base classes and target novel classes. Nevertheless, the capacity of fine-tuning can still be limited when facing the large domain gap. To tackle this problem, in the paper, we propose a batch spectral regularization (BSR) mechanism to suppress all the singular values of the feature matrix in pre-training so that the pre-trained model can avoid overfitting to the source domain and generalize well to the target domain. Moreover, we propose a feature transformation ensemble model that builds multiple predictors in projected diverse feature spaces to facilitate cross-domain adaptation and increase prediction robustness. To mitigate the shortage of labeled data in the target domain, we exploit the unlabeled query set in fine-tuning through entropy minimization. We also apply a label propagation (LP) step to refine the original classification results, and exploit data augmentation techniques to augment both the few-shot and test instances from different angles to improve prediction performance. Experiments are conducted on the CD-FSL benchmark tasks and the results demonstrate the superiority of our proposed model over the strong fine-tuning baseline.

2 Approach

Refer to caption
Figure 1: An overview of the proposed approach.

In the cross-domain few-shot learning setting, we have a source domain (Xs,Ys)\left({{X^{s}},{Y^{s}}}\right) and a target domain (Xt,Yt)\left({{X^{t}},{Y^{t}}}\right). We use all the labeled data in the source domain for pre-training. In the target domain, a numbe of KK-way NN-shot classification tasks are sampled, each with a support set S={(xi,yi)}i=1K×NS=\left\{({x_{i}},{y_{i}})\right\}_{i=1}^{K\times N} composed of NN labeled examples from KK novel classes. The labeled support set can be used to fine-tune the pre-trained model, while a query set from the same KK classes is used to evaluate the model performance. Inspired by the ensemble networks for zero-shot learning [4], we build an ensemble prediction model for cross-domain few-shot learning. The overview of our proposed ensemble model is depicted in Fig. 1. Below we introduce the components involved in the model.

2.1 Feature Transformation Ensemble Model

We build the ensemble model by increasing the diversity of the feature representation space while maintaining the usage of the entire training data for each prediction branch network. As shown in Fig. 1, we use a Convolutional Neural Network (CNN) FBF^{B} to extract advanced visual features fBmf^{B}\in{\mathbb{R}^{m}} from the input data, and then transform the features into multiple diverse feature representation spaces using different random orthogonal projection matrices {E(1),E(2),,E(M)}\{E_{(1)},E_{(2)},\cdots,E_{(M)}\} on different branches. Each projection matrix E(i)E_{(i)} is generated in the follow way. We randomly generate a symmetric matrix Z(i)[0,1]m×mZ_{(i)}\in[0,1]^{m\times m}, and then form the orthogonal matrix E(i)E_{(i)} using the eigenvectors of Z(i)Z_{(i)} such that E(i)=[e1,e2,,em1]E_{(i)}=[e_{1},e_{2},...,e_{m-1}]^{\top}, where eie_{i} represents the eigenvector corresponding to the top ii-th eigenvalue of Z(i)Z_{(i)}. With each projection matrix E(i)E_{(i)}, we can transform the extracted features into a new feature representation space such that f(i)E=E(i)f(i)Bf^{E}_{(i)}=E_{(i)}f^{B}_{(i)}, and build a soft-max predictor C(i)C_{(i)} in this feature space. By using MM randomly generated orthogonal projection matrices, we can then build MM classifiers. In the pre-training stage in the source domain, all the labeled source data is used to train each branch network, which includes the composite feature extractor F(i)E(x)=E(i)F(i)B(x)F^{E}_{(i)}(x)=E_{(i)}F^{B}_{(i)}(x), and the classifier C(i)C_{(i)}, by minimizing the cross-entropy loss. In a training batch (XB,YB)={(X1,Y1),,(Xb,Yb)}(X_{B},Y_{B})=\{(X_{1},Y_{1}),\cdots,(X_{b},Y_{b})\}, the loss function can be written as

ce(XB,YB;C(i)F(i)E)=1bj=1bLce(C(i)F(i)E(Xj),Yj)\ell_{ce}(X_{B},Y_{B};C_{(i)}\circ F^{E}_{(i)})=\frac{1}{b}\sum\limits_{j=1}^{b}L_{ce}\left(C_{(i)}\circ F^{E}_{(i)}(X_{j}),Y_{j}\right) (1)

where Lce{L_{ce}} denotes the cross-entropy loss function. After pre-training in the source domain, the fine-tuning on the labeled support set of the target domain can be conducted in the similar way, while the testing on the query instances can be naturally produced in an ensemble manner by taking the average of the MM classifiers’ prediction results.

2.2 Batch Spectral Regularization

Previous work [1] shows that penalizing smaller singular values of a feature matrix can help mitigate negative transfer in fine-tuning. We extend this penalizer into the full spectrum and propose a batch spectral regularization (BSR) mechanism to suppress all the singular values of the batch feature matrices in pre-training, aiming to avoid overfitting to the source domain and increase generalization ability to the target domain. This regularization is applied for each branch network of the ensemble model separately in the same way. For simplicity, we omitted the branch network index in the following presentation.

Specifically, for a stochastic gradient descent based training algorithm, we work with training batches. Given a batch of training instances (XB,YB)(X_{B},Y_{B}), its feature matrix can be obtained as A=[f1E,,fbE]A=\left[f^{E}_{1},\cdots,f^{E}_{b}\right], where bb is the batch size and fiE=FE(Xi){f^{E}_{i}}=F^{E}(X_{i}) is the feature vector for the ii-th instance in the batch. The BSR term can then be written as

bsr(A)=i=1bσi2{\ell_{bsr}}\left(A\right)=\sum\limits_{i=1}^{b}{\sigma_{i}^{2}} (2)

where σ1,σ2,,σb{\sigma_{1}},{\sigma_{2}},\cdots,{\sigma_{b}} are singular values of the batch feature matrix AA. The spectral regularized training loss for each batch will be:

=ce(XB,YB;CFE)+λbsr(FE(XB))\mathcal{L}=\ell_{ce}\left(X_{B},Y_{B};C\circ F^{E}\right)+\lambda\ell_{bsr}(F^{E}(X_{B})) (3)

2.3 Label Propagation

Due to the lack of labeled data in the target domain, the model fine-tuned with the support set can can easily make wrong predictions on the query instances. Following the effective label refinement procedure in [3], we propose to apply a label propagation (LP) step to exploit the semantic information of unlabeled test data in the extracted feature space and refine the original classification results.

Given the prediction score matrix Y^0\hat{Y}^{0} on the query instances XQX_{Q} with the fine-tuned classifier CtC_{t}, we keep the top-δ\delta fraction of scores in each class (the columns of Y^0\hat{Y}^{0}) and set other values to zeros in order to propagate only the most confident predictions. We then build a k-NN graph over the query instances based on the extracted features FE(XQ)F^{E}(X_{Q}). We use the squared Euclidean distance between each pair of images such as d(i,j)=FE(xi)FE(xj)2d(i,j)=\left\|F^{E}(x_{i})-F^{E}(x_{j})\right\|^{2} to determine the k-NN graph. The RBF kernel based affinity matrix WW can be computed as follows:

Wij={exp(d(i,j)2γ2),iKNN(j)orjKNN(i)0,otherwise{W_{ij}}=\left\{{\begin{array}[]{*{20}{c}}{\exp\left(\frac{-d(i,j)}{2{\gamma^{2}}}\right),}&{i\in{\rm{KNN}}\left(j\right){\rm{or\ }}j\in{\rm{KNN}}\left(i\right)}\\ {0,}&{{\rm{otherwise}}}\end{array}}\right. (4)

where γ2{\gamma^{2}} is the radius of the RBF kernel and KNN(i){\rm{KNN}}\left(i\right) denotes the k-nearest neighbors of the ii-th image. The normalized Laplacian matrix LL can then be calculated as L=Q1/2WQ1/2L={Q^{-1/2}}W{Q^{-1/2}}, where QQ is a diagonal matrix with Qii=jWij{Q_{ii}}=\sum\nolimits_{j}{{W_{ij}}}. The label propagation is then performed to provide the following refined prediction score matrix:

Y=(IαL)1×Y^0{Y^{*}}={\left({I-\alpha L}\right)^{-1}}\times\hat{Y}^{0} (5)

where II is an identity matrix and α[0,1]\alpha\in\left[{0,1}\right] is a trade-off parameter. After LP, y^i=argmaxjYij{\hat{y}_{i}}=\arg{\max_{j}}Y_{ij}^{*} is used as the predicted class for the ii-th image.

2.4 Entropy Minimization

We extend the semi-supervised learning mechanism into the fine-tuning phase in the target domain by minimizing the prediction entropy on the unlabeled query set:

ent(XBq;CFE)=1bi=1bCFE(Xiq)log(CFE(Xiq)){\ell_{ent}}(X_{B}^{q};C\circ{F^{E}})=-\frac{1}{b}\sum\limits_{i=1}^{b}{C\circ{F^{E}}(X_{i}^{q})\log(C\circ{F^{E}}(X_{i}^{q}))} (6)

where XBqX_{B}^{q} denotes a query batch. We can add this term to the original cross-entropy loss on the support set batch (XBs,YBs)(X_{B}^{s},Y_{B}^{s}) and form a transductive fine-tuning loss function:

ft=ce(XBs,YBs;CFE)+βent(XBq;CFE){{\cal L}_{ft}}={\ell_{ce}}(X_{B}^{s},Y_{B}^{s};C\circ{F^{E}})+\beta{\ell_{ent}}(X_{B}^{q};C\circ{F^{E}}) (7)

where β\beta is a trade-off parameter.

2.5 Data Augmentation

We also exploit data augmentation (DA) strategy with several random operations to supplement the support set and make the models learn with more variations. In particular, we use combinations of some operations such as image scaling, random crop, random flip, random rotation and color jitter to generate a few variants for each image. The fine-tuning can be conducted on the augmented support set. The same augmentation can be conducted for the query set as well, where several variants of each image can be generated to share the same label. Thus the prediction result on each image can be determined by averaging the prediction results on all the augmented variants of the same image.

3 Experiments

Table 1: Hyper-parameters of augmentation operations.
Augmentation Hyper-parameters
Scale (S) 84×8484\times 84
RandomResizedCrop (C) 84×8484\times 84
ImageJitter (J) Brightness: 0.4
Contrast:0.4
Color: 0.4
RandomHorizontalFlip (H) Flip probability: 50%
RandomRotation (R) 0450-45 degrees
Table 2: Compound modes of augmentation operations.
Dataset Augmentation
ISIC & EuroSAT S + SJHR + SR + SJ + SH
& CropDiseases
ChestX S + SJH + C + CJ + CH
Table 3: Results on the CD-FSL benchmark.
Methods ChestX ISIC
5-way 5-shot 5-way 20-shot 5-way 50-shot 5-way 5-shot 5-way 20-shot 5-way 50-shot
Fine-tuning [2] 25.97%±\pm0.41% 31.32%±\pm0.45% 35.49%±\pm0.45% 48.11%±\pm0.64% 59.31%±\pm0.48% 66.48%±\pm0.56%
BSR 26.84%±\pm0.44% 35.63%±\pm0.54% 40.18%±\pm0.56% 54.42%±\pm0.66% 66.61%±\pm0.61% 71.10%±\pm0.60%
BSR+LP 27.10%±\pm0.45% 35.92%±\pm0.55% 40.56%±\pm0.56% 55.86%±\pm0.66% 67.48%±\pm0.60% 72.17%±\pm0.58%
BSR+DA 28.20%±\pm0.46% 36.72%±\pm0.51% 42.08%±\pm0.53% 54.97%±\pm0.68% 66.43%±\pm0.57% 71.62%±\pm0.60%
BSR+LP+ENT 26.86%±\pm0.45% 35.60%±\pm0.51% 42.26%±\pm0.53% 56.82%±\pm0.68% 68.97%±\pm0.56% 74.13%±\pm0.56%
BSR+LP+DA 28.50%±\pm0.48% 36.95%±\pm0.52% 42.32%±\pm0.53% 56.25%±\pm0.69% 67.31%±\pm0.57% 72.33%±\pm0.58%
BSR (Ensemble) 28.44%±\pm0.45% 37.05%±\pm0.50% 43.22%±\pm0.54% 55.47%±\pm0.68% 68.00%±\pm0.59% 73.36%±\pm0.54%
BSR+LP (Ensemble) 28.66%±\pm0.44% 37.44%±\pm0.51% 43.72%±\pm0.54% 57.14%±\pm0.67% 68.99%±\pm0.58% 74.62%±\pm0.54%
BSR+DA (Ensemble) 29.09%±\pm0.45% 37.89%±\pm0.53% 43.98%±\pm0.56% 56.13%±\pm0.66% 67.10%±\pm0.61% 73.16%±\pm0.54%
BSR+LP+ENT (Ensemble) 28.00%±\pm0.46% 36.87%±\pm0.52% 43.79%±\pm0.53% 58.02%±\pm0.68% 70.22%±\pm0.59% 74.94%±\pm0.55%
BSR+LP+DA (Ensemble) 29.72%±\pm0.45% 38.34%±\pm0.53% 44.43%±\pm0.56% 57.40%±\pm0.67% 68.09%±\pm0.60% 74.08%±\pm0.55%
Methods EuroSAT CropDiseases
5-way 5-shot 5-way 20-shot 5-way 50-shot 5-way 5-shot 5-way 20-shot 5-way 50-shot
Fine-tuning [2] 79.08%±\pm0.61% 87.64%±\pm0.47% 90.89%±\pm0.36% 89.25%±\pm0.51% 95.51%±\pm0.31% 97.68%±\pm0.21%
BSR 80.89%±\pm0.61% 90.44%±\pm0.40% 93.88%±\pm0.31% 92.17%±\pm0.45% 97.90%±\pm0.22% 99.05%±\pm0.14%
BSR+LP 84.35%±\pm0.59% 91.99%±\pm0.37% 95.02%±\pm0.27% 94.45%±\pm0.40% 98.65%±\pm0.19% 99.38%±\pm0.11%
BSR+DA 82.75%±\pm0.55% 92.61%±\pm0.31% 95.26%±\pm0.39% 93.99%±\pm0.39% 98.62%±\pm0.15% 99.42%±\pm0.08%
BSR+LP+ENT 85.70%±\pm0.53% 92.90%±\pm0.33% 95.40%±\pm0.29% 95.69%±\pm0.35% 98.60%±\pm0.18% 99.27%±\pm0.12%
BSR+LP+DA 85.97%±\pm0.52% 93.73%±\pm0.29% 96.07%±\pm0.30% 95.97%±\pm0.33% 99.10%±\pm0.12% 99.66%±\pm0.07%
BSR (Ensemble) 83.93%±\pm0.53% 92.55%±\pm0.33% 95.11%±\pm0.24% 93.54%±\pm0.41% 98.34%±\pm0.20% 99.22%±\pm0.12%
BSR+LP (Ensemble) 86.08%±\pm0.55% 93.81%±\pm0.30% 95.97%±\pm0.23% 95.48%±\pm0.38% 98.94%±\pm0.16% 99.49%±\pm0.11%
BSR+DA (Ensemble) 85.19%±\pm0.51% 93.68%±\pm0.28% 96.14%±\pm0.26% 94.80%±\pm0.36% 98.69%±\pm0.16% 99.51%±\pm0.09%
BSR+LP+ENT (Ensemble) 87.17%±\pm0.52% 93.96%±\pm0.29% 96.09%±\pm0.22% 96.04%±\pm0.36% 98.94%±\pm0.16% 99.45%±\pm0.10%
BSR+LP+DA (Ensemble) 88.13%±\pm0.49% 94.72%±\pm0.28% 96.89%±\pm0.19% 96.59%±\pm0.31% 99.16%±\pm0.14% 99.73%±\pm0.06%
Table 4: Averages across all datasets and shot levels.
Methods Average
Fine-tuning [2] 67.23%±\pm0.46%
BSR 70.76%±\pm0.46%
BSR+LP 71.91%±\pm0.44%
BSR+DA 71.89%±\pm0.44%
BSR+LP+ENT 72.68%±\pm0.42%
BSR+LP+DA 72.85%±\pm0.42%
BSR (Ensemble) 72.35%±\pm0.43%
BSR+LP (Ensemble) 73.36%±\pm0.42%
BSR+DA (Ensemble) 72.95%±\pm0.42%
BSR+LP+ENT (Ensemble) 73.62%±\pm0.42%
BSR+LP+DA (Ensemble) 73.94%±\pm0.40%

3.1 Setup

In the experiments, we use the evaluation protocol in the CD-FSL challenge [2], which takes 15 images from each class as the query set and uses 600 randomly sampled few-shot episodes in each target domain, The average accuracy and 95% confidence interval are reported.

As for the model architecture, we use ResNet-10 as the CNN feature extractor FB{F^{B}} and a fully-connected layer with soft-max activation as the classifier C{C}. We set the trade-off parameters λ=0.001\lambda=0.001, β=0.1\beta=0.1 and the number of branches M=10M=10. For the label propagation step, we use k=10k=10 for the k-NN graph construction, and set γ2{\gamma^{2}} as the average of the squared distances of the edges in the k-NN graph. The parameters δ\delta and α\alpha are set to 0.2 and 0.5 respectively. We adopt mini-batch SGD with momentum of 0.9 for both pre-training and fine-tuning. During the pre-training stage, models are trained for 400 epochs. The learning rate and the weight decay are set to 0.001 and 0.0005 respectively. During fine-tuning, we set the learning rate to 0.01 and fine-tune for 100 epochs.

For data augmentation (DA), we choose 5 types of augmentation operations as shown in Table 1. We use 5 compound modes of these operations to generate data in different target domains. The specific operations used in each target domain are shown in Table 2.

3.2 Results

We investigate a number of variants of the proposed model by comparing with the strong fine-tuning baseline result reported in [2]. We first investigate a single prediction network with batch spectral regularization (BSR) without ensemble and its other variants that further incorporate label propagation (LP) or/and data augmentation (DA). Then we extend these variants into the ensemble model framework with M=10M=10. The results are reported in Table 3, and the average accuracies (and 95% confidence internals) across all datasets and shot levels are shown in Table 4.

We can see that even with only BSR, the proposed method can already significantly outperform the fine-tuning baseline (average 70.76% vs 67.23%). The ensemble BSR further improves the results (72.35%). The LP and DA components can also help improve the CD-FSL performance. We observe that on target domains more similar to the source domain, LP performs better than DA and vice versa. This shows that LP and DA focus on different aspects of the data. As a result combining LP and DA can further improve the performances. Moreover, DA is not very effective for the ISIC domain, where it even degrades the performance in some cases. Also the experiment’s running time is typically longer with DA. By replacing DA with ENT, we can obtain similar overall performance. With a single model, the best average result achieved by BSR+LP+DA is 72.85%, while BSR+LP+ENT achieves 72.68%. With the ensemble model, BSR+LP+DA (Ensemble) produces the best average result 73.94%, while BSR+LP+ENT (Ensemble) yields 73.62%.

4 Conclusion

In this paper, we proposed a feature transformation based ensemble model for CD-FSL. The model also incorporates batch spectral regularization in pre-training, and exploits data augmentation and label propagation during fine-tuning and testing in the target domain. The combinational models produced superior CD-FSL results comparing to the strong fine-tuning baseline.

References

  • [1] Xinyang Chen, Sinan Wang, Bo Fu, Mingsheng Long, and Jianmin Wang. Catastrophic forgetting meets negative transfer: Batch spectral shrinkage for safe transfer learning. In NeurIPS, 2019.
  • [2] Yunhui Guo, Noel CF Codella, Leonid Karlinsky, John R Smith, Tajana Rosing, and Rogerio Feris. A new benchmark for evaluation of cross-domain few-shot learning. arXiv preprint arXiv:1912.07200, 2019.
  • [3] Meng Ye and Yuhong Guo. Labelless scene classification with semantic matching. In BMVC, 2017.
  • [4] Meng Ye and Yuhong Guo. Progressive ensemble networks for zero-shot recognition. In CVPR, 2019.