This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

11institutetext: School of Computer Science, Faculty of Engineering, The University of Sydney 11email: {guli5858, yson6207}@uni.sydney.edu.au
11email: [email protected]

SAU: A Dual-Branch Network to Enhance Long-Tailed Recognition via Generative Models

Guangxi Li    Yinsheng Song    Mingkai Zheng🖂
Abstract

Long-tailed distributions in image recognition pose a considerable challenge due to the severe imbalance between a few dominant classes with numerous examples and many minority classes with few samples. Recently, the use of large generative models to create synthetic data for image classification has been realized, but utilizing synthetic data to address the challenge of long-tailed recognition remains relatively unexplored. In this work, we proposed the use of synthetic data as a complement to long-tailed datasets to eliminate the impact of data imbalance. To tackle this real-synthetic mixed dataset, we designed a two-branch model that contains Synthetic-Aware and Unaware branches (SAU). The core ideas are (1) a synthetic-unaware branch for classification that mixes real and synthetic data and treats all data equally without distinguishing between them. (2) A synthetic-aware branch for improving the robustness of the feature extractor by distinguishing between real and synthetic data and learning their discrepancies. Extensive experimental results demonstrate that our method can improve the accuracy of long-tailed image recognition. Notably, our approach achieves state-of-the-art Top-1 accuracy and significantly surpasses other methods on CIFAR-10-LT and CIFAR-100-LT datasets across various imbalance factors. Our code is available at https://github.com/lgX1123/gm4lt.

Keywords:
Long-Tailed Image Recognition Image Generation Imbalanced Learning

1 Introduction

In the field of computer vision, mainstream datasets [9, 18] are characteristically balanced. This equilibrium is a crucial factor underpinning the success of deep neural networks in image recognition tasks. Conversely, in real-world scenarios, data frequently exhibits a long-tailed distribution. As a result, models trained on such imbalanced datasets often disproportionately emphasize the head classes and neglect the tail classes, leading to suboptimal performance on a balanced test set. Thus, enhancing model performance in the context of long-tailed distribution data has emerged as a significant challenge.

To mitigate the imbalanced data issue, traditional methods to address data imbalance include class re-balancing strategies such as re-sampling [34] and re-weighting [8, 28], which aim to correct the imbalance by giving more emphasis to under-represented classes. More recently, diverse approaches have emerged. Contrastive learning [7, 30, 37] enhances the model’s feature extraction capabilities and improves the model’s ability to differentiate between similar and dissimilar images by comparing. Nevertheless, most of these methods attempt to calibrate the discrimination of the tail classes rather than addressing the root issue of data imbalance.

In recent years, generative models [13, 1, 29, 27, 26, 22] have experienced significant advancements, demonstrating immense potential across various applications. Researchers have begun to leverage synthetic data generated by these text-to-image (T2I) models for visual tasks [15, 3, 16]. Despite their innovative capabilities, synthetic images often suffer from unrealistic hallucinations, which can mislead downstream image processing tasks. Therefore, effectively utilizing the vast amounts of synthetic data generated by T2I models remains a crucial challenge that needs to be addressed.

In this work, inspired by [15, 35], we aim to enhance long-tailed image recognition by leveraging powerful large generative models. Specifically, we adopt GLIDE [22] as our text-to-image (T2I) model. To generate diverse and high-quality synthetic images, we utilize GPT-4 [1] for creating varied image descriptions and employ CLIP [25] to filter out low-quality synthetic images. The resulting synthetic data is then used to augment the long-tailed dataset, creating a more balanced dataset for training our models. To effectively handle the real-synthetic mixed dataset, we propose a Synthetic-Aware and Unaware two-branch framework (SAU) to maintain consistency between synthetic-aware and unaware processing. In the synthetic-unaware branch, we apply mix-based augmentations, such as MixUp [33] and CutMix [32], to blend real and synthetic data. This encourages the model to operate without distinguishing between real and synthetic images, making this branch suitable for classification tasks by effectively classifying new images regardless of their origin. In the synthetic-aware branch, we focus on enhancing feature extraction through supervised contrastive learning (SupCon) [17]. We introduce a K-Nearest Neighbor-based label correction strategy to identify low-quality data within this branch dynamically. Identified low-quality data are treated as noise, and we employ three distinct noise-dropping strategies to design the corresponding objectives.

Our method is mainly evaluated in three widely used public long-tailed image recognition datasets CIFAR-10-LT, CIFAR-100-LT, and ImageNet-LT. Extensive experimental results show our method achieves state-of-the-art Top-1 accuracy and significantly surpasses other methods on CIFAR-10-LT and CIFAR-100-LT datasets over a range of imbalance factors.

Our main contributions can be summarized as follows:

  • We propose the utilization of synthetic data generated by Text-to-Image (T2I) models and optionally with LLMs to augment long-tailed datasets, thereby creating a balanced dataset to address data imbalance issues. To effectively manage this dataset, we design a two-branch network architecture comprising Synthetic-Aware and Synthetic-Unaware branches.

  • We proposed a K-Nearest Neighbor-based label correction procedure to dynamically detect low-quality synthetic data for contrastive learning. We designed three variants of supervised contrastive loss to handle detected low-quality synthetic data.

  • The experimental results show that our method achieves state-of-the-art Top-1 accuracy on popular long-tailed benchmarks including CIFAR-10-LT and CIFAR-100-LT.

2 Related Works

2.1 Generative Models for Image Recognition

