This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Unified Framework for Histopathology Image Augmentation and Classification via Generative Models

Meng Li*, Chaoyi Li*, Can Peng, Brian C. Lovell * denotes equal contribution The University of Queensland, School of EECS, QLD 4072, Australia [email protected], [email protected], [email protected], [email protected]
Abstract

Deep learning techniques have become widely utilized in histopathology image classification due to their superior performance. However, this success heavily relies on the availability of substantial labeled data, which necessitates extensive and costly manual annotation by domain experts. To address this challenge, researchers have recently employed generative models to synthesize data for augmentation, thereby enhancing classification model performance. Traditionally, this involves generating synthetic data first and then training the classification model with both synthetic and real data, which creates a two-stage, time-consuming workflow. To overcome this limitation, we propose an innovative unified framework that integrates the data generation and model training stages into a unified process. Our approach utilizes a pure Vision Transformer (ViT)-based conditional Generative Adversarial Network (cGAN) model to simultaneously handle both image synthesis and classification. An additional classification head is incorporated into the cGAN model to enable simultaneous classification of histopathology images. To improve training stability and enhance the quality of generated data, we introduce a conditional class projection technique that helps maintain class separation during the generation process. We also employ a dynamic multi-loss weighting mechanism to effectively balance the losses of the classification tasks. Furthermore, our selective augmentation mechanism actively selects the most suitable generated images for data augmentation to further improve performance. Extensive experiments on histopathology datasets show that our unified synthetic augmentation framework consistently enhances the performance of histopathology image classification models.

Index Terms:
Conditional Transformer-based GAN, Histopathology Image Classification, Image Synthesis, Data Augmentation
publicationid: pubid: ©2024 IEEE. Published in the Digital Image Computing: Techniques and Applications, 2024 (DICTA 2024), 27-29 November 2024 in Perth, Western Australia, Australia. Personal use of this material is permitted. However, permission to reprint/republish this material for advertising or promotional purposes or for creating new collective works for resale or redistribution to servers or lists, or to reuse any copyrighted component of this work in other works, must be obtained from the IEEE. Contact: Manager, Copyrights and Permissions / IEEE Service Center / 445 Hoes Lane / P.O. Box 1331 / Piscataway, NJ 08855-1331, USA. Telephone: + Intl. 908-562-3966.

I Introduction

The benefits of deep learning are greatly enhanced by the availability of labeled data, which has facilitated the successful application of deep learning in histopathology image classification tasks. However, most of these tasks require the expertise of domain specialists and the acquisition of substantial amounts of labeled data [1]. This process is not only labor-intensive and time-consuming but also impractical in the context of rare diseases, early-stage clinical studies, or emerging imaging modalities. To address these challenges, several studies have utilized generative adversarial networks (GANs) to synthesize images for data augmentation, achieving satisfactory results [1, 2]. Despite their success, these approaches generally take to a traditional two-stage workflow, where image generation and data augmentation are treated as separate tasks. This separation requires distinct models for each stage, resulting in increased training complexity and effort.

Moreover, a distinct characteristic of histopathological images is the prevalence of non-local or long-range information within their composition [3]. CNN-based generative models may struggle to synthesize realistic histopathology images since their locally focused receptive fields are not effective at capturing the non-local information in these images. In contrast, vision transformers (ViTs) offer promising advantages in modeling non-local context dependencies [4]. Consequently, ViT-based generative models hold significant potential for histopathology image analysis tasks.

To leverage the advantages of ViTs and transition from the two-stage paradigm to a unified approach, we propose a framework that employs a pure ViT-based GAN model for synthetic augmentation in histopathology image analysis. This framework integrates models from both stages into one, conditionally generating synthetic images while also performing classification prediction. This approach faces a significant challenge, as GANs are notorious for their training instability [5] and introducing multiple loss functions can further complicate this issue. To address these challenges, we first propose a conditioned class projection technique, designed to aid in separating conditional information during training. Drawing inspiration from the success of multi-task learning methods [6, 7], we also introduce a multi-loss weighting function to dynamically balance the losses when training the GAN model. Additionally, to further enhance the effectiveness of synthetic augmentation, we incorporate a selective augmentation mechanism that actively chooses the most suitable generated images for data augmentation.

The main contributions of this paper can be summarized as follows:

  • We propose a novel unified framework that integrates the two stages of image generation and classification into a single stage, achieving higher efficiency compared to the traditional paradigm of using synthetic data for augmentation;

  • We explore the application of a conditional, pure ViT-based GAN for histopathology image analysis, highlighting its substantial practical benefits;

  • We introduce a conditioned class projection technique to improve class separation during training;

  • We propose a multi-loss weighting function that effectively stabilizes training and improves performance across tasks;

  • Extensive experiments on lymph node histopathology datasets demonstrate that our approach significantly improves classification performance.

Refer to caption
Figure 1: Overview of our proposed framework. During training, In the generator, the latent vector zz is split into multiple chunks and concatenated with conditional label cc as the input to the class projection layer. A multi-scale pyramid technique is employed to learn global and local information. The selective data augmentation mechanism filters the predictions of synthetic data from the class head. Finally, the output of the discriminator is the objective function. In the inference, test data is only fed into the discriminator to obtain predictions.

