This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Vision Pair Learning: An Efficient Training Framework for Image Classification

Bei Tong,1 Xiaoyuan Yu 1
Abstract

Transformer is a potentially powerful architecture for vision tasks. Although equipped with more parameters and attention mechanism, its performance is not as dominant as CNN currently. CNN is usually computationally cheaper and still the leading competitor in various vision tasks. One research direction is to adopt the successful ideas of CNN and improve transformer, but it often relies on elaborated and heuristic network design. Observing that transformer and CNN are complementary in representation learning and convergence speed, we propose an efficient training framework called Vision Pair Learning (VPL) for image classification task. VPL builds up a network composed of a transformer branch, a CNN branch and pair learning module. With multi-stage training strategy, VPL enables the branches to learn from their partners during the appropriate stage of the training process, and makes them both achieve better performance with less time cost. Without external data, VPL promotes the top-11 accuracy of ViT-Base and ResNet-50 on the ImageNet-1k validation set to 83.47%83.47\% and 79.61%79.61\% respectively. Experiments on other datasets of various domains prove the efficacy of VPL and suggest that transformer performs better when paired with the differently structured CNN in VPL. we also analyze the importance of components through ablation study.

Introduction

Transformer architecture has attracted much attention since its first application in machine translation, and it has proven effective in various tasks of natural language processing. As a result, the transformer is introduced to computer vision recently. Although researchers have successfully adapted the transformer for CV tasks, currently CNN based models are superior to transformer based ones in terms of performance and speed. Main reasons include 1) The special structure of transformer lets it pay more attention to global features but ignore detailed information. Because of that, pure transformer-based methods often fail in fine grained tasks such as small object detection. 2) Problems also arise in model training. Due to the relatively large number of parameters, training transformer tends to be more computational expensive than CNN models. For example, comparing with ViT-L/16 (Dosovitskiy et al. 2020), training EfficientNetV2 (Tan and Le 2021) is 5×5\times-11×11\times faster. This offers unfriendly hardware requirement and slows down prototyping experiments.

To tackle these problems, effective components of CNN are combined with transformer, for both enhancing performance and decreasing computation cost. These approaches generally fall into two categories, 1) One may utilize CNN blocks to minimize the information loss brought by image tokenization, and reduce the dimension of input feature, so as to balance the performance and speed, e.g. Visual Transformers (Wu et al. 2020), T2T-ViT (Yuan et al. 2021). 2) Modify the general design of network. TNT (Han et al. 2021) follows the idea of “deep-narrow” and makes the transformer deeper and slimmer. CvT (Wu et al. 2021) replaces linear projection with convolution projection for less computation. Other widely-used techniques, e.g. pruning (Zhu et al. 2021) and sliding window (Liu et al. 2021), are adopted as well.

The performance of above methods are satisfactory, but mixing the advantages of CNN and transformer often relies on elaborated and heuristic network design, and may hurt the transferability. In this paper, we propose a pair learning method based on CNN and transformer for image classification tasks. The method enables two different networks to learn from each other and achieve better performance together. The structure is shown in Figure 1.

Refer to caption
Figure 1: Architecture of Vision Pair Learning. The green bold arrow lines represent the forward pass in PLM. The red dash arrow lines indicate the direction of restricted gradient flow.

The network consists of three parts, CNN branch, transformer branch and Pair Learning Module (PLM). Since the target of VPL is to build a high-level training framework for better representation learning, we directly use CNN and transformer baseline models as branches without modification. Being the core part of the model, PLM introduces a contrastive-learning-like task. The most common idea of contrastive learning (He et al. 2020; Chen et al. 2020b; Grill et al. 2020; Chen and He 2020; Caron et al. 2020) is to feed a pair of randomly transformed images to two parameter-sharing branches and predict whether they are from the same source image by the extracted embeddings. Different from that, the branches of our method read the same image as input but are distinct from each other in terms of basic network block. Due to the natural difference of the two branches’ network architecture and capacity, the primary model collapse problem in contrastive learning is easily avoided. Meanwhile due to the different inductive biases of global attention and local convolution, the transformer and CNN branches produce representations of diverse views. And the task forces the high-level representation learned by two branches to interact, thus providing extra supervision signals and accelerating the convergence. A distillation-like loss is utilized as well.

Considering that the two branches are of different data efficiency, we additional adjust the training strategy so that the branches can choose to learn from partners during the appropriate stage of the training process. The heterogeneous model with PLM and modified training protocol together constitute vision pair learning. After training, PLM is removed and branches can be used for inference without extra computation.