Recently, generative models have developed rapidly, and they have also obtained great achievements. Specifically, generative models can be divided into two categories, Large Language Models (LLMs) [13, 1, 29] and Text-to-Image (T2I) Models [27, 26, 22]. Additionally, many studies have investigated whether generated contents by these advanced models can help improve the performance of downstream tasks. He et al. [15] show that synthetic data are ready to be applied to image recognition, and synthetic data can significantly enhance the performance of zero-shot and few-shot recognition. Qiu et al. [24] explore the potential of synthetic data for face recognition. However, in the field of long-tailed recognition, there is little research on using generated data to address the problem. An important challenge of long-tailed recognition is the scarcity of tail classes, which can be solved by using generated data.

2.2 Long-Tailed Recognition

Existing methods for long-tailed image recognition can be divided into four main parts, including re-sampling, re-weighting, mix-based augmentation, and contrastive learning. Re-sampling [4, 34] methods use either over-sampling or under-sampling to create a balanced mini-batch during training. Re-weighting [8, 28] intends to modify the loss function to generate uneven impacts for different classes. Mix-based augmentation [32, 33] methods enrich tail class representations by generating mixed images with over-sampled tail class images. Contrastive learning [17, 7, 37, 12] methods aim to pull the samples of the same class together while pushing samples of different classes apart, which have achieved great success in the field of long-tailed image recognition. However, all of these methods make it difficult to learn the representations of tail classes due to scarcity of data. In this work, by leveraging synthetic data, we aim to use mix-based augmentation methods to reduce the domain gap between real data and synthetic data and use contrastive learning methods to improve the model’s understanding and generalization capabilities.

3 Methods

3.1 Overall Framework

Our overall framework is shown in Figure 1, which can be divided into seven main components:

Refer to caption
Figure 1: The overall framework of our proposed method. By leveraging LLM and T2I models, we generate synthetic data as a complement to a long-tailed distributed dataset to obtain a balanced dataset. Synthetic-unaware branch takes the same view v1v_{1} from two samples to calculate the mixing loss. Synthetic-aware branch takes two different views v2v_{2} and v3v_{3} from one sample to calculate supervised contrastive loss.
  • Synthetic data generation module: For each class that requires synthetic images, we optionally utilize LLMs to enhance the prompts, and then employ T2I models to generate class-related synthetic images.

  • Data augmentation module: T()T(\cdot). For each input sample 𝐱\mathbf{x}, we generate three random augmented views, denoted as 𝐯=T(𝐱)\mathbf{v}=T(\mathbf{x}). Two of these views are utilized in the synthetic-aware branch, while the remaining one is used in the synthetic-unaware branch.

  • Encoder network: ()\mathcal{F}(\cdot), which maps input sample 𝐱\mathbf{x} to feature space, denoted as 𝐡=(𝐱)\mathbf{h}=\mathcal{F}(\mathbf{x}). Both the synthetic-unaware branch and synthetic-aware branch share the network.

  • Mix module: MixUp and CutMix. For each pair of two augmented input samples, it mixes them and their labels into two forms of augmentation pairs.

  • Classifier head: A linear classifier head which maps feature vector 𝐡\mathbf{h} to class space. Note that it calculates the mixed cross-entropy loss during training.

  • A multi-layer perceptron projection: ϕ\phi, which maps feature vector 𝐡\mathbf{h} to a low-dimensional representations 𝐳\mathbf{z}.

  • Label correction module: A powerful module which is used to dynamically detect low-quality synthetic image-label pairs in synthetic-aware branch.

3.2 Synthetic Data Generation

In our work, we employ LLMs and T2I models to collaboratively generate high-quality and diverse datasets for model training.

3.2.1 Prompt with GPT-4.

The main purpose of this unit is to create a description of a certain class as a guide prompt for the T2I generation process. Here we have two ways to generate prompts, basic prompts and LLM-enhanced prompts.

Basic Prompts.

For the basic prompts, following the simple and effective template introduced by CLIP [25], the fundamental template used for generating descriptions in this context is expressed as:

Py="A photo of a [CLASS]."P_{y}=\text{"A photo of a [CLASS]."} (1)

where PyP_{y} indicates the prompt for class yy. However, the main shortcoming of basic prompts is less diverse.

LLM Enhanced Prompts.

To increase the diversity of the synthetic images, we adopt the powerful tool LLM GPT-4 to generate diverse descriptions for synthetic image generation. Specifically, we assign the GPT-4 model as a role of the system with a system command, such as "You are a helpful assistant designed to output a prompt (no more than 50 words) describing a given object’s appearance and actions.". We provide a template of response for a given class yy as "A photo of the class [yy], with specific features and with a specific background.". Last, we also design the query template for a certain class. For example, "What the [CLASS] looks like?". In summary, the LLMs enhanced prompts for synthetic image generation can be formulated as:

Py=GPT-4command(Queryy){P_{y}}=\text{GPT-4}_{\text{command}}({\text{Query}_{y}}) (2)

3.2.2 T2I with GLIDE.

After obtaining the texts, we employ them to generate synthetic images by leveraging advanced T2I models. In detail, we use GLIDE [22] to generate synthetic images, which can be formulated as:

Iyn=GLIDE(Pyn)I^{n}_{y}=\text{GLIDE}(P^{n}_{y}) (3)

where nNyn\in N_{y} represents the number of synthetic images of a certain class, IynI^{n}_{y} and PynP^{n}_{y} indicate the corresponding generated images and prompts for the class yy, respectively.

However, some synthetic images might not be high-quality enough to represent the corresponding class. To address this problem, we use CLIP [25] to evaluate whether synthetic images are well related to their corresponding textual contexts.