II Related Work

II-A Generative Models for Medical Image Augmentation

Data scarcity remains a significant obstacle in applying deep learning technologies in the medical field. To mitigate this challenge, data augmentation has emerged as a powerful technique for expanding training datasets. However, conventional augmentation methods used in general image processing often struggle or cannot be applied due to the complexities and unique characteristics of medical images. This limitation underscores the appeal of generative models, which have demonstrated strong performance in augmenting medical data.

Currently, three primary generative models are utilized for medical image synthesis: variational autoencoders (VAEs), generative adversarial networks (GANs), and diffusion models [8]. GANs, in particular, are widely favored for their capability to generate realistic images even with limited data availability [9], despite challenges such as training instability, convergence issues, and mode collapse [10]. In contrast, VAEs offer greater output diversity and avoid mode collapse compared to GANs, but they often produce blurry images, which limits their adoption for augmentation purposes. Recently, diffusion models [11, 12] have shown promise in generating realistic and diverse outputs, though they require extensive data and computational resources for training [8]. Given the inherent data scarcity in medical datasets, our study focuses on augmenting data with limited availability using generative models. Considering these constraints, GANs emerge as the preferred choice due to their proven effectiveness in generating high-quality medical images under restricted data conditions.

II-B Transformer-based GANs

A GAN model consists of two key components: a discriminator and a generator. During training, these components are alternatively trained. The discriminator learns to differentiate between the synthetic images generated by the generator and real training data. Conversely, the generator learns to create synthetic images that are as realistic as possible to deceive the discriminator. During testing, only the well-trained generator is used to synthesize realistic images.

GAN was originally proposed using the convolutional neural network (CNN) architecture to formulate a min-max optimization problem aimed at narrowing the gap between real and synthetic data distributions [13]. Recently, transformer networks [14] have attracted attention for their ability to effectively model non-local contextual dependencies. In the realm of medical image synthesis, there is a growing trend towards replacing CNNs with transformer networks, given the significant role of global understanding of training data in generating realistic medical images. For instance, Korkmaz et al. [15] introduced the GVTrans model, which employs cross-attentive visual transformers to map low-dimensional noise and latent variables onto MRI data. Zhang et al. [16] proposed PTNet3D, leveraging pyramid transformer networks to generate high-resolution 3D longitudinal infant brain MRI data. However, these methods often adopt a hybrid structure that integrates a transformer-based generator alongside a CNN-based discriminator.

In contrast to hybrid models, a pure transformer-based GAN offers architectural simplicity that can potentially reduce method complexity, thereby facilitating stable training and reducing computational overhead. MedViTGAN stands out as a pure vision transformer-based conditional GAN designed specifically for generating histopathology images for data augmentation [17]. Inspired by MedViTGAN, our study adopts a pure transformer GAN framework to capture non-local or long-range information in histopathology images. Furthermore, different from previous methods, our framework integrates conditional class projection and selective data augmentation techniques to enhance both class separation and generation quality.

II-C Enhancing GAN Performance with Multi-Loss Integration

To enhance performance in various image synthesis applications, several GAN-based studies have explored combining the Wasserstein loss with different loss functions. For instance, Liu et al. [18] integrated total variation loss, pixel loss, and feature-level losses linearly with the Wasserstein loss function to train an auto-painter, resulting in improved painting quality. Similarly, Ebenezer et al. [19] employed a combination of Wasserstein loss, L1 pixel loss, and VGG loss to train their GAN. However, these approaches typically involve manual tuning of the weights assigned to each loss component. This manual hyperparameter tuning relies heavily on expert knowledge and can lead to time-consuming and computationally expensive training processes. Moreover, it may limit the generalizability of the method and result in sub-optimal performance.

The success of multi-task learning suggests a more sophisticated approach to combining loss functions across different tasks. Kendall et al. [6] introduced a multi-task loss formulation that leverages the homoscedastic uncertainty of each task to dynamically weigh their respective losses. Inspired by their work, we propose a multi-loss weighting function that automates the adjustment of these weights during the training process.

III Method

The overall framework of our method, illustrated in Fig. 1, aims to integrate image generation and classification into a unified stage. We build upon the transformer-based GAN architecture introduced in MedViTGAN [17] as the baseline model in our approach. This transformer structure effectively captures essential global contextual dependencies and simplifies training.

During the training stage, we introduce a conditional training strategy to enhance the model’s capabilities. This strategy incorporates multiple class projection modules within the generator and adds an extra classification head to the discriminator. The discriminator is equipped with two classification heads: one source head for distinguishing real and synthetic images, and one class head for image class classification. We then utilize a multi-loss weighting function to train the generator and the discriminator. We further refine the synthetic augmentation process by implementing a selective data augmentation mechanism, which filters and prioritizes high-quality synthetic images for augmentation. In the inference stage, classification tasks are solely performed by the discriminator. The pseudo-code outlining the training and inference procedures is provided in the Algorithm 1.