To summarize, our contributions are three-fold:

  • We propose a new training framework named vision pair learning. By building a heterogeneous model composed of CNN and transformer, and introducing pair learning module (PLM) and multi-stage training strategy, the framework offers both branches faster convergence speed and superior performance in image classification task. After training, PLM is removed and CNN and transformer work without additional component, therefore vision pair learning brings no extra computation burden for inference.

  • With the help of pair learning, ViT-Base obtains 81%81\% top-1 accuracy on ImageNet validation set (Deng et al. 2009) in 100-epoch training, which surpasses our reimplemented 77.1%77.1\% baseline (Touvron et al. 2020). Furthermore, with 300-epoch training, ViT-Base can reach 83.47%83.47\% and exceeds the official experimental result 81.8%81.8\% by 1.671.67 percent. And ResNet-50 achieve competitive result 79.61%79.61\%. In addition, RegNet-16GF (Radosavovic et al. 2020) and ViT-Base are trained jointly as well for further confirming the effectiveness of the proposed method. The accuracy is increased by 2.36%2.36\% and 2%2\% respectively compared to independent training.

  • Experiments on more image classification datasets suggest that vision pair learning is a robust and efficient training framework. And ablation studies prove the effectivity of proposed component including simultaneous training, restricted gradient flow, the necessity of different two branches, and so on.

Related Work

Transformer (Vaswani et al. 2017) and its applications (Devlin et al. 2018; Brown et al. 2020) have achieved great success in various natural language processing tasks in recent years. Different from traditional RNN model, its attention-only building block is effective in modeling long-range correlation of input sequences. And large number of parameters introduced by transformer architecture increase the model capacity, as well as the difficulty of optimization. When optimizing a big model, abundant training data is usually necessary for a satisfactory performance. To better apply transformer to data-scarce tasks, NLP researchers propose to utilize the pretraining technique. A transformer is typically pretrained in unsupervised manner with large-scale corpus, and then fine-tuned in downstream tasks. For examples, BERT (Devlin et al. 2018) exploits the encoder part of transformer for sequence modeling and propose two unsupervised pretraining tasks, masked language model and next sentence prediction. BERT is proven powerful and reaches the state-of-the-art in 1111 downstream NLP tasks. Transformer decoder based GPT series models (Radford et al. 2018, 2019; Brown et al. 2020), which lead the trend of expanding model size, suggest that pretrained casual language model is good at few-shot/zero-shot learning and natural language generation.

Due to the amazing performance transformer presents in NLP, researchers attempt to introduce it to various CV tasks. (Girdhar et al. 2019) proposes an action transformer model for recognizing and localizing human actions in video clips. The transformer encoder based model combines I3D (Carreira and Zisserman 2017) and RPN (Ren et al. 2016) simultaneously, and makes a significant margin in Atomic Visual Actions dataset (Gu et al. 2018). DETR (Carion et al. 2020) transfers transformer to object detection task. The main idea is to learn a 2D representation of an image using CNN backbone and feed it to the transformer to make the final detection prediction. DETR demonstrates significantly better performance on large objects but fails on small objects. Since the results are not particularly satisfactory, its follow-up UP-DETR (Dai et al. 2020) puts forward a random query patch detection method and boosts the performance of DETR with faster convergence and higher precision. IPT (Chen et al. 2020a) generates corrupted image pairs from ImageNet (Deng et al. 2009) and pretrains transformer on them. By fine-tuning the model in low-level CV tasks such as denoising, super-resolution and deraining, IPT outperforms contemporaneous approaches.

Though much progress transformer has made in CV tasks, the majority of proposed approaches just use transformer as a component. And researchers continue to explore transformer-only solutions. ViT (Dosovitskiy et al. 2020) firstly shows that a pure transformer is promising in image classification tasks. Even though its performance is weaker than some CNN-based methods, it can overtake them when pretrained with huge datasets. DeiT (Touvron et al. 2020) based on ViT further proves transformer can outperform CNN-based methods trained on the same dataset by optimizing the training strategies. Furthermore, the authors presents a distillation strategy for better performance but same inference speed.

Despite that transformer is promising in CV tasks, it is rarely used in production due to the huge computation resources it requires. VPL aims at reducing the cost by combining both advantages of transformer and CNN. It can accelerate the training speed and achieve comparable results with less training iterations. A similar approach (Zhang et al. 2018) has been proposed before, which pays more attention to promoting CNN series by using distillation iteratively. We will introduce the detailed differences between DML and our approach in next section.

Method

Given a batch of input images 𝑿={𝒙1,,𝒙N}\bm{X}=\{\bm{x}_{1},\dots,\bm{x}_{N}\}, image classification task aims to predict the ground-truth label 𝒀={𝒚1,,𝒚N}\bm{Y}=\{\bm{y}_{1},\dots,\bm{y}_{N}\}, where 𝒚iC\bm{y}_{i}\in\mathbb{N}^{C} is one-hot vector, NN is the batch size and CC is the class number. The proposed vision pair learning is composed of two branches, pair learning module and multi-stage training. The two branches are original CNN and transformer baseline models, and the pair learning module provide extra supervision signals for them. And the multi-stage training helps branches learn from their partners.

Two Branches