3.3 Mixed Augmentation on Synthetic-Aware Branch

The key idea of this branch is to employ a combination of data augmentation techniques to mix synthetic data with real data, which guarantees that the model treats all data as same. To do so, we select two simple and efficient mix methods as follows.

3.3.1 MixUp.

First, we apply the MixUp [33] to mix real samples and synthetic samples. Let x\mathbbRW×H×Cx\in\mathbb{R}^{W\times H\times C} and yy denote a training image and its corresponding label, respectively. For a pair of two images and their labels (xi,yi)(x_{i},y_{i}) and (xj,yj)(x_{j},y_{j}), the mixed augmentation image and its label (x~ij,y~ij)(\tilde{x}_{ij},\tilde{y}_{ij}) can be calculated as follows:

x~ij=λxi+(1λ)xj,y~ij=λyi+(1λ)yj\tilde{x}_{ij}=\lambda x_{i}+(1-\lambda)x_{j},\ \ \ \tilde{y}_{ij}=\lambda y_{i}+(1-\lambda)y_{j} (4)

where the combination ratio λ\lambda is sampled from a Beta distribution. λ\lambda is a hyper-parameter, which is set as 1 here for a more random mixture.

3.3.2 CutMix.

Another mixture technology CutMix [32] also aims to mix real samples and synthetic samples, which combine two images by simply replacing the image region with a path from another image. Similarly, the combining operation can be defined as follows:

x~ij=Mxi+(1M)xj,y~ij=λyi+(1λ)yj\tilde{x}_{ij}=\textbf{M}\odot x_{i}+(\textbf{1}-\textbf{M})\odot x_{j},\ \ \ \tilde{y}_{ij}=\lambda y_{i}+(1-\lambda)y_{j} (5)

where the matrix M{0,1}W×H\textbf{M}\in\{0,1\}^{W\times H} is the binary mask that indicates the randomly selected region from the image xix_{i}, which is then filled into xjx_{j}. 1 is a binary mask filled with ones. \odot is element-wise multiplication. Specifically, we sample the bounding box coordinates B=(rx,ry,rw,rh)\textbf{B}=(r_{x},r_{y},r_{w},r_{h}) indicating the cropping regions on xix_{i} and xjx_{j}. The binary mask M is determined by setting to 0 within the bounding box B, and to 1 elsewhere. The box coordinates are uniformly sampled according to:

rx\displaystyle r_{x} Uniform(0,W),rw=W1λ\displaystyle\sim\text{Uniform}(0,W),\ \ r_{w}=W\sqrt{1-\lambda} (6)
ry\displaystyle r_{y} Uniform(0,H),rh=H1λ\displaystyle\sim\text{Uniform}(0,H),\ \ r_{h}=H\sqrt{1-\lambda}

where the λ\lambda is also sampled from a Beta distribution, ensuring that the cropped area ratio is equal to 1λ1-\lambda.

After constructing two kinds of real-synthetic mixed augmented data, we use the cross-entropy loss to calculate two losses in the same way:

MixUp=\displaystyle\mathcal{L}_{MixUp}= ce(f(x~MixUp),y~MixUp)\displaystyle\ \mathcal{L}_{ce}(f(\tilde{x}_{MixUp}),\ \tilde{y}_{MixUp}) (7)
CutMix=\displaystyle\mathcal{L}_{CutMix}= ce(f(x~CutMix),y~CutMix)\displaystyle\ \mathcal{L}_{ce}(f(\tilde{x}_{CutMix}),\ \tilde{y}_{CutMix})

here, the f(x~)f(\tilde{x}) denotes the predictions of x~\tilde{x}.

3.4 Supervised Contrastive Learning on Synthetic-Aware Branch

The main idea of the synthetic-aware branch is to enable the model to distinguish between real data and synthetic data, thereby learning their discrepancies and improving the model’s understanding and generalization capabilities.

3.4.1 Preliminaries on Contrastive Learning.

Contrastive learning methods typically use the noise contrastive estimation (NCE) objective to distinguish between similar and dissimilar data points. Specifically, it pulls different views of the same sample together, while pushing others away in a latent space. Similarly, in a supervised classification task, where label information is given, it pulls different views of samples from the same class together while pushing others away. For instance, we randomly sample a mini-batch of images {𝐱}i=1N\{\mathbf{x}\}^{N}_{i=1}, and then we apply a random augmentation function T()T(\cdot) to obtain a pair of different augmented views each image. Within this multiviewed batch, we can define that iI{1,2,,2N}i\in I\equiv\{1,2,...,2N\} as the index of multiviewed batch samples. Following the previous works [7, 17], the loss function can be defined as:

SC=iISCi=iI(1|Iy|1j{Iy}\{i}logexp(𝐳i𝐳j/τ)kI\{i}exp(𝐳i𝐳k/τ))\mathcal{L}_{SC}=\sum_{i\in I}\mathcal{L}_{SC}^{i}=\sum_{i\in I}\left(-\frac{1}{|I_{y}|-1}\sum_{j\in\{I_{y}\}\backslash\{i\}}\log\frac{\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{j}/\tau)}{\sum\limits_{k\in I\backslash\{i\}}\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{k}/\tau)}\right) (8)

where 𝐳i=ϕ((T(𝐱i)))\mathbf{z}_{i}=\phi(\mathcal{F}(T(\mathbf{x}_{i}))) indicates the vector in a latent space of input images, and Iy{iI:𝐲i=y}I_{y}\equiv\{i\in I:\mathbf{y}_{i}=y\} indicates a set of views of same class yy, and the symbol |||\cdot| denotes to the number of the elements, and τ>0\tau>0 denotes a scalar temperature hyper-parameter.