Algorithm 1 Training and inference of the proposed framework
Training Phase
Input: Truncation value τ\tau, Class Label cc, Training data IrealI_{real}, Conditional Class Projection Module CPCP, Selective Data Augmentation Mechanism SASA, Objective Function total\mathcal{L}_{total}, Generator GG, Discriminator Feature Extractor DD with Source Head f1f^{1} and Class Head f2f^{2}
Output: Generator GG, Discriminator Feature Extractor DD with f1f^{1} and f2f^{2}
Initialize GG, DD with f1f^{1} and f2f^{2}
for epoch=1epoch=1 to NN do
      A noise vector zTruncNormal(0,1,τ,τ)z\sim\text{TruncNormal}(0,1,-\tau,\tau)
      Isyn=G(CP(z,c))I_{\text{syn}}=G(CP(z,c))
      y1,y2=f2(D(Ireal,Isyn))y_{1},y_{2}=f^{2}(D(I_{\text{real}},I_{\text{syn}}))
      y3=f1(D(Ireal,Isyn))y_{3}=f^{1}(D(I_{\text{real}},I_{\text{syn}}))
      y2filtered=SA(y2)y_{2}^{\text{filtered}}=SA(y_{2})
      Compute total\mathcal{L}_{total} using y1y_{1} , y2filteredy_{2}^{\text{filtered}}, y3y_{3} to train GG and DD
end for
Return GG, DD with f1f^{1} and f2f^{2}
\triangleright y1,2,3y_{1,2,3} represent classification predictions of real and synthetic data from the class head and real/synthetic predictions from the source head.
Inference Phase
Input: Test data ItestI_{test}, Trained Discriminator Feature Extractor DD with Class Head f2f^{2}
Output: Prediction yty_{t}
yt=f2(D(Itest))y_{t}=f^{2}(D(I_{\text{test}}))
Return yty_{t}

III-A The Architecture of Transformer-based GAN

The GAN architecture utilized in our framework is a pure transformer-based model inspired by TransGAN [5]. The core component is the Vision Transformer (ViT) encoder [20], comprising multi-head self-attention modules and a feed-forward multi-layer perceptron (MLP) with GELU non-linearity. The self-attention modules and MLPs in the ViT encoder are equipped with residual connections and layer normalization for improved stability and performance.

Our GAN’s generator is structured with four-stage blocks designed to progressively learn and upsample the given latent vector. Each stage block consists of a class projection layer, an upsampling module, and four ViT encoders. As the generator advances through these stages, it incrementally increases the feature map resolution until it reaches the target dimensions. Initially, the latent noise vector zz is concatenated with the one-hot class label cc and passed through a linear projection layer to generate embedding tokens X0H0×W0×CX_{0}\in\mathbb{R}^{H_{0}\times W_{0}\times C}. These tokens are then sent to the first stage of the generator. In the first stage, the embedded feature map undergoes upsampling from X0H0×W0×CX_{0}\in\mathbb{R}^{H_{0}\times W_{0}\times C} to X12H0×2W0×CX_{1}\in\mathbb{R}^{2H_{0}\times 2W_{0}\times C} using cubic interpolation, ensuring early feature learning without reducing dimensions. In the following three stages, a pixel-shuffle module [21] is used to upsample the resolution by a factor of 2×\times, while decreasing the channels to one-quarter. This approach reduces memory requirements and enhances network efficiency. To tackle the issue where the self-attention module sacrifices local information to learn global correspondences during the high-resolution generation stage (the third and fourth stages), we replace the self-attention module with the grid self-attention module [5] in the ViT encoder. This modification enables our model to effectively capture both global and local information. Specifically, we set the predefined window size to 32×3232\times 32 in the third stage and 16×1616\times 16 in the fourth stage.

The discriminator is similarly structured into four stages. The first three stages consist of a ViT encoder and an average pooling layer. The final stage includes two ViT encoders that handle the class and source heads, respectively. Between the first three stages, we apply a multi-scale technique that combines the outputs from the last block with patches of varying sizes extracted from the same input image. After passing through an MLP layer, the patch information is encoded into a sequence of embeddings for concatenation, allowing the model to learn both semantic structure and texture details efficiently.

For conditional learning, a class token and a source token are appended to the beginning of the 1D sequence at the end of the third block. This modified sequence then passes through the fourth stage, enabling the model to make both class predictions and real/synthesis judgments.

III-B Conditioned Class Projection

To further improve the conditional image generation, we propose the conditional class projection module, including the conditional direct skip connection and the class projection layer. Inspired by the skip-z approach [22], we propose conditional skip-z, which integrates skip-z with class-specific information. Skip-z [22] feeds the noise vector zz into multiple layers of the generator, instead of just the initial layer. This design enables the generator to utilize the diversity of input noise to enhance the quality and diversity of the generated images. In our approach, we divide zz into chunks corresponding to each stage and concatenate each chunk with the class vector cc. The combined vector [c,z][c,z] is then processed through the class projection layer at each stage, enhancing both the quality and diversity of the synthetic images within the class.