Both branches are widely-used image classification models, except the backbone of one is CNN and the other is transformer. We denote the extracted image embeddings (the input feature of the last fully-connected layer before softmax) as 𝑯cnn={𝒉1cnn,,𝒉Ncnn}\bm{H}^{cnn}=\{\bm{h}^{cnn}_{1},\dots,\bm{h}^{cnn}_{N}\} and 𝑯trans={𝒉1trans,,𝒉Ntrans}\bm{H}^{trans}=\{\bm{h}^{trans}_{1},\dots,\bm{h}^{trans}_{N}\} where 𝑯cnn,𝑯transN×d\bm{H}^{cnn},\bm{H}^{trans}\in\mathbb{R}^{N\times d}, and 𝒉icnn,𝒉itransd\bm{h}^{cnn}_{i},\bm{h}^{trans}_{i}\in\mathbb{R}^{d}. It is notable that if the embeddings of two branches are of different size, we will add an extra affine layer to make sure their embedding sizes are the same. The logits fed to softmax are denoted as 𝒁cnn={𝒛1cnn,,𝒛Ncnn}\bm{Z}^{cnn}=\{\bm{z}^{cnn}_{1},\dots,\bm{z}^{cnn}_{N}\} and 𝒁trans={𝒛1trans,,𝒛Ntrans}\bm{Z}^{trans}=\{\bm{z}^{trans}_{1},\dots,\bm{z}^{trans}_{N}\} where 𝒁cnn,𝒁transN×C\bm{Z}^{cnn},\bm{Z}^{trans}\in\mathbb{R}^{N\times C}, and 𝒛icnn,𝒛itransC\bm{z}^{cnn}_{i},\bm{z}^{trans}_{i}\in\mathbb{R}^{C}. The branches are trained with ground-truth classification labels and cross entropy loss. Their objective function are

CEcnn=1Ni=1Nj=1Cyi,jlogexp(zi,jcnn)k=1Cexp(zi,kcnn)\mathcal{L}_{CE-cnn}=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}y_{i,j}log\frac{exp(z^{cnn}_{i,j})}{\sum_{k=1}^{C}exp(z^{cnn}_{i,k})}\\ (1)
CEtrans=1Ni=1Nj=1Cyi,jlogexp(zi,jtrans)k=1Cexp(zi,ktrans)\mathcal{L}_{CE-trans}=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}y_{i,j}log\frac{exp(z^{trans}_{i,j})}{\sum_{k=1}^{C}exp(z^{trans}_{i,k})} (2)

Pair Learning Module

Due to the structural differences of CNN and transformer, they focus on different types of feature. The former is better at extracting local features such as texture, edge, etc. And the latter is more proficient in building global receptive field and features. To combine the advantages of both, we introduce the pair learning module. The pair learning module consists of two objectives: proposed contrastive loss and KL-divergence loss. Different from the majority of contrastive learning methods which extract the embeddings of a pair of randomly transformed images with two parameter-sharing branches and predict whether they are from the same source image, we propose to apply contrastive learning to our network in a different perspective. A group of image are simultaneously fed to two branches and two groups of embeddings are extracted respectively. We ask the model to find out which pair of embeddings represent the same image. For example, the probability of 𝒉itrans\bm{h}^{trans}_{i} being the counterpart of 𝒉icnn\bm{h}^{cnn}_{i} is formulated as

P(𝒉icnn,𝒉itrans)=exp(sim(𝒉icnn,𝒉itrans)/τ)Z\displaystyle P(\bm{h}^{cnn}_{i},\bm{h}^{trans}_{i})=\frac{\exp(sim(\bm{h}^{cnn}_{i},\bm{h}^{trans}_{i})/\tau)}{Z} (3)
Z=j=1Nexp(sim(𝒉icnn,𝒉jtrans)/τ)+k=1,kiNexp(sim(𝒉icnn,𝒉kcnn)/τ)\begin{split}Z=&\sum_{j=1}^{N}\exp(sim(\bm{h}^{cnn}_{i},\bm{h}^{trans}_{j})/\tau)+\\ &\sum_{k=1,k\neq i}^{N}\exp(sim(\bm{h}^{cnn}_{i},\bm{h}^{cnn}_{k})/\tau)\end{split} (4)

which can be regarded as a (2N1)(2N-1)-way classification. The loss function of contrastive learning target is formulated as

CL=12N(i=1NlogP(𝒉icnn,𝒉itrans)+j=1NlogP(𝒉jtrans,𝒉jcnn))\begin{split}\mathcal{L}_{CL}=-\frac{1}{2N}(&\sum_{i=1}^{N}\log P(\bm{h}^{cnn}_{i},\bm{h}^{trans}_{i})+\\ &\sum_{j=1}^{N}\log P(\bm{h}^{trans}_{j},\bm{h}^{cnn}_{j}))\end{split} (5)

where τ\tau is the temperature and sim(,)sim(\cdot,\cdot) represents the dot product. As mentioned above, since the two branches in the network are of different architecture, and have unequal model capacity and representation learning ability, our proposed contrastive loss will not lead the model to degradation easily.

We also adopt KL-divegence based loss for representation interaction between branches. Similar to knowledge distillation, we utilize the classification logits as inputs and minimize the KL-divergence between the predicted probability of two branches, which is

KL=1Ni=1NKL(g(𝒛itrans/ρ)||g(𝒛icnn/ρ))\mathcal{L}_{KL}=\frac{1}{N}\sum_{i=1}^{N}KL(g(\bm{z}^{trans}_{i}/\rho)\ ||\ g(\bm{z}^{cnn}_{i}/\rho)) (6)