3.4.2 Label Correction.

To further address the problem of low-quality synthetic images, motivated by [36], we propose a K-Nearest-Neighbour(KNN) based method to dynamically detect low-quality synthetic image-label pairs. Specifically, given two sets of representations {𝐳real}i=1N1\{\mathbf{z}_{real}\}^{N_{1}}_{i=1}, {𝐳syn}i=1N2\{\mathbf{z}_{syn}\}^{N_{2}}_{i=1} and their corresponding labels {𝐲real}i=1N1\{\mathbf{y}_{real}\}^{N_{1}}_{i=1}, {𝐲syn}i=1N2\{\mathbf{y}_{syn}\}^{N_{2}}_{i=1} of both real images and synthetic images, we generate correction labels for the synthetic images based on the labels of their k nearest real neighbors. Combining them with original labels 𝐲org\mathbf{y}_{org}, the final labels for synthetic images can be defined as:

𝐲new={𝐲org,if(𝐲cor1=𝐲org)&(𝐲cor2=𝐲org)𝐲noise,otherwise\mathbf{y}_{new}=\left\{\begin{array}[]{ll}\mathbf{y}_{org}&,\ \text{if}\ \ \ (\mathbf{y}_{cor}^{1}=\mathbf{y}_{org})\ \&\ (\mathbf{y}_{cor}^{2}=\mathbf{y}_{org})\\ \mathbf{y}_{noise}&,\ \text{otherwise}\end{array}\right. (9)

where 𝐲new\mathbf{y}_{new} indicates the new labels for the synthetic data. 𝐲cor1\mathbf{y}_{cor}^{1} and 𝐲cor2\mathbf{y}_{cor}^{2} are the two sets of corrective labels of two different augmented views. 𝐲noise\mathbf{y}_{noise} represents the noise labels for those low-quality synthetic images.

Refer to caption
Figure 2: An example of label correction process.

3.4.3 Prototypes Complement.

To ensure all classes of real samples appear in every mini-batch, following previous work [21], we calculate prototypes of each class by averaging the real training data at the end of the previous epoch for the current epoch. Formally, the normalized prototype for a certain class yy can be formulated as:

𝒫¯y=1N1𝐳i𝐳real𝐳i,𝒫y=𝒫¯y𝒫¯y2\mathcal{\bar{P}}_{y}=\frac{1}{N_{1}}\sum_{\mathbf{z}_{i}\in\mathbf{z}_{real}}\mathbf{z}_{i},\ \ \ \ \ \mathcal{P}_{y}=\frac{\mathcal{\bar{P}}_{y}}{{\lVert\mathcal{\bar{P}}_{y}\rVert}_{2}} (10)

where 𝒫\mathbbRC×d\mathcal{P}\in\mathbb{R}^{C\times d} indicates that prototypes include CC data points with dd dimension each. N1N_{1} is the number of real data, and the symbol 2{\lVert\cdot\rVert}_{2} denotes the l2l^{2}-norm. Note that since the prototypes 𝒫\mathcal{P} are the average from real data, we treat them as real data in the following declaration.

3.4.4 Noise Dropping Strategy.

The key idea to tackle low-quality synthetic data is treating them as noise within a mini-batch. Here, we propose three forms of loss for this scenario.

First, the most intuitive way is simply to ignore all the noise data points:

1=iA\I(1|Iy|jAy\{Iy{i}}logexp(𝐳i𝐳j/τ)k{Ay}\{Iy{i}}exp(𝐳i𝐳k/τ))\mathcal{L}_{1}=\sum_{i\in A\backslash I^{{}^{\prime}}}\left(-\frac{1}{|I_{y}|}\sum_{j\in A_{y}\backslash\{I^{{}^{\prime}}_{y}\cup\{i\}\}}\log\frac{\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{j}/\tau)}{\sum\limits_{k\in\{A_{y}\}\backslash\{I^{{}^{\prime}}_{y}\cup\{i\}\}}\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{k}/\tau)}\right) (11)

here, A={I,I}A=\{I,I^{{}^{\prime}}\} denotes the entire dataset, where II consists of real data, prototypes, and high-quality synthetic data, and II^{{}^{\prime}} is the noise data. The symbol |||\cdot| represents the number of elements. Note the original |Iy|1|I_{y}|-1 term in Eq. 8 should plus one because prototypes contain one data point for each class.

Second, instead of simply removing all the noise data points, we keep them and treat all the noise data points as individual positive samples. Concretely, we assign unique noise labels 𝐲noise={1,2,,Nn}\mathbf{y}_{noise}=\{-1,-2,...,-N_{n}\} for noise data points, representing new classes distinct from all the original labels:

2=iA2i=iA(1|Iy|+|Iy|jA\{i}logexp(𝐳i𝐳j/τ)kA\{i}exp(𝐳i𝐳k/τ))\mathcal{L}_{2}=\sum_{i\in A}\mathcal{L}_{2}^{i}=\sum_{i\in A}\left(-\frac{1}{|I_{y}|+|I^{{}^{\prime}}_{y}|}\sum_{j\in A\backslash\{i\}}\log\frac{\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{j}/\tau)}{\sum\limits_{k\in A\backslash\{i\}}\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{k}/\tau)}\right) (12)