Next, we project [c,z][c,z] onto token embeddings by the class projection layer which is a conditional layer normalization. Unlike class-conditional methods that use batch normalization [23, 22, 24], we find that layer normalization [25] yields superior performance for ViT-based conditional GANs. This is possibly explained by the fact that layer normalization can preserve positional information learned by the attention mechanism, whereas batch normalization may compromise this information [26]. The approach transforms a layer’s activations aa into a class-specific normalized activation a¯\bar{a}, which is described as:

a¯=aμ^σ^+ϵγ+β.\bar{a}=\frac{a-\hat{\mu}}{\sqrt{\hat{\sigma}+\epsilon}}*\gamma+\beta. (1)

where μ^\hat{\mu} and σ^\hat{\sigma} represent the mean and variance of the input, and ϵ\epsilon ensures numerical stability. We introduce class-conditional information through parameters γ\gamma and β\beta as linear transformations of the class embedding cc, where γ:=Wγ[c,z]\gamma:=W_{\gamma}^{\top}[c,z] and β:=Wβ[c,z]\beta:=W_{\beta}^{\top}[c,z], with [c,z][c,z] denoting the concatenation of cc and zz in the conditional direct skip-z.

III-C Multi-Loss Weighing Function

Our framework’s learning strategy is inspired by AC-GAN [27], which incorporates an auxiliary class head in the discriminator. However, the auxiliary classifier in [27] primarily functions to guide diversified image generation, often resulting in suboptimal classification performance. To address this issue, we introduce a multi-loss weighting function that integrates the WGAN-GP loss [28] with the classification cross-entropy loss. This approach ensures that both the generation and classification aspects of the model perform satisfactorily. We first apply the WGAN-GP loss to our GAN. The objective function is defined as follows:

S=minGmaxf1,D𝔼Ir[f1(D(I))]𝔼zz[f1(D(G(c,z)))]\displaystyle\mathcal{L}_{S}=\min_{G}\max_{f^{1},D}\underset{I\sim\mathbb{P}_{r}}{\mathbb{E}}[f^{1}(D(I))]-\underset{z\sim\mathbb{P}_{z}}{\mathbb{E}}[f^{1}(D(G(c,z)))] (2)
+λ𝔼I^I^[(I^f1(D(I^))21)2],\displaystyle\quad+\lambda\underset{\hat{I}\sim\mathbb{P}_{\hat{I}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{I}}f^{1}(D(\hat{I}))\right\|_{2}-1\right)^{2}\right],

where f1f^{1} is the source head of the discriminator, DD denotes the discriminator feature extractor and GG is the generator. r\mathbb{P}_{r}, z\mathbb{P}_{z}, and I^\mathbb{P}_{\hat{I}} donate the distribution of real data IrealI_{real}, the normal distribution of a random noise vector zz, and the distribution of pairs of points from r\mathbb{P}_{r} and z\mathbb{P}_{z}. Each generated sample is associated with a class label ccc\sim\mathbb{P}_{c} besides zz. GG uses conditioned noise to generate images Isyn=G(c,z)I_{\text{syn}}=G(c,z). Note that the source head does not involve conditioned information.

To enable conditional learning, we employ a class head P(CI)P(C\mid I) on the discriminator, the log-likelihood objective function is given by:

C=𝔼[logp(C=cIreal )]+𝔼[logp(C=cIsyn)],\displaystyle\mathcal{L}_{C}=\mathbb{E}\left[\log p\left(C=c\mid I_{\text{real }}\right)\right]+\mathbb{E}\left[\log p\left(C=c\mid I_{\text{syn}}\right)\right], (3)

both DD and GG are trained to maximize C\mathcal{L}_{C}. We find that simply combining C\mathcal{L}_{C} and S\mathcal{L}_{S} is prone to result in training failure as the WGAN-GP loss value is relatively large at most times, which results in cross-entropy loss overwhelmed during training. In addition, the model tends to perform well on the discriminator for a particular classification task, leading to an imbalance problem.

To this end, we extend the concept from the multi-task learning realm [6] to weigh classification losses and propose a multi-loss weighing loss function, which balances losses of multi-classification tasks by considering the uncertainty of each task. Specifically, let f2(I)f^{2}(I) be the output of the class head and exp(σ)\exp\left(-\sigma\right) be the trainable scaling parameter, our adapted classification likelihood of the model output through the softmax function can be written as:

p(yf2(I),σ)=Softmax(exp(σ)f2(I)),p\left(y\mid f^{2}(I),\sigma\right)=\operatorname{Softmax}\left(\exp(-\sigma)f^{2}(I)\right), (4)

with σ[,+]\sigma\in[-\infty,+\infty]. The scaling process can be regarded as a Maxwell–Boltzmann distribution, where exp(σ)\exp(-\sigma) is commonly referred to as the temperature of the input. The learnable parameter’s magnitude determines how “flat” the discrete distribution is. The output is then related to uncertainty by using the log-likelihood, which can then be written as:

logp(y=cf2(I),σ)\displaystyle-\operatorname{log}p\left(y=c\mid f^{2}(I),\sigma\right) =logSoftmax(exp(σ)fc2(I))\displaystyle=-\operatorname{logSoftmax}\left(\exp(-\sigma)f^{2}_{c}\left(I\right)\right) (5)
=logexp[exp(σ)fc2(I)]cexp[exp(σ)fc2(I)]\displaystyle=-\operatorname{log}\frac{\exp[\exp(-\sigma)f^{2}_{c}(I)]}{\sum_{c^{\prime}}\exp[\exp(-\sigma)f^{2}_{c^{\prime}}(I)]}
=exp(σ)log[exp(fc2(I))cexp(fc2(I))]\displaystyle=-\exp(-\sigma)\operatorname{log}[\frac{\exp({f^{2}_{c}(I))}}{\sum_{c^{\prime}}\exp({f^{2}_{c^{\prime}}(I))}}]
+logcexp[exp(σ)fc2(I)](cexp(fc2(I)))exp(σ)\displaystyle\,\quad+\operatorname{log}\frac{\sum_{c^{\prime}}\exp[\exp(-\sigma)f^{2}_{c^{\prime}}(I)]}{\left(\sum_{c^{\prime}}\exp({f^{2}_{c^{\prime}}(I))}\right)^{\exp(-\sigma)}}
exp(σ)+σ,\displaystyle\approx\exp(-\sigma)\mathcal{L}+\sigma,

we write =logSoftmax(y,f2(I))\mathcal{L}=-\operatorname{logSoftmax}\left(y,f^{2}(I)\right) for the cross entropy loss (not scaled). Then, an explicit simplifying assumption is introduced as exp(σ)cexp[exp(σ)fc2(I)](cexp(fc2(I)))exp(σ)\exp(-\sigma)\sum_{c^{\prime}}\exp\left[\exp(-\sigma)f_{c^{\prime}}^{2}(I)\right]\approx\left(\sum_{c^{\prime}}\exp\left(f_{c^{\prime}}^{2}(I)\right)\right)^{\exp(-\sigma)}. When σ0\sigma\to 0, this equation becomes equality. This helps simplify logcexp[exp(σ)fc2(I)](cexp(fc2(I)))exp(σ)\operatorname{log}\frac{\sum_{c^{\prime}}\exp[\exp(-\sigma)f^{2}_{c^{\prime}}(I)]}{\left(\sum_{c^{\prime}}\exp({f^{2}_{c^{\prime}}(I))}\right)^{\exp(-\sigma)}} to σ\sigma and empirically demonstrates promising results.

Given multiple classification outputs from the class head, we often define the likelihood of factorizing over the outputs. Our multi-task likelihood is shown as follows:

p(y1,y2f2(Ireal,syn))=\displaystyle p\left(y_{1},y_{2}\mid f^{2}(I_{real,syn})\right)= p(y1f2(Ireal))p(y2f2(Isyn)),\displaystyle p\left(y_{1}\mid f^{2}(I_{real})\right)p\left(y_{2}\mid f^{2}(I_{syn})\right), (6)

where y1,2y_{1,2} represent classification predictions of real and synthetic data from the class head, respectively. Next, the joint loss of weighted log-likelihood of multi-classification tasks is given as:

mlw\displaystyle\mathcal{L}_{mlw} =logp(y1,y2=c1,c2f2(Ireal,fake),σ1,σ2)\displaystyle=-\operatorname{log}p\left(y_{1},y_{2}=c_{1},c_{2}\mid f^{2}(I_{real,fake}),\sigma_{1},\sigma_{2}\right) (7)
=exp(σ1)1+exp(σ2)2+σ1+σ2\displaystyle=\exp(-\sigma_{1})\mathcal{L}_{1}+\exp(-\sigma_{2})\mathcal{L}_{2}+\sigma_{1}+\sigma_{2}

LmlwL_{mlw} can be regarded as learning the weights of each output loss. A larger temperature value, σ\sigma, results in a reduced contribution to the loss function, whereas a smaller σ\sigma increases its weight. When σ\sigma is very small, the function is predominantly regulated by the last three terms. Our experiments indicate that the range of σ\sigma values impacts model performance during training. Consequently, differently from the approach of [6], we utilize exp(σ)\exp\left(-\sigma\right) as weights for different losses. This allows σ\sigma to act as an infinitely regularized value, enhancing the model’s performance by providing a more balanced loss function across various tasks. The total loss LtotalL_{total} is shown as the following:

total=S+mlw\mathcal{L}_{total}=\mathcal{L}_{S}+\mathcal{L}_{mlw} (8)

III-D Selective Data Augmentation

We further enhance the performance of synthetic augmentation by employing a selective data augmentation mechanism. To ensure high-fidelity image generation, we first use a truncation method that resamples zz from a truncated normal distribution with a truncation value τ\tau. After generating the images conditionally, we leverage the class head to select high-quality augmentation images. Only generated images with prediction values exceeding a threshold λ\lambda will be used to compute the classification loss function. This approach helps to mitigate the negative impact of unrealistic synthetic data on model training. Our experiments indicate that setting τ=0.7\tau=0.7 and λ=0.6\lambda=0.6 provides the optimal performance. The pseudo-code of selective data augmentation is shown in the Algorithm 2.