where ρ\rho is the temperature and g()g(\cdot) represents the softmax function.

Preliminary distillation experiments show that if we choose only one type of loss function, CNN will perform better when using KL-divergence loss and transformer prefers contrastive loss. Based on the results, we propose to restrict the gradient flow in PLM. To be more specific, the gradients produced by contrastive loss will propagate back to only transformer branch, while the KL-divergence loss will only affect CNN branch, as shown in Figure 1. This intentional design is supported by empirical results in experiment section.

Though distillation shares ideas with our methods, the simultaneous training of branches and restricted gradient flow bring better performance to vision pair learning according to our experiments.

Multi-Stage Training

When comparing the learning curve of models in independent training, we observe that ResNet (CNN) converges considerably faster than ViT-Base (transformer) in early stage, but ViT-Base catches up gradually when epoch number increases, and reaches a higher final accuracy. Based on this observation, we propose the multi-stage training strategy so that the outperforming branch will help the other one.

We split the training process into three stages. The overall loss VPL\mathcal{L}_{VPL} changes in different stages, which is formulated as

VPL=CEcnn+CEtrans+{CLfirst x% epochsCL+LKLmiddle (100xy)% epochsKLlast y% epochs\begin{split}&\mathcal{L}_{VPL}=\mathcal{L}_{CE-cnn}+\mathcal{L}_{CE-trans}+\\ &\left\{\begin{array}[]{lr}\mathcal{L}_{CL}\qquad\qquad\quad\text{first }\quad x\%\text{ epochs}\\ \mathcal{L}_{CL}+L_{KL}\quad\text{middle }\ (100-x-y)\%\text{ epochs}\\ \mathcal{L}_{KL}\qquad\qquad\quad\text{last }\quad y\%\text{ epochs}\\ \end{array}\right.\end{split} (7)

The hard classification labels and corresponding cross-entropy losses always participate in the training. Meanwhile, CNN will lead the training and feeds additional gradient flow to the transformer in the first x%x\% epochs. During the middle (100xy)%(100-x-y)\% epochs, CNN and transformer will influence each other by optimizing different types of loss functions. In the last y%y\% epochs, the transformer plays the role of teacher until the training ends.

The training strategy is rather flexible. If yy is set to 100100, the model is similar to knowledge distillation with transformer being the teacher. If x=y=50x=y=50, two branches will be the training leader in different half time. And the training becomes one-stage if x=y=0x=y=0. We search for the best setting and finally set x=20x=20 and y=20y=20 to get the best performance by empirical results. For more detail about how different settings will influence the accuracy, see ablation study in experiment section.

Our proposed method is similar with the DML (Zhang et al. 2018) in terms of two-tower structure and distillation-like objective, but they differ in more aspects. First, the motivation of VPL is to leverage the two naturally different types of image encoders, CNN and transformer, to avoid the model collapse problem in contrastive learning, while DML aims to help multiple models learning from each other in equal position, regardless of their architectures. Second, based on the motivation, our pair learning module focuses on feature learning and proposes to use contrastive-learning-style objective, in addition to the distillation-style loss in DML. Third, DML update each model iteratively with their independent loss functions which may cause low data efficiency, while VPL proposes a unified loss function and update both models simultaneously. Fourth, we prove that the transformer model performs better when paired with the differently structured CNN rather than another transformer.

Experiment

In this section, we first introduce the implementation details. And then we evaluate the performance of VPL on ImageNet classification dataset (Deng et al. 2009). More transfer learning experiments are conducted in various datasets as well, including CIFAR-10 (Krizhevsky, Hinton et al. 2009), CIFAR-100 (Krizhevsky, Hinton et al. 2009), Oxford Flowers-102 (Nilsback and Zisserman 2008), and Stanford Cars (Krause et al. 2013). The statistics of the datasets are in Table 1. Last, we analyze the importance of different components of the VPL in ablation study.

Dataset Train images Eval Images Classes
ImageNet (Deng et al. 2009) 1,281,167 50,000 1000
CIFAR-10 (Krizhevsky, Hinton et al. 2009) 50,000 10,000 10
CIFAR-100 (Krizhevsky, Hinton et al. 2009) 50,000 10,000 100
Flowers-102 (Nilsback and Zisserman 2008) 2,040 6,149 102
Stanford Cars (Krause et al. 2013) 8,144 8,041 196
Table 1: Statistics of image classification datasets

Implementation Details

Epoch Framework ImageNet Real V2
Top-1 Top-5 Top-1 Top-5 Top-1 Top-5
100 ResNet-50× (He et al. 2016) 76.14 92.90 82.69 95.50 63.07 84.62
ResNet-50 77.22 93.34 84.18 96.08 65.19 85.88
VPL ResNet-50 77.45 93.70 84.44 96.31 65.79 86.54
ViT-Base 77.15 93.04 82.96 95.35 64.11 85.09
VPL ViT-Base 80.95 95.57 86.57 97.38 69.89 89.32
300 ResNet-50 79.13 94.66 85.45 96.82 67.47 87.54
VPL ResNet-50 79.61 94.88 85.84 96.99 68.28 88.03
ViT-Base 81.65 95.69 86.84 97.21 70.85 89.54
VPL ViT-Base 83.47 96.59 88.24 97.94 73.10 91.20
Table 2: Top-1 and Top-5 accuracy of ResNet-50, ViT-Base and VPL version. ”×\times” indicates our implementation using the official training strategies and ”\star” indicates with the reported better training scheme. For fair comparison, the input size is 2242224^{2} and no extra datasets are used for training or pre-training.

If no further explanation is provided, we use ResNet-50 (He et al. 2016) and ViT-Base (Dosovitskiy et al. 2020) as CNN and transformer backbone respectively in the following experiments.

Different from the training recipe in (He et al. 2016), we use AdamW (Loshchilov and Hutter 2017) as the optimizer and cosine decay schedule (Loshchilov and Hutter 2016) to change the learning rate. The maximum learning rate is tuned for different datasets. We set the maximum learning rate of Vit-Base and ResNet-50 to 0.0020.002 and 0.0050.005 in ImageNet task and 0.00010.0001 and 0.00050.0005 for transferring to other datasets. The weight decay ratio is set to 0.050.05 for both branches and 1e81e^{-8} for fine-tuning tasks. 1616 V100 GPUs are used in ImageNet training task and the training batch size is 1,6001,600. We introduce Exponential Moving Average (EMA) as well and the decay rate changes with the number of training iterations. The EMA operation benefits both ResNet and ViT. For the image pre-processing, we use AutoAugment (Cubuk et al. 2019) and Random Erasing (Zhong et al. 2020) to increase the difficulty of the training and help the network converge better. All reported results are average performances of five runs.

It is worth mentioning that we re-train ResNet-50 using settings above. With 100 and 300 epoches’ training, the ResNet-50 can obtain 77.2% and 79.1% top-1 accurcay on ImageNet validation set respectively, ahead of the popular baseline result used in (Ge et al. 2021; Zbontar et al. 2021; Wang and Qi 2021). By empowering the baseline, we hope to select valuable opponents produced by unified experimental setting and make fair and meaningful comparison in the subsequent experiment.

Results on ImageNet series datasets

Refer to caption
Refer to caption
Figure 2: Top-1 accuracy v.s. epoch number curves of ResNet-50, ViT-Base and VPL based versions on ImageNet validation set.

Table 2 reports the overall performance of models on ImageNet (Deng et al. 2009), ImageNet Real (Beyer et al. 2020) and ImageNet V2 matched frequency (Recht et al. 2019). We split the experiments into two groups by training epoch number. VPL consistently accelerate the convergence of two branches. For ResNet-50, though our implementation with improved training scheme is stronger than the original version, VPL based ResNet-50 still outperforms them. With 100100-epoch training, VPL push up the top-11 accuracy of ViT-Base significantly by 3.5%3.5\% and achieve a comparable performance to the one trained independently for 300300 epochs. Figure 2 shows the convergence curve of models. When the ViT-Base is fully trained with VPL, it reaches a even higher accuracy without any extra supervision signal. The results suggest that the VPL is able to offer two better optimized model with slight modification of network and training strategy.

Framework Epoch ImageNet Real V2
DML ResNet-50 300 79.00 85.61 67.37
DML ViT-Base 300 82.86 88.03 71.92
VPL ResNet-50 300 79.61 85.84 68.28
VPL ViT-Base 300 83.47 88.24 73.10
Table 3: Top-1 accuracy of DML and VPL on ImageNet series datasets.

We reimplement the DML (Zhang et al. 2018) and use ResNet-50 and ViT-Base as two small models (i.e. the branches in VPL) in training. The results are shown in Table 3. With 300300-epoch training, VPL achieves superior results to DML for both branches. The DML treats all small models equally and let them play the role of student in turn. When the models are of diverse capacity and learning efficiency, letting the smaller model guide the bigger one all the time may not the best choice. Under this situation, VPL utilizes pair learning module and multi-stage training to ensure better performance. Beyond that, DML results in lower data efficiency since it needs do twice iterations for updating both branches for once, which will become more complicated if more networks are included. Experimental results show that VPL is 1.4×1.4\times faster compared to DML.

We also compare the time cost of VPL and independent training. With 1616 V100 GPUs, independently training ResNet-50 and Vit-Base costs 55 minutes and 15.3315.33 minutes for each epoch respectively, and the time cost is 19.5819.58 minutes for VPL. This suggests that under the circumstances of limited computation resources and need for both CNN and transformer models, VPL is obviously a better choice than independent training in terms of training efficiency and final performance.

Epoch Framework ImageNet Real V2
100 DIS ResNet-50 77.18 84.11 65.45
VPL ResNet-50 77.45 84.44 65.79
DIS ViT-Base 81.04 86.89 70.16
VPL ViT-Base 80.95 86.57 69.89
300 DIS ResNet-50 79.35 85.61 67.91
VPL ResNet-50 79.61 85.84 68.28
DIS ViT-Base 83.28 88.00 72.81
VPL ViT-Base 83.47 88.24 73.10
Table 4: Top-1 accuracy of VPL and corresponding distillation models on ImageNet. Framework starting with ”DIS” means the model is trained under distillation recipe.

Transfer Learning

We further investigate the transferability of VPL. We first pretrain the models on ImageNet dataset then transfer them to other datasets for further finetuning. The results shown on Table 5 suggest that VPL gives two branches a big boost across these datasets and outperforms independently trained classification models. Compared with ResNet-50, ViT-Base benefits from VPL more and surpasses the original ViT-B/16 (Dosovitskiy et al. 2020) significantly on three of four datasets. Furthermore, it also achieve better performance than the distillation version of DeiT-B (Touvron et al. 2020) which uses RegNetY-1616GF (Radosavovic et al. 2020) as a teacher for distillation.

Framework CF-10 CF-100 Flowers Cars
ResNet-50 98.27 87.51 98.23 94.07
VPL ResNet-50 98.41 87.87 98.28 94.11
ViT-B/16 98.13 87.13 89.49 -
DeiT-B 99.10 91.30 98.80 92.90
ViT-Base 99.08 91.36 98.41 92.50
VPL ViT-Base 99.10 91.62 98.80 94.14
Table 5: Top-1 accuracy of different models and VPL in tranfer learning. CF-10, CF-100, Flowers and Cars stand for CIFAR-10, CIFAR-100, Flower-102 and Stanford Cars datasets respectively.

Ablation Study

Different Architecture

Branch 1 Branch 2 B1 B2 B1 solo
ResNet-50 ViT-Base 79.61 83.47 79.13
RegNet-16GF ViT-Base 83.57 83.78 81.21
ViT-Small ViT-Base 80.66 81.71 79.46
Table 6: Top-1 accuracy of different architecture setting of VPL on ImageNet dataset. “B1” and “B2” refers to the accuracy of “Branch 1” and “Branch 2” in the VPL. “B1 solo” refers to the accuracy of “Branch 1” in independent training.

To test the compatibility of VPL and investigate how the architectures of branches will influence the performance, we first substitute the ResNet-50 with a stronger and larger CNN, RegNet-16GF. The results are shown in Table 6. As expected, both branches get better results than solo training. ViT-Base reaches a higher accuracy with the help of a better companion and RegNet-16GF get boosted as well. This suggests that VPL is relatively insensible to the architecture of CNN branch.

To further explore how much help the architecture difference can bring in VPL, we use ViT-Small whose parameter number is close to ResNet-50, as the backbone of CNN branch, i.e. both branches are ViT based model and share the same inductive bias. Although both transformers reach improved accuracy when comparing with independent training, the performance of ViT-Base is significantly inferior to the one paired with ResNet-50. The results lead to an interesting conclusion that a weaker model may not be a bad teacher if it can offer informative and representations and signals of different views.

Distillation

Considering that vision pair learning trains branches with multi-stage strategy and promotes both branches to converge simultaneously, it is different from simple distillation. For convincing results, we develop distillation based models and compare them with VPL models. The distillation scheme is as follows, first one branch (teacher) is trained with ground-truth labels and best parameter setting from scratch (baselines in Table 2), then its weights are frozen and we train the other branch (student) with both hard labels and losses defined in pair learning module.

Since the VPL based models are always stronger than the independently trained ones (teacher models) as shown in Table 2, we focus on the results of student models in Table 4. Though the gap between two groups of models is close, VPL is able to produce slightly better classification accuracy in most settings. And one-step VPL is relatively easier for implementation than two-step distillation.

Gradient Flow in Pair Learning Module

Framework ImageNet Real V2
VPL ResNet-50 79.61 85.84 68.28
VPL ResNet-50 Bi 78.63 85.32 66.83
VPL ViT-Base 83.47 88.24 73.10
VPL ViT-Base Bi 82.57 88.01 72.03
Table 7: Comparison of Top-1 accuracy between different designs of gradient flow in pair learning module. Models end with ”Bi” depicts the loss functions will influence both branches in optimization.

As described in Pair Learning Module, we control the gradients produced by the proposed two loss functions and let each of them influence only one branch in optimization. To prove the effectivity of this design, we conduct experiments where the restriction is cancelled and the gradient flow will update the weights of both branches. As shown in Table 7, the restriction brings consistently better performance for VPL on ImageNet datasets. As for the reason, we think there is no one-size-fit-all loss function in vision pair learning and models of different architecture should choose adequate loss function for better performance.

Multi-stage training strategies

Framework ImgNet Real V2
VPL ResNet-50  x=y=0x=y=0 79.51 85.75 68.06
VPL ResNet-50  x=y=20x=y=20 79.61 85.84 68.28
VPL ResNet-50  x=y=40x=y=40 79.53 85.76 67.72
VPL ViT-Base    x=y=0x=y=0 83.40 88.10 72.99
VPL ViT-Base    x=y=20x=y=20 83.47 88.24 73.10
VPL ViT-Base    x=y=40x=y=40 83.04 87.74 72.46
Table 8: Results of multi-stage training of different proportions. The percent number indicates the proportion of the first stage and the last stage.

We investigate how the proportions of different stages in training will affect the model performance. To make the comparison brief, we set the proportions of the first stage and the last stage to be equal, i.e. x=yx=y in multi-stage training, and change xx. In Table 8, the results of different proportions vary in a relatively narrow range. The multi-stage training slightly enhances the accuracy and the improvement is limited. And the best performance is achieved when x and y are equal to 2020.

Conclusion

In this paper, we have proposed a new training framework named vision pair learning (VPL) for image classification task. The VPL makes use of the advantages of CNN and transformer and promote theirs performances simultaneously. We prove the effectivity of VPL in ImageNet series datasets and four downstream datasets. And the experiments further verify its ability in building strong representation. Future work will be on improving the model’s capability and flexibility to deal with more different structures and verify its ability in other vision fields. We hope the viewpoints and experimental results described in this work will be helpful to the follow-up work.

References

  • Beyer et al. (2020) Beyer, L.; Hénaff, O. J.; Kolesnikov, A.; Zhai, X.; and Oord, A. v. d. 2020. Are we done with ImageNet? arXiv preprint arXiv:2006.07159.
  • Brown et al. (2020) Brown, T. B.; Mann, B.; Ryder, N.; Subbiah, M.; Kaplan, J.; Dhariwal, P.; Neelakantan, A.; Shyam, P.; Sastry, G.; Askell, A.; et al. 2020. Language models are few-shot learners. arXiv preprint arXiv:2005.14165.
  • Carion et al. (2020) Carion, N.; Massa, F.; Synnaeve, G.; Usunier, N.; Kirillov, A.; and Zagoruyko, S. 2020. End-to-end object detection with transformers. In European Conference on Computer Vision, 213–229. Springer.
  • Caron et al. (2020) Caron, M.; Misra, I.; Mairal, J.; Goyal, P.; Bojanowski, P.; and Joulin, A. 2020. Unsupervised learning of visual features by contrasting cluster assignments. arXiv preprint arXiv:2006.09882.
  • Carreira and Zisserman (2017) Carreira, J.; and Zisserman, A. 2017. Quo vadis, action recognition? a new model and the kinetics dataset. In proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 6299–6308.
  • Chen et al. (2020a) Chen, H.; Wang, Y.; Guo, T.; Xu, C.; Deng, Y.; Liu, Z.; Ma, S.; Xu, C.; Xu, C.; and Gao, W. 2020a. Pre-trained image processing transformer. arXiv preprint arXiv:2012.00364.
  • Chen et al. (2020b) Chen, T.; Kornblith, S.; Norouzi, M.; and Hinton, G. 2020b. A simple framework for contrastive learning of visual representations. In International conference on machine learning, 1597–1607. PMLR.
  • Chen and He (2020) Chen, X.; and He, K. 2020. Exploring Simple Siamese Representation Learning. arXiv preprint arXiv:2011.10566.
  • Cubuk et al. (2019) Cubuk, E. D.; Zoph, B.; Mane, D.; Vasudevan, V.; and Le, Q. V. 2019. Autoaugment: Learning augmentation strategies from data. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 113–123.
  • Dai et al. (2020) Dai, Z.; Cai, B.; Lin, Y.; and Chen, J. 2020. UP-DETR: Unsupervised Pre-training for Object Detection with Transformers. arXiv preprint arXiv:2011.09094.
  • Deng et al. (2009) Deng, J.; Dong, W.; Socher, R.; Li, L.-J.; Li, K.; and Fei-Fei, L. 2009. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, 248–255. Ieee.
  • Devlin et al. (2018) Devlin, J.; Chang, M.-W.; Lee, K.; and Toutanova, K. 2018. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.
  • Dosovitskiy et al. (2020) Dosovitskiy, A.; Beyer, L.; Kolesnikov, A.; Weissenborn, D.; Zhai, X.; Unterthiner, T.; Dehghani, M.; Minderer, M.; Heigold, G.; Gelly, S.; et al. 2020. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
  • Ge et al. (2021) Ge, Y.; Choi, C. L.; Zhang, X.; Zhao, P.; Zhu, F.; Zhao, R.; and Li, H. 2021. Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification. arXiv preprint arXiv:2104.13298.
  • Girdhar et al. (2019) Girdhar, R.; Carreira, J.; Doersch, C.; and Zisserman, A. 2019. Video action transformer network. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 244–253.
  • Grill et al. (2020) Grill, J.-B.; Strub, F.; Altché, F.; Tallec, C.; Richemond, P. H.; Buchatskaya, E.; Doersch, C.; Pires, B. A.; Guo, Z. D.; Azar, M. G.; et al. 2020. Bootstrap your own latent: A new approach to self-supervised learning. arXiv preprint arXiv:2006.07733.
  • Gu et al. (2018) Gu, C.; Sun, C.; Ross, D. A.; Vondrick, C.; Pantofaru, C.; Li, Y.; Vijayanarasimhan, S.; Toderici, G.; Ricco, S.; Sukthankar, R.; et al. 2018. Ava: A video dataset of spatio-temporally localized atomic visual actions. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 6047–6056.
  • Han et al. (2021) Han, K.; Xiao, A.; Wu, E.; Guo, J.; Xu, C.; and Wang, Y. 2021. Transformer in transformer. arXiv preprint arXiv:2103.00112.
  • He et al. (2020) He, K.; Fan, H.; Wu, Y.; Xie, S.; and Girshick, R. 2020. Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9729–9738.
  • He et al. (2016) He, K.; Zhang, X.; Ren, S.; and Sun, J. 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
  • Krause et al. (2013) Krause, J.; Stark, M.; Deng, J.; and Fei-Fei, L. 2013. 3d object representations for fine-grained categorization. In Proceedings of the IEEE international conference on computer vision workshops, 554–561.
  • Krizhevsky, Hinton et al. (2009) Krizhevsky, A.; Hinton, G.; et al. 2009. Learning multiple layers of features from tiny images.
  • Liu et al. (2021) Liu, Z.; Lin, Y.; Cao, Y.; Hu, H.; Wei, Y.; Zhang, Z.; Lin, S.; and Guo, B. 2021. Swin transformer: Hierarchical vision transformer using shifted windows. arXiv preprint arXiv:2103.14030.
  • Loshchilov and Hutter (2016) Loshchilov, I.; and Hutter, F. 2016. Sgdr: Stochastic gradient descent with warm restarts. arXiv preprint arXiv:1608.03983.
  • Loshchilov and Hutter (2017) Loshchilov, I.; and Hutter, F. 2017. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101.
  • Nilsback and Zisserman (2008) Nilsback, M.-E.; and Zisserman, A. 2008. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, 722–729. IEEE.
  • Radford et al. (2018) Radford, A.; Narasimhan, K.; Salimans, T.; and Sutskever, I. 2018. Improving language understanding by generative pre-training.
  • Radford et al. (2019) Radford, A.; Wu, J.; Child, R.; Luan, D.; Amodei, D.; and Sutskever, I. 2019. Language models are unsupervised multitask learners. OpenAI blog, 1(8): 9.
  • Radosavovic et al. (2020) Radosavovic, I.; Kosaraju, R. P.; Girshick, R.; He, K.; and Dollár, P. 2020. Designing network design spaces. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 10428–10436.
  • Recht et al. (2019) Recht, B.; Roelofs, R.; Schmidt, L.; and Shankar, V. 2019. Do imagenet classifiers generalize to imagenet? In International Conference on Machine Learning, 5389–5400. PMLR.
  • Ren et al. (2016) Ren, S.; He, K.; Girshick, R.; and Sun, J. 2016. Faster R-CNN: towards real-time object detection with region proposal networks. IEEE transactions on pattern analysis and machine intelligence, 39(6): 1137–1149.
  • Tan and Le (2021) Tan, M.; and Le, Q. V. 2021. EfficientNetV2: Smaller Models and Faster Training. arXiv preprint arXiv:2104.00298.
  • Touvron et al. (2020) Touvron, H.; Cord, M.; Douze, M.; Massa, F.; Sablayrolles, A.; and Jégou, H. 2020. Training data-efficient image transformers & distillation through attention. arXiv preprint arXiv:2012.12877.
  • Vaswani et al. (2017) Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A. N.; Kaiser, L.; and Polosukhin, I. 2017. Attention is all you need. arXiv preprint arXiv:1706.03762.
  • Wang and Qi (2021) Wang, X.; and Qi, G.-J. 2021. Contrastive Learning with Stronger Augmentations. arXiv preprint arXiv:2104.07713.
  • Wu et al. (2020) Wu, B.; Xu, C.; Dai, X.; Wan, A.; Zhang, P.; Yan, Z.; Tomizuka, M.; Gonzalez, J.; Keutzer, K.; and Vajda, P. 2020. Visual transformers: Token-based image representation and processing for computer vision. arXiv preprint arXiv:2006.03677.
  • Wu et al. (2021) Wu, H.; Xiao, B.; Codella, N.; Liu, M.; Dai, X.; Yuan, L.; and Zhang, L. 2021. Cvt: Introducing convolutions to vision transformers. arXiv preprint arXiv:2103.15808.
  • Yuan et al. (2021) Yuan, L.; Chen, Y.; Wang, T.; Yu, W.; Shi, Y.; Tay, F. E.; Feng, J.; and Yan, S. 2021. Tokens-to-token vit: Training vision transformers from scratch on imagenet. arXiv preprint arXiv:2101.11986.
  • Zbontar et al. (2021) Zbontar, J.; Jing, L.; Misra, I.; LeCun, Y.; and Deny, S. 2021. Barlow twins: Self-supervised learning via redundancy reduction. arXiv preprint arXiv:2103.03230.
  • Zhang et al. (2018) Zhang, Y.; Xiang, T.; Hospedales, T. M.; and Lu, H. 2018. Deep mutual learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 4320–4328.
  • Zhong et al. (2020) Zhong, Z.; Zheng, L.; Kang, G.; Li, S.; and Yang, Y. 2020. Random erasing data augmentation. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, 13001–13008.
  • Zhu et al. (2021) Zhu, M.; Han, K.; Tang, Y.; and Wang, Y. 2021. Visual Transformer Pruning. arXiv preprint arXiv:2104.08500.