where yi{𝐲org,𝐲noise}y_{i}\in\{\mathbf{y}_{org},\ \mathbf{y}_{noise}\} indicates there is also a new noise class in addition to the original class. The denominator of the average term should count all samples among IyI_{y} and IyI^{{}^{\prime}}_{y}.

Third, similar to 2\mathcal{L}_{2}, we keep noise data points, but only as negative samples:

3=iA\I(1|Iy|jAy\{Iy{i}}logexp(𝐳i𝐳j/τ)kA\{i}exp(𝐳i𝐳k/τ))\mathcal{L}_{3}=\sum_{i\in A\backslash I^{{}^{\prime}}}\left(-\frac{1}{|I_{y}|}\sum_{j\in A_{y}\backslash\{I^{{}^{\prime}}_{y}\cup\{i\}\}}\log\frac{\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{j}/\tau)}{\sum\limits_{k\in A\backslash\{i\}}\exp(\mathbf{z}_{i}\cdot\mathbf{z}_{k}/\tau)}\right) (13)

In summary, the overall training loss function for our supervised contrastive learning framework can be expressed as:

overall=λMixUp+βCutMix+γ1,2,3\mathcal{L}_{overall}=\lambda\mathcal{L}_{MixUp}+\beta\mathcal{L}_{CutMix}+\gamma\mathcal{L}_{1,2,3} (14)

where λ\lambda, β\beta, γ\gamma are the hyper-parameters. More details in Algorithm 1.

Algorithm 1 Learining algorithm
1:Input: {𝐱1}i=1N\{\mathbf{x}_{1}\}^{N}_{i=1} and {𝐱2}i=1N\{\mathbf{x}_{2}\}^{N}_{i=1}: two batches of samples. {𝐲1}i=1N\{\mathbf{y}_{1}\}^{N}_{i=1} and {𝐲2}i=1N\{\mathbf{y}_{2}\}^{N}_{i=1}: labels of samples. 𝟙\mathbbm{1}: indicator of whether the samples are synthetic. T()T(\cdot): the augmentation function. \mathcal{F}: the backbone network. classifer()classifer(\cdot): the classifier head of synthetic-unaware branch. ϕ\phi: the projection head of the synthetic-aware branch.
2:while network not converge do
3:     Calculate class prototypes 𝐩2\mathbf{p}^{2}, 𝐩3\mathbf{p}^{3};
4:     for i=1i=1 to step do
5:         𝐯11\mathbf{v}^{1}_{1}, 𝐯12\mathbf{v}^{2}_{1}, 𝐯13=T({𝐱1}i=1N)\mathbf{v}^{3}_{1}=T(\{\mathbf{x}_{1}\}^{N}_{i=1})
6:         # Synthetic-unaware Branch
7:         𝐯21,_,_=T({𝐱2}i=1N)\mathbf{v}^{1}_{2},\_,\_=T(\{\mathbf{x}_{2}\}^{N}_{i=1})
8:         𝐱~m\tilde{\mathbf{x}}_{m}, 𝐲~m=MixUp(𝐯11,𝐯21)\tilde{\mathbf{y}}_{m}=MixUp(\mathbf{v}^{1}_{1},\mathbf{v}^{1}_{2}) Eq. 4;   𝐱~c\tilde{\mathbf{x}}_{c}, 𝐲~c=CutMix(𝐯11,𝐯21)\tilde{\mathbf{y}}_{c}=CutMix(\mathbf{v}^{1}_{1},\mathbf{v}^{1}_{2}) Eq. 5
9:         𝐡m=(𝐱~m)\mathbf{h}_{m}=\mathcal{F}(\tilde{\mathbf{x}}_{m}), 𝐡c=(𝐱~c)\mathbf{h}_{c}=\mathcal{F}(\tilde{\mathbf{x}}_{c})
10:         𝐲^m=classifer(𝐡m)\hat{\mathbf{y}}_{m}=classifer(\mathbf{h}_{m}), 𝐲^c=classifer(𝐡c)\hat{\mathbf{y}}_{c}=classifer(\mathbf{h}_{c})
11:         Calculate MixUp(𝐲^m,𝐲~m)\mathcal{L}_{MixUp}(\hat{\mathbf{y}}_{m},\tilde{\mathbf{y}}_{m}) and CutMix(𝐲^c,𝐲~c)\mathcal{L}_{CutMix}(\hat{\mathbf{y}}_{c},\tilde{\mathbf{y}}_{c}) Eq. 7
12:         # Synthetic-aware Branch
13:         𝐡2=(𝐯12)\mathbf{h}^{2}=\mathcal{F}(\mathbf{v}_{1}^{2}), 𝐡3=(𝐯13)\mathbf{h}^{3}=\mathcal{F}(\mathbf{v}_{1}^{3})
14:         𝐳2=ϕ(𝐡2)\mathbf{z}^{2}=\phi(\mathbf{h}^{2}), 𝐳3=ϕ(𝐡3)\mathbf{z}^{3}=\phi(\mathbf{h}^{3})
15:         Generate new labels 𝐲12\mathbf{y}_{1}^{2}, 𝐲13\mathbf{y}_{1}^{3} for {𝐱1}i=1N\{\mathbf{x}_{1}\}^{N}_{i=1} based on 𝐳2\mathbf{z}^{2}, 𝐳3\mathbf{z}^{3}.
16:         Calculate SC(𝐳2,𝐳3,𝐩2,𝐩3,𝐲12,𝐲13)\mathcal{L}_{SC}(\mathbf{z}^{2},\mathbf{z}^{3},\mathbf{p}^{2},\mathbf{p}^{3},\mathbf{y}_{1}^{2},\mathbf{y}_{1}^{3}) Eq. 11, 12, 13
17:         Optimize the network by overall\mathcal{L}_{overall} Eq. 14
18:     end for
19:end while
20:Output: The well trained model \mathcal{F} and classifer()classifer(\cdot)