Algorithm 2 Selective Data Augmentation Mechanism
Input: Truncation value τ\tau, Confidence threshold λ\lambda, Conditional Class Projection Module CPCP, Class Label cc, Generator GG, Discriminator Feature Extractor DD with Source Head f1f^{1} and Class Head f2f^{2}
Output: Filtered prediction y2filteredy_{2}^{\text{filtered}}
A noise vector zTruncNormal(0,1,τ,τ)z\sim\text{TruncNormal}(0,1,-\tau,\tau)
Isyn=G(CP(z,c))I_{\text{syn}}=G(CP(z,c))
y2=f2(D(Isyn))y_{2}=f^{2}(D(I_{\text{syn}}))
if y2>λy_{2}>\lambda then
     y2y2filteredy_{2}\to y_{2}^{\text{filtered}}
else
     y2y_{2} is discarded
end if
return y2filteredy_{2}^{\text{filtered}}
Refer to caption
Figure 2: Visualization comparison between real images from the training data and images generated by the proposed model.
TABLE I: Comparison of various baseline methods with and without our data synthesis augmentation method on the PCam dataset [29].
Accuracy AUC Sensitivity Specificity
ResNet34 [30] 0.881 ±\pm 0.079 0.945 ±\pm 0.022 0.846 ±\pm 0.086 0.916 ±\pm 0.020
ResNet34 + Synthetic Data 0.916 ±\pm 0.144 0.954 ±\pm 0.042 0.891 ±\pm 0.100 0.951 ±\pm 0.109
ResNet50_CBAM [31] 0.899 ±\pm 0.131 0.955 ±\pm 0.045 0.863 ±\pm 0.096 0.935 ±\pm 0.040
ResNet50_CBAM + Synthetic Data 0.922 ±\pm 0.077 0.962 ±\pm 0.046 0.879 ±\pm 0.047 0.945 ±\pm 0.110
DenseNet169 [32] 0.894 ±\pm 0.059 0.955 ±\pm 0.036 0.881 ±\pm 0.094 0.908 ±\pm 0.032
DenseNet169 + Synthetic Data 0.928 ±\pm 0.076 0.960 ±\pm 0.049 0.898 ±\pm 0.125 0.967 ±\pm 0.092
MedViTGAN [17] 0.939 ±\pm 0.059 0.980 ±\pm 0.073 0.906 ±\pm 0.079 0.974 ±\pm 0.054
Ours

0.945

±\pm

0.054

0.981 ±\pm 0.045

0.910 ±\pm 0.063

0.977 ±\pm 0.102

TABLE II: Ablation study of the proposed approach on the PatchCamelyon benchmark dataset. CP: conditional class projection. SA: selective data augmentation.
Baseline CP SA Accuracy AUC Sensitivity Specificity
\checkmark 0.916 ±\pm 0.070 0.971 ±\pm 0.098 0.900 ±\pm 0.033 0.923 ±\pm 0.081
\checkmark \checkmark 0.931 ±\pm 0.029 0.966 ±\pm 0.101 0.897 ±\pm 0.277 0.957 ±\pm 0.123
\checkmark \checkmark 0.929 ±\pm 0.110 0.977 ±\pm 0.093 0.889 ±\pm 0.042 0.961 ±\pm 0.069
\checkmark \checkmark \checkmark

0.945

±\pm

0.054

0.981 ±\pm 0.045

0.910 ±\pm 0.063

0.977 ±\pm 0.102

IV Experiments

In this section, we present experiments conducted on a benchmark dataset to validate the effectiveness of the proposed method. Additionally, we perform an ablation study to evaluate the contribution of each component of our approach.

IV-A Datasets

We use a public PatchCamelyon (PCam) benchmark dataset [29] for our experiments. The PCam dataset consists of 327,680 color images obtained from lymph node sections, digitized with a 40x objective, resulting in a pixel resolution of 0.243 microns. Each image patch has a resolution of 96 ×\times 96 pixels and is classified into two categories based on the presence of metastatic tissue in the central region. The dataset is split into 75% for training, 12.5% for validation, and 12.5% for testing, using a hard-negative mining regime. To simulate scenarios with limited training data, we randomly select only 10% of the training images (32,768 images). The entire test set is used to evaluate model performance.

IV-B Implementation Details

All experiments are conducted using two Tesla V100 GPUs, each with 16GB of RAM, and implemented with PyTorch. An Adam optimizer with parameters β1=0\beta_{1}=0, β2=0.99\beta_{2}=0.99, and a learning rate of 1e41e-4 is used to tune both the generator and discriminator. The batch size is set to 12 for both the generator and discriminator. Training is performed over 500 epochs for all experiments. We select the best epoch based on two criteria: the Frechet Inception Distance (FID) score between the generated images and the validation data, and the classification performance on the validation set.

IV-C Evaluation Metrics

To comprehensively evaluate the performance of the trained classifier, we employed four evaluation metrics in our experiments: accuracy, area under the ROC curve (AUC), sensitivity, and specificity. Accuracy is useful for datasets with a balanced distribution of categories but can be misleading when categories are unbalanced. AUC, on the other hand, provides an assessment of the model’s overall performance across different thresholds and is suitable for datasets with unbalanced categories. Sensitivity is critical in tasks where identifying all positive samples is essential, as it measures recall. Specificity is important in tasks requiring precise identification of negative samples, as it measures the true negative rate. This combined assessment offers a thorough evaluation of the model’s performance in terms of its accuracy and its ability to correctly identify both positive and negative samples.