4 Experiments

4.1 Experiment Setup

4.1.1 Datasets.

Both CIFAR-10 and CIFAR-100 [18] consist of 60000 32×\times32 color images with 50000 images for training and 10000 images for testing. There are 10 classes and 100 classes resulting in 6000 images and 600 images for each class respectively. CIFAR-10-LT and CIFAR-100-LT are the subsets of the long-tailed versions of CIFAR-10 and CIFAR-100. In this work, the training datasets consist of real data and generated data. For real data, following the previous works [2, 12, 37], we use the exponential decay function Ni=Noμin1N_{i}=N_{o}\mu^{\frac{i}{n-1}} to obtain long-tailed datasets, where ii is the class index (0-indexed), NoN_{o} is the original number of training images and μ\mu is the imbalanced factor. The imbalanced factor μ=Nmax/Nmin\mu=N_{max}/N_{min} reflects the degree of imbalance and we use different μ\mu [10, 50, 100, 200] for both CIFAR-10-LT and CIFAR-100-LT. For synthetic data, they are simply set as the complement of the real data. Following the official GLIDE T2I process, we use different prompts to generate 64×\times64 color synthetic images.

ImageNet-LT, proposed by [20], is a long-tailed version of the original ImageNet [9] dataset, which is designed to mimic real-world scenarios. Specifically, the original ImageNet consists of 1,280,000 images of 1000 classes in total with 1280 images per class. And ImageNet-LT is a subset of ImageNet with an imbalance factor of 256. To reduce computational costs, we only use basic prompts to generate 256×\times256 color synthetic images.

4.1.2 Implementation Details

We use Pytorch [23] deep learning framework to train models for all datasets.

For CIFAR-10-LT and CIFAR-100-LT, we follow [12, 7, 37] to use ResNet-32 [14] for a fair comparison across different studies. For the optimization, we use a standard SGD optimizer with a momentum of 0.9. We start with an initial learning rate of 0.01 with a cosine annealing scheduler. Additionally, the batch size is set at 128 and the weight decay rate at 5e-3. The model will be trained for 200 epochs, after which the model is evaluated on the test set to assess its performance. In terms of data augmentation strategies, we use random horizontal flipping and cropping for the classification branch, AutoAugment [6], Cutout [10] and SimAugment [5] for contrastive learning branch. And MLP consists of two hidden layers of size 2048 and output vector of size 128, and there are batch normalization and ReLU activation function between two hidden layers.

For ImageNet-LT, we use ResNet-50 [14] and ResNeXt-50-32x4d [31] as backbone. We use a standard SGD optimizer with a momentum of 0.9. We set the initial learning rate as 0.1, which is adjusted by a cosine annealing scheduler. The batch size is set at 256 and the weight decay rate at 1e-4. The model is trained over 180 epochs.

4.1.3 Evaluation Protocol

For all datasets, we focus on the top-1 accuracy. We train models on a balanced dataset composed of a mix of real long-tailed training sets and synthetic set and evaluate them on the balanced validation/test dataset. Furthermore, for ImageNet-LT, we report many-shot classes with real training samples > 100, medium-shot classes with real training samples 20 \sim 100, and few-shot classes with real training samples \leq 20, respectively.

Table 1: Top-1 accuracy of ResNet-32 on CIFAR-10-LT and CIFAR-100-LT with different imbalance factors [200, 100, 50, 10]. The best results are marked in bold. \dagger denotes results reproduced by ourselves using the code released by authors. * reports the results based on real-synthetic mixed balanced data.
CIFAR-10 CIFAR-100
Method IF=200 IF=100 IF=50 IF=10 IF=200 IF=100 IF=50 IF=10
CE 65.87 70.14 74.94 86.18 34.70 38.46 44.02 55.73
CB-Focal [8] 68.89 74.57 79.27 87.49 36.23 39.60 45.42 57.99
Mixup [33] - 73.06 77.82 87.10 - 39.54 54.99 58.02
PaCo [7] - - - - - 52.0 56.0 64.2
ProCo [11] - 85.9 88.2 91.9 - 52.8 57.1 65.5
Hybrid-SC [30] - 81.40 85.36 91.12 - 46.72 51.87 63.05
BCL [37] - 84.32 87.24 91.12 - 51.93 56.59 64.87
GLMC [12] - 87.25 90.39 94.33 - 57.70 62.70 72.63
SURE [19] - 86.93 90.22 94.96 - 57.34 63.13 73.24
CE 82.29 84.45 87.27 91.10 55.24 57.80 59.85 67.93
GLMC 85.13 88.33 91.35 94.80 58.46 61.80 65.21 74.11
SAU (Ours) 89.96 92.21 93.91 95.88 61.95 64.47 68.22 76.31
Table 2: Top-1 accuracy (%) of on ImageNet-LT dataset compared to the state-of-the-art works with different backbone. The best results are marked in bold. * reports the results based on real-synthetic mixed balanced data.
Method Backbone Many Med Few All
CE ResNet-50 64 33.8 5.8 41.6
CB-Focal [8] ResNet-50 39.6 32.7 16.8 33.2
BCL (90 epochs) [37] ResNeXt-50 67.2 53.9 36.5 56.7
BCL (180 epochs) [37] ResNeXt-50 67.9 54.2 36.6 57.1
PaCo (400 epochs) [7] ResNeXt-50 - - - 58.2
PaCo (180 epochs) [7] ResNeXt-50 64.4 55.7 33.7 56.0
ProCo (180 epochs) [11] ResNet-50 68.2 55.1 38.1 57.8
GLMC [12] ResNeXt-50 70.1 52.4 30.4 56.3
CE ResNet-50 58.3 47.7 32.1 49.7
SAU (ours) ResNet-50 62.3 51.7 37.2 53.7
SAU (ours) ResNeXt-50 64.7 52.6 38.4 55.2

4.2 Main Results

4.2.1 CIFAR-10-LT and CIFAR-100-LT

The result has been reported in Table 1. For comparison baselines , we mainly consider previous state-of-the-art works including CB-Focal [8], PaCo [7], Hybrid-SC [30], BCL [37], GLMC [12], and SURE [19], whose models are trained from scratch. Furthermore, for a fairer comparison, we also apply recent state-of-the-art work GLMC on real-synthetic mixed balanced datasets. As we can see, our method achieves the best results on all imbalance factors.

4.2.2 ImageNet-LT

Similar to CIFAR, we report the Top-1 accuracy of our models compared to other works in Table 2. For comparison baselines, we mainly consider both current and previous state-of-the-art works including Cross-Entropy, CB-Focal [8], BCL [37], PaCo [7], ProCo [11] and GLMC [12]. From the results, our method achieves the best performance in few-shot.

Table 3: Effectiveness of different noise dropping strategies. We report the Top-1 accuracy on CIFAR-10-LT and CIFAR-100-LT (IF=100) with the ResNet-32 backbone.
Method CIFAR-10-LT CIFAR-100-LT
1\mathcal{L}_{1} 91.34 63.03
2\mathcal{L}_{2} 92.21 64.47
3\mathcal{L}_{3} 91.05 64.28
Table 4: Effectiveness of primary components. We report Top-1 Acc. on CIFAR-100-LT (IF=100). The backbone is ResNet-32. Note that we use 2\mathcal{L}_{2} as SC\mathcal{L}_{SC} for all experiments.
CE\mathcal{L}_{CE} MixUp\mathcal{L}_{MixUp} CutMix\mathcal{L}_{CutMix} SC\mathcal{L}_{SC} Top-1 Acc.
\checkmark\checkmark 57.84
\checkmark\checkmark 59.95
\checkmark\checkmark 58.22
\checkmark\checkmark \checkmark\checkmark 61.68
\checkmark\checkmark \checkmark\checkmark 62.45
\checkmark\checkmark \checkmark\checkmark 61.09
\checkmark\checkmark \checkmark\checkmark \checkmark\checkmark 64.47

4.3 Ablation Studies

4.3.1 Effectiveness of Noise-Dropping Strategy.

First, we compare the performances of three different contrastive losses (1\mathcal{L}_{1}, 2\mathcal{L}_{2}, and 3\mathcal{L}_{3}) to deal with the poor-quality synthetic images. From the results presented in Table 3, all losses are effective and 2\mathcal{L}_{2} performs best on both CIFAR-10-LT and CIFAR-100-LT.

4.3.2 Effectiveness of Components.

We also deploy an extensive ablation experiment to examine the effectiveness of primary components. The results are shown in Table 4, which validates the effectiveness of our proposed method. Note the first row is the baseline, which uses cross-entropy as loss function and is trained on the real-synthetic mixed dataset. As we can see, the final row of the table, where both MixUp\mathcal{L}_{MixUp} and CutMix\mathcal{L}_{CutMix} are utilized alongside SC\mathcal{L}_{SC}, presents the best result with a Top-1 accuracy of 64.47%.

5 Conclusion

In this work, to enhance the long-tailed image recognition, we leverage LLMs and T2I models to generate synthetic images, which is used as a complement to the long-tailed distribution dataset to obtain a balanced dataset. And we propose a powerful framework SAU to tackle such real-synthetic mixed datasets. As a result, SAU has achieved state-of-the-art Top-1 accuracy on both CIFAR-10-LT and CIFAR-100-LT, and significantly improve the performance of tail classes on ImageNet-LT.