IV-D Main Results

We conducted a comparative analysis between our proposed method, MedViTGAN [17], and several representative baseline models under identical settings. The baseline models include ResNet34 [30], ResNet50 CBAM [31], and DenseNet169 [32]. For the baseline models, we present results both with and without utilizing synthetic data generated by our proposed method during training. Each model was executed five times with random initialization to ensure a fair comparison. Mean and standard deviation values of the results are reported for comprehensive evaluation.

Table. I presents the experimental results. Comparing the results of baseline models with and without synthetic data demonstrates a significant improvement in classification performance when using data generated by our method. For instance, compared to the baseline DenseNet169, the DenseNet169 trained with synthetic data shows an improvement of 3.4%, 0.5%, 1.7% and 5.9% in accuracy, AUC, sensitivity and specificity, respectively. Furthermore, our proposed method surpasses MedViTGAN in terms of classification performance. As illustrated in Fig. 2, the synthetic images exhibit comparable quality in fidelity and diversity to real training images. Thus, we believe that our proposed model is effective for handling challenges related to limited training data in histopathology image tasks.

IV-E Ablation Study

To further validate the effectiveness of our proposed method, we conducted an ablation study to assess the contributions of its two main components: the conditional class projection module and selective data augmentation. The findings are presented in Table II. First, we removed both the conditional class projection and selective data augmentation from the framework to establish a baseline. Then, we gradually added these two modules back into the framework to analyze their individual and combined effects. Compared to the baseline, adding the conditional class projection improved accuracy by 1.5% and specificity by 3.4%, although it slightly underperformed the baseline in terms of AUC and sensitivity. Similarly, incorporating the selective data augmentation mechanism resulted in a 1.3% improvement in accuracy, 0.6% increase in AUC, and 3.8% enhancement in specificity over the baseline, while maintaining sensitivity at nearly the same level as the baseline. The decrease in AUC and sensitivity when using only conditional class projection may be due to the generation of some noisy synthetic data, slightly diminishing the model’s performance in these areas. Using only selective data augmentation avoids this issue, but its improvements in accuracy and specificity are not as pronounced as those achieved with conditional class projection alone. Overall, the results of our proposed method showed a 3.2% increase in accuracy over the baseline, demonstrating the effectiveness of our framework in enhancing classification performance.

Additionally, we analyzed the impact of different loss functions on the generated images. In these experiments, all settings were kept constant except for the loss function. We compared the multi-task loss function proposed by Kendall et al. [6] (denoted as log_) with our multi-loss weighing function (denoted as exp_). The results in Fig.3 demonstrate that with our loss function applied, the fake loss converges rapidly, indicating that the learning of feature separation achieves robust results.

Refer to caption
Figure 3: Comparison of training loss performance between our proposed multi-task loss function (exp_) and the multi-task loss function by Kendall et al. [6] (log_).

V Conclusion

In this paper, we introduce a unified framework that integrates the two stages of image generation and classification into a single stage for the synthetic augmentation of histopathology images. To facilitate conditional learning, we incorporate an auxiliary discriminator head for classification. Additionally, we integrate a conditional class projection and a selective data augmentation mechanism to enhance class separation and improve the quality of generated images. Furthermore, we propose a multi-loss weighing function to stabilize training and boost classification performance. Experimental results demonstrate that our approach significantly outperforms baseline methods. Future work will focus on extending our framework to other generative models, such as diffusion models.

References

  • [1] D. Nie, R. Trullo, J. Lian, C. Petitjean, S. Ruan, Q. Wang, and D. Shen, “Medical image synthesis with context-aware generative adversarial networks,” in MICCAI.   Springer, 2017, pp. 417–425.
  • [2] S. U. Dar, M. Yurt, L. Karacan, A. Erdem, E. Erdem, and T. Çukur, “Image synthesis in multi-contrast mri with conditional generative adversarial networks,” IEEE transactions on medical imaging, vol. 38, no. 10, pp. 2375–2388, 2019.
  • [3] C. L. Srinidhi, O. Ciga, and A. L. Martel, “Deep neural network models for computational histopathology: A survey,” Medical image analysis, vol. 67, p. 101813, 2021.
  • [4] K. Lee, H. Chang, L. Jiang, H. Zhang, Z. Tu, and C. Liu, “Vitgan: Training gans with vision transformers,” arXiv preprint arXiv:2107.04589, 2021.
  • [5] Y. Jiang, S. Chang, and Z. Wang, “Transgan: Two transformers can make one strong gan,” arXiv preprint arXiv:2102.07074, vol. 1, no. 2, p. 7, 2021.
  • [6] A. Kendall, Y. Gal, and R. Cipolla, “Multi-task learning using uncertainty to weigh losses for scene geometry and semantics,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2018, pp. 7482–7491.
  • [7] S. Ruder, “An overview of multi-task learning in deep neural networks,” arXiv preprint arXiv:1706.05098, 2017.
  • [8] A. Kebaili, J. Lapuyade-Lahorgue, and S. Ruan, “Deep learning approaches for data augmentation in medical imaging: a review,” Journal of Imaging, vol. 9, no. 4, p. 81, 2023.
  • [9] Y. Chen, X.-H. Yang, Z. Wei, A. A. Heidari, N. Zheng, Z. Li, H. Chen, H. Hu, Q. Zhou, and Q. Guan, “Generative adversarial networks in medical image augmentation: a review,” Computers in Biology and Medicine, vol. 144, p. 105382, 2022.
  • [10] L. Mescheder, A. Geiger, and S. Nowozin, “Which training methods for gans do actually converge?” in International conference on machine learning.   PMLR, 2018, pp. 3481–3490.
  • [11] J. Sohl-Dickstein, E. Weiss, N. Maheswaranathan, and S. Ganguli, “Deep unsupervised learning using nonequilibrium thermodynamics,” in International conference on machine learning.   PMLR, 2015, pp. 2256–2265.
  • [12] J. Ho, A. Jain, and P. Abbeel, “Denoising diffusion probabilistic models,” Advances in neural information processing systems, vol. 33, pp. 6840–6851, 2020.
  • [13] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein generative adversarial networks,” in International conference on machine learning.   PMLR, 2017, pp. 214–223.
  • [14] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “Bert: Pre-training of deep bidirectional transformers for language understanding,” arXiv preprint arXiv:1810.04805, 2018.
  • [15] Y. Korkmaz, M. Yurt, S. U. H. Dar, M. Özbey, and T. Cukur, “Deep mri reconstruction with generative vision transformers,” in Machine Learning for Medical Image Reconstruction: 4th International Workshop, MLMIR 2021, Held in Conjunction with MICCAI 2021, Strasbourg, France, October 1, 2021, Proceedings 4.   Springer, 2021, pp. 54–64.
  • [16] X. Zhang, X. He, J. Guo, N. Ettehadi, N. Aw, D. Semanek, J. Posner, A. Laine, and Y. Wang, “Ptnet3d: A 3d high-resolution longitudinal infant brain mri synthesizer based on transformers,” IEEE transactions on medical imaging, vol. 41, no. 10, pp. 2925–2940, 2022.
  • [17] M. Li, C. Li, P. Hobson, T. Jennings, and B. C. Lovell, “Medvitgan: End-to-end conditional gan for histopathology image augmentation with vision transformers,” in 2022 26th International Conference on Pattern Recognition (ICPR).   IEEE, 2022, pp. 4406–4413.
  • [18] Y. Liu, Z. Qin, T. Wan, and Z. Luo, “Auto-painter: Cartoon image generation from sketch by using conditional wasserstein generative adversarial networks,” Neurocomputing, vol. 311, pp. 78–87, 2018.
  • [19] J. P. Ebenezer, B. Das, and S. Mukhopadhyay, “Single image haze removal using conditional wasserstein generative adversarial networks,” in 2019 27th European Signal Processing Conference (EUSIPCO).   IEEE, 2019, pp. 1–5.
  • [20] D. A. et al., “An image is worth 16x16 words: Transformers for image recognition at scale,” arXiv preprint arXiv:2010.11929, 2020.
  • [21] S. W. et al., “Real-time single image and video super-resolution using an efficient sub-pixel convolutional neural network,” in CVPR, 2016, pp. 1874–1883.
  • [22] A. Brock, J. Donahue, and K. Simonyan, “Large scale gan training for high fidelity natural image synthesis,” arXiv preprint arXiv:1809.11096, 2018.
  • [23] T. Miyato and M. Koyama, “cgans with projection discriminator,” arXiv preprint arXiv:1802.05637, 2018.
  • [24] M. Kang and J. Park, “Contragan: Contrastive learning for conditional image generation,” Advances in Neural Information Processing Systems, vol. 33, pp. 21 357–21 369, 2020.
  • [25] J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer normalization,” arXiv preprint arXiv:1607.06450, 2016.
  • [26] R. Xiong, Y. Yang, D. He, K. Zheng, S. Zheng, C. Xing, H. Zhang, Y. Lan, L. Wang, and T. Liu, “On layer normalization in the transformer architecture,” in International Conference on Machine Learning.   PMLR, 2020, pp. 10 524–10 533.
  • [27] A. Odena, C. Olah, and J. Shlens, “Conditional image synthesis with auxiliary classifier gans,” in International conference on machine learning.   PMLR, 2017, pp. 2642–2651.
  • [28] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, and A. Courville, “Improved training of wasserstein gans,” arXiv preprint arXiv:1704.00028, 2017.
  • [29] B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, and M. Welling, “Rotation equivariant cnns for digital pathology,” in International Conference on Medical image computing and computer-assisted intervention.   Springer, 2018, pp. 210–218.
  • [30] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770–778.
  • [31] S. Woo, J. Park, J.-Y. Lee, and I. S. Kweon, “Cbam: Convolutional block attention module,” in Proceedings of the European conference on computer vision (ECCV), 2018, pp. 3–19.
  • [32] G. Huang, Z. Liu, L. Van Der Maaten, and K. Q. Weinberger, “Densely connected convolutional networks,” in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, pp. 4700–4708.