References

  • [1] Achiam, J., Adler, S., Agarwal, S., Ahmad, L., Akkaya, I., Aleman, F.L., Almeida, D., Altenschmidt, J., Altman, S., Anadkat, S., et al.: Gpt-4 technical report. arXiv preprint arXiv:2303.08774 (2023)
  • [2] Alshammari, S., Wang, Y.X., Ramanan, D., Kong, S.: Long-tailed recognition via weight balancing. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 6897–6907 (2022)
  • [3] Besnier, V., Jain, H., Bursuc, A., Cord, M., Pérez, P.: This dataset does not exist: training models from generated images. In: ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). pp. 1–5. IEEE (2020)
  • [4] Chawla, N.V., Bowyer, K.W., Hall, L.O., Kegelmeyer, W.P.: Smote: synthetic minority over-sampling technique. Journal of artificial intelligence research 16, 321–357 (2002)
  • [5] Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International conference on machine learning. pp. 1597–1607. PMLR (2020)
  • [6] Cubuk, E.D., Zoph, B., Mane, D., Vasudevan, V., Le, Q.V.: Autoaugment: Learning augmentation strategies from data. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 113–123 (2019)
  • [7] Cui, J., Zhong, Z., Liu, S., Yu, B., Jia, J.: Parametric contrastive learning. In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 715–724 (2021)
  • [8] Cui, Y., Jia, M., Lin, T.Y., Song, Y., Belongie, S.: Class-balanced loss based on effective number of samples. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 9268–9277 (2019)
  • [9] Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., Fei-Fei, L.: Imagenet: A large-scale hierarchical image database. In: 2009 IEEE conference on computer vision and pattern recognition. pp. 248–255. Ieee (2009)
  • [10] DeVries, T., Taylor, G.W.: Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552 (2017)
  • [11] Du, C., Wang, Y., Song, S., Huang, G.: Probabilistic contrastive learning for long-tailed visual recognition. IEEE Transactions on Pattern Analysis and Machine Intelligence (2024)
  • [12] Du, F., Yang, P., Jia, Q., Nan, F., Chen, X., Yang, Y.: Global and local mixture consistency cumulative learning for long-tailed visual recognitions. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 15814–15823 (2023)
  • [13] Floridi, L., Chiriatti, M.: Gpt-3: Its nature, scope, limits, and consequences. Minds and Machines 30, 681–694 (2020)
  • [14] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 770–778 (2016)
  • [15] He, R., Sun, S., Yu, X., Xue, C., Zhang, W., Torr, P., Bai, S., Qi, X.: Is synthetic data from generative models ready for image recognition? arXiv preprint arXiv:2210.07574 (2022)
  • [16] Jahanian, A., Puig, X., Tian, Y., Isola, P.: Generative models as a data source for multiview representation learning. arXiv preprint arXiv:2106.05258 (2021)
  • [17] Khosla, P., Teterwak, P., Wang, C., Sarna, A., Tian, Y., Isola, P., Maschinot, A., Liu, C., Krishnan, D.: Supervised contrastive learning. Advances in neural information processing systems 33, 18661–18673 (2020)
  • [18] Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny images (2009)
  • [19] Li, Y., Chen, Y., Yu, X., Chen, D., Shen, X.: Sure: Survey recipes for building reliable and robust deep networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2024)
  • [20] Liu, Z., Miao, Z., Zhan, X., Wang, J., Gong, B., Yu, S.X.: Large-scale long-tailed recognition in an open world. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 2537–2546 (2019)
  • [21] Nassar, I., Hayat, M., Abbasnejad, E., Rezatofighi, H., Haffari, G.: Protocon: Pseudo-label refinement via online clustering and prototypical consistency for efficient semi-supervised learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 11641–11650 (2023)
  • [22] Nichol, A., Dhariwal, P., Ramesh, A., Shyam, P., Mishkin, P., McGrew, B., Sutskever, I., Chen, M.: Glide: Towards photorealistic image generation and editing with text-guided diffusion models. arXiv preprint arXiv:2112.10741 (2021)
  • [23] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al.: Pytorch: An imperative style, high-performance deep learning library. Advances in neural information processing systems 32 (2019)
  • [24] Qiu, H., Yu, B., Gong, D., Li, Z., Liu, W., Tao, D.: Synface: Face recognition with synthetic data. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 10880–10890 (2021)
  • [25] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., et al.: Learning transferable visual models from natural language supervision. In: International conference on machine learning. pp. 8748–8763. PMLR (2021)
  • [26] Ramesh, A., Dhariwal, P., Nichol, A., Chu, C., Chen, M.: Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125 1(2),  3 (2022)
  • [27] Ramesh, A., Pavlov, M., Goh, G., Gray, S., Voss, C., Radford, A., Chen, M., Sutskever, I.: Zero-shot text-to-image generation. In: International conference on machine learning. pp. 8821–8831. Pmlr (2021)
  • [28] Ren, J., Yu, C., Ma, X., Zhao, H., Yi, S., et al.: Balanced meta-softmax for long-tailed visual recognition. Advances in neural information processing systems 33, 4175–4186 (2020)
  • [29] Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M.A., Lacroix, T., Rozière, B., Goyal, N., Hambro, E., Azhar, F., et al.: Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971 (2023)
  • [30] Wang, P., Han, K., Wei, X.S., Zhang, L., Wang, L.: Contrastive learning based hybrid networks for long-tailed image classification. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. pp. 943–952 (2021)
  • [31] Xie, S., Girshick, R., Dollár, P., Tu, Z., He, K.: Aggregated residual transformations for deep neural networks. In: Proceedings of the IEEE conference on computer vision and pattern recognition. pp. 1492–1500 (2017)
  • [32] Yun, S., Han, D., Oh, S.J., Chun, S., Choe, J., Yoo, Y.: Cutmix: Regularization strategy to train strong classifiers with localizable features. In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 6023–6032 (2019)
  • [33] Zhang, H., Cisse, M., Dauphin, Y.N., Lopez-Paz, D.: mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 (2017)
  • [34] Zhang, Z., Pfister, T.: Learning fast sample re-weighting without reward data. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 725–734 (2021)
  • [35] Zhao, Q., Dai, Y., Li, H., Hu, W., Zhang, F., Liu, J.: Ltgc: Long-tail recognition via leveraging llms-driven generated content. arXiv preprint arXiv:2403.05854 (2024)
  • [36] Zheng, M., Wang, F., You, S., Qian, C., Zhang, C., Wang, X., Xu, C.: Weakly supervised contrastive learning. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 10042–10051 (2021)
  • [37] Zhu, J., Wang, Z., Chen, J., Chen, Y.P.P., Jiang, Y.G.: Balanced contrastive learning for long-tailed visual recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 6908–6917 (2022)