This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Provable Optimization for Adversarial Fair Self-supervised Contrastive Learning

Qi Qi
Department of Computer Science
The University of Iowa
Iowa City, IA 52242, USA
[email protected]
Quanqi Hu
Department of Computer Science & Engineering
Texas A&M University
College Station, TX 77843, USA
[email protected]
Qihang Lin
Department of Computer Science
The University of Iowa
Iowa City, IA 52242, USA
[email protected]
Tianbao Yang
Department of Computer Science & Engineering
Texas A&M University
College Station, TX 77843, USA
[email protected]
Abstract

This paper studies learning fair encoders in a self-supervised learning (SSL) setting, in which all data are unlabeled and only a small portion of them are annotated with sensitive attribute. Adversarial fair representation learning is well suited for this scenario by minimizing a contrastive loss over unlabeled data while maximizing an adversarial loss of predicting the sensitive attribute over the data with sensitive attribute. Nevertheless, optimizing adversarial fair representation learning presents significant challenges due to solving a non-convex non-concave minimax game. The complexity deepens when incorporating a global contrastive loss that contrasts each anchor data point against all other examples. A central question is “can we design a provable yet efficient algorithm for solving adversarial fair self-supervised contrastive learning?” Building on advanced optimization techniques, we propose a stochastic algorithm dubbed SoFCLR with a convergence analysis under reasonable conditions without requring a large batch size. We conduct extensive experiments to demonstrate the effectiveness of the proposed approach for downstream classification with eight fairness notions.

1 Introduction

Self-supervised learning (SSL) has become a pivotal paradigm in deep learning (DL), offering a groundbreaking approach to addressing challenges related to labeled data scarcity. The significance of SSL lies in its ability to leverage vast amounts of unlabeled data to learn encoder networks for extracting meaningful representations from input data that are useful for various downstream tasks. One state-of-the-art method for SSL is contrastive learning by minimizing a contrastive loss that contrasts a positive data pair with a number of negative data pairs. It has shown remarkable success in pretraining encoder networks, e.g., Google’s SimCLR model [6] pretrained on image data and OpenAI’s CLIP model [35] pretrained on image-text data, leading to improved performance when fine-tuned on downstream tasks.

However, like traditional supervised learning, SSL is also not immune to the fairness concern and biases inherent in the data. The potential for self-supervised models to produce unfair outcomes stems from the biases embedded in the unlabeled data used for training. The contrastive loss could inadvertently reinforce certain biases present in the data. For example, a biased feature that is highly relevant to the gender (e.g., long-hair) can be easily learned from the data to contrast a female face image with a number of images that are dominated by male face images. As a result, the learned feature representations will likely to induce unfair predictive models for downstream tasks. One approach to address this issue is to utilize traditional supervised fairness-aware approaches to learn a predictive model in the downstream task by removing the disparate impact of biased features highly relevant to sensitive attributes. However, this approach will suffer from several limitations: (i) it requires repeated efforts for different downstream tasks; (ii) it requires labeled data to be annotated with sensitive attribute as well, which may cause privacy issues.

Several studies have put forward techniques to enhance the fairness of contrastive learning of representations [29, 25, 47, 17]. However, the existing approaches focused on modifying the contrastive loss by restricting the space of positive data and/or the space of negative data of an anchor data by using the sensitive attributes of all data or an external image generator [47]. Different from the prior studies on fair contrastive learning, we revisit a classical idea of advesarial fair representation learning (AFRL) by solving the following minimax problem:

min𝐰max𝐰F(𝐰,𝐰):=FGCL(𝐰)+αFfair(𝐰,𝐰),\displaystyle\min_{\bf{w}}\max_{{\bf{w}}^{\prime}}F({\bf{w}},{\bf{w}}^{\prime}):=F_{\text{GCL}}({\bf{w}})+\alpha F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime}), (1)

where FGCL(𝐰)F_{\text{GCL}}({\bf{w}}) denotes a self-supervised global contrastive loss (GCL) of learning an encoder network parameterized by 𝐰{\bf{w}} and Ffair(𝐰,𝐰)F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime}) denotes an adversarial loss of a discrimnator parameterized by 𝐰{\bf{w}}^{\prime} for predicting the sensitive attribute given the encoded representation. There are several benefits of this approach compared with existing fairness-aware contrastive learning [30, 17, 25, 47]. First, the contrastive loss remains intact. Hence, it offers more flexibility of choosing different contrastive losses including bimodal contrastive losses [35]. Second, only the fairness-promoting regularizer depends on the sensitive attribute. Therefore, it is not necessary to annotate all unlabeled data with sensitive attribute, which makes it more suitable for SSL with a large number of unlabeled data but only limited data with sensitive attribute.

Despite the simplicity of the framework, a challenging question remains: is it possible to design an efficient algorithm that can be proved to converge for solving (3)? There are two hurdles to be overcome: (i) the problem could be non-convex and non-concave if the discriminator is not a linear model, which makes the convergence analysis formidable; (ii) the standard mini-batch gradient estimator yields a biased stochastic gradient of the primal variable due to presence of GCL, which does not ensure the convergence. Although existing studies in ML have shown that solving a general non-convex minimax game (e.g., generative adversarial networks [GAN]) is not stable [20], it is still possible to develop provable yet efficient algorithms for AFRL because: (i) AFRL can use a simple/shallow network for the discriminator that operates on an encoded representation; hence one step of the dual update of 𝐰{\bf{w}}^{\prime} followed by a step of the primal update of 𝐰{\bf{w}} is sufficient; (ii) non-convex minimax game has been proved to converge under weaker structured conditions than concavity. Based these two observations, we design an efficient algorithm and and provide its convergence guarante. Our contributions are outlined below:

  • 1.

    Theoretically, we propose a stochastic algorithm for optimizing the non-convex compositional minimax game and manage to establish a convergence result under provable conditions without requiring convexity/concavity on either side.

  • 2.

    Empirically, we conduct extensive experiments to demonstrate the effectiveness of the proposed approach for downstream classification with eight fairness notions.

This paper is different from tremendous studies on fair representation learning, which either rely on the labeled data with annotated sensitive attribute, or simply use mini-batch based contrastive loss for SSL that has no convergence guarantee, or only examine few standard fairness measures for downstream classification.

Table 1: Comparison with existing studies on fair deep representation learning. The column “Label” and “Sensitive Attribute" mean whether the label information or the senstive attribute information is used in the training process. means that the method can be extended to a setting with partial data annotated with sensitive attribute.
Category Reference Adversarial Contrastive Label Sensitive Attribute Sample Generator Theoretical Analysis
(Semi-)Supervised [10] Yes No Yes Yes No No
[23] No No Partial Yes No No
[3] Yes No Yes Partial No No
[40] Yes No Yes Yes No Yes
[11] Yes No Yes Yes No No
[36] Yes No Yes Yes No Yes
[27] No No Yes Yes No No
[29] No Yes Yes Yes No No
Unsupervised [23] No No No Yes No No
[27] No No No Yes No No
Self-supervised [25] No Yes No Yes No No
[5] No Yes Partial No No No
[17] No Yes No Yes Yes No
[47] No Yes No Partial Yes No
Our work Yes Yes No Partial No Yes

2 Related Work

While there are tremendous work on fairness-aware learning [31, 26, 1, 32], we restrict our discussion below to most relevant studies about fair representation learning, which is able to learn a mapping function that induces representations of data for fairer predictions.

(Semi-)Supervised Fair Representation Learning. The seminal work [45] initiated the study of fair representation learning. The goal is to learn a fair mapping that maps the original input feature representation into a space that not only preserves the information of data but also obfuscate any information about membership in the protected group. Nevertheless, their approach is deemed as a shallow approach, which is not suitable for deep learning.

For DL, many prior studies have considerd to learn an encoder network that induces fair representation of data with respect to sensitive attributes. A classical idea is adversarial learning by minimizing the loss of predicting class labels and maximizing the loss of predicting the sensitive attribute given the encoded representations. This has been studied in [10, 11, 40, 38, 3, 12, 46, 36] for different applications or different contexts. For example, [12, 7] tackle the domain adaptation setting, where the encoder network is learned such that it cannot discriminate between the training (source) and test (target) domains. Other approaches have explicitly considered fair classification with respect to some sensitive attribute. Among these studies, [3] raised the challenge of collecting sensitive attribute information for all data and considered a setting only part of the labeled data have a sensitive attribute. In addition to adversarial learning, variational auto-encoder (VAE) based methods have been explored for fair representation learning [23, 13, 27, 16]. These methods can work in unsupervised learning setting where no labels are given or semi-supervised learning setting where only part of the data are labeled. However, they do not consider how to leverage data that are not annotated with attribute information. Moreover, VAE-based representation learning methods lag significantly behind self-supervised contrastive representation learning on complicated tasks in terms of performance [4].

Fair Contrastive Learning. Another category of research that is highly related to our study is fair contrastive learning [30, 17, 25, 47]. These methods usually use the sensitive attribute information to restrict the space of negative and positive data in the contrastive loss. [30] utilized a supervised contrastive loss for representation learning, and modifies the contrastive loss by incorporating sensitive attribute information to encourage samples with the same class but different attributes to be closer together while pushing samples with the same sensitive attribute but different classes further apart. Their approach requires all data to be labeled and annotated with sensitive attribute information.

Several papers have modified the contrastive loss for self-supervised learning [25, 17, 47]. For example, in [25] the authors define a contrastive loss for each group of senstive attribute separately. [17] proposed to incorporate counterfactual fairness by using a counterfactual version of an anchor data as positive in contrastive learning. It is generated by flipping the sensitive attribute (e.g., female to male) using a sample generator (cyclic variational autoencoder), which is learned separately. [47] used a similar idea by constructing a positive sample of an achor data with a different sensitive attribute generated by the image attribute editor, and constructing the negative samples as the views generated from the different images with the same sensitive attribute. To this end, they also need to train an image attribute editor that can genereate a sample with a different sensitive attribute.

[5] proposed a bilevel learning framework in the setting that no sensitive attribute information is available. They used a weighted self-supervised contrastive loss as the lower-level objective for learnig a representation and an averaged top-KK classification loss on validation data as the upper objective to learn the weights and the classifier. Table 1 summarizes prior studies and our work.

3 Preliminaries

Notations. Let 𝒟\mathcal{D} represent an unlabeled set of nn images, and let 𝒟a𝒟\mathcal{D}_{a}\subset\mathcal{D} denote a subset of knk\ll n training images with attribute information. Let 𝐱𝒟{\bf{x}}\sim\mathcal{D} denote a random data uniformly sampled from 𝒟\mathcal{D}. For each (𝐱,a)𝒟a({\bf{x}},a)\in\mathcal{D}_{a}, a{1,,K}a\in\{1,\ldots,K\} denotes the sensitive attribute.

We denote an encoder network by E𝐰()E_{\bf{w}}(\cdot) parameterized by 𝐰{\bf{w}}, and let E𝐰(𝐱)dE_{{\bf{w}}}({\bf{x}})\in\mathbb{R}^{d} represent a normalized output representation of input data 𝐱{\bf{x}}. For simplicity, we omit the parameters and use E()E(\cdot) to refer to the encoder. 𝒫\mathcal{P} denotes a set of standard data augmentation operators generating various views of a given image [6], and 𝒜𝒫\mathcal{A}\sim\mathcal{P} is a random data augmentation operator. We use p(X)p(X) to denote the probability distribution of a random variable XX, and use p(X|Y)p(X|Y) to represent the conditional distribution of a random variable XX given YY.

A state-of-the-art method of SSL is to optimize a contrastive loss [33, 6]. A standard approach for defining a contrastive loss of image data is the following mini-batch based contrastive loss for each image 𝐱i{\bf{x}}_{i} and two random augmentation operators 𝒜,𝒜𝒫\mathcal{A},\mathcal{A}^{\prime}\sim\mathcal{P} [6]:

LCL(𝐱i,𝒜,𝒜,)=logexp(E(𝒜(𝐱i))E(𝒜(𝐱i))/τ)𝐱~i{𝒜(𝐱i)}exp(E(𝒜(𝐱i))E(𝐱~)/τ),\displaystyle L_{\text{CL}}({\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime},\mathcal{B})=-\log\frac{\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E(\mathcal{A}^{\prime}({\bf{x}}_{i}))/\tau)}{\sum_{\tilde{{\bf{x}}}\in\mathcal{B}_{i}^{-}\cup\{\mathcal{A}^{\prime}({\bf{x}}_{i})\}}\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E(\tilde{{\bf{x}}})/\tau)},

where τ>0\tau>0 is called the temperature parameter, 𝒟\mathcal{B}\subset\mathcal{D} is a random mini-batch and i={𝒜(𝐱),𝒜(𝐱)|𝐱𝐱i}\mathcal{B}_{i}^{-}=\{\mathcal{A}({\bf{x}}),\mathcal{A}^{\prime}({\bf{x}})|{\bf{x}}\in\mathcal{B}\setminus{\bf{x}}_{i}\} denotes the set of all other samples in the mini-batch and their two random augmentations. However, this approach requires a very large size to achieve good performance [6] due to a large optimization error with a small batch size [43].

To address this challenge, Yuan et al. [43] proposed to optimize a global contrastive loss (GCL) based on advanced optimization techniques with rigorous convergence guarantee. We adopte the second variant of GCL defined in their work in order to derive a convergence guaratnee. A GCL for a given sample 𝐱i{\bf{x}}_{i} and two augmentation operators 𝒜,𝒜𝒫\mathcal{A},\mathcal{A}^{\prime}\sim\mathcal{P} can be defined as:

LGCL(𝐱i,𝒜,𝒜,𝒟)=τlogexp(E(𝒜(𝐱i))E(𝒜(𝐱i))/τ)ϵ0+𝔼𝒜𝐱~𝒮iexp(E(𝒜(𝐱i))E1(𝐱~)/τ),\displaystyle L_{\text{GCL}}({\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime},\mathcal{D})=-\tau\log\frac{\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E(\mathcal{A}^{\prime}({\bf{x}}_{i}))/\tau)}{\epsilon_{0}+\mathbb{E}_{\mathcal{A}}\sum_{\tilde{{\bf{x}}}\in\mathcal{S}_{i}^{-}}\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E_{1}(\tilde{{\bf{x}}})/\tau)},

where ϵ0\epsilon_{0} is a small constant, 𝒮i={𝒜(𝐱)|𝒜𝒫,𝐱𝒟\𝐱i}\mathcal{S}_{i}^{-}=\{\mathcal{A}({\bf{x}})|\mathcal{A}\in\mathcal{P},{\bf{x}}\in\mathcal{D}\backslash{{\bf{x}}_{i}}\} denotes the set of all data to be contrasted with 𝐱i{\bf{x}}_{i}, which can be constructed by including all other images except for 𝐱i{\bf{x}}_{i} and their augmentations. Then, the averaged GCL becomes FGCL(𝐰)=𝔼𝐱i𝒟,𝒜,𝒜𝒫LGCL(𝐱i,𝒜,𝒜,𝒟)F_{\text{GCL}}({\bf{w}})=\mathbb{E}_{{\bf{x}}_{i}\sim\mathcal{D},\mathcal{A},\mathcal{A}^{\prime}\sim\mathcal{P}}L_{\text{GCL}}({\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime},\mathcal{D}). To facilitate the design of stochastic optimization, we cast the above loss into the following form:

FGCL(𝐰):=\displaystyle F_{\text{GCL}}({\bf{w}}):= f1(𝐰)+1ni=1nf2(g(𝐰;𝐱i,𝒮i))+c,\displaystyle f_{1}({\bf{w}})+\frac{1}{n}\sum_{i=1}^{n}f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))+c, (2)

where

f1(𝐰)=𝔼𝐱𝒟,𝒜,𝒜𝒫[f1(𝐰;𝐱i,𝒜,𝒜)],f1(𝐰;𝐱i,𝒜,𝒜)=E(𝒜(𝐱i))E(𝒜(𝐱i)),\displaystyle f_{1}({\bf{w}})=\mathbb{E}_{{\bf{x}}\sim\mathcal{D},\mathcal{A},\mathcal{A}^{\prime}\sim\mathcal{P}}[f_{1}({\bf{w}};{\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime})],\quad f_{1}({\bf{w}};{\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime})=-E(\mathcal{A}({\bf{x}}_{i}))^{\top}E(\mathcal{A}^{\prime}({\bf{x}}_{i})),
f2(g)=τlog(ϵ0+g),g(𝐰,𝐱i,𝒮i)=𝔼𝐱~𝒮i𝔼𝒜exp(E(𝒜(𝐱i))E1(𝐱~)/τ),\displaystyle f_{2}(g)=\tau\log(\epsilon_{0}^{\prime}+g),\quad g({\bf{w}},{\bf{x}}_{i},\mathcal{S}_{i}^{-})=\mathbb{E}_{\tilde{{\bf{x}}}\sim\mathcal{S}_{i}^{-}}\mathbb{E}_{\mathcal{A}}\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E_{1}(\tilde{{\bf{x}}})/\tau),

and ϵ0=ϵ0/|𝒮i|\epsilon_{0}^{\prime}=\epsilon_{0}/|\mathcal{S}_{i}^{-}| is a small constant and cc is a constant.

4 The Formulation and Justification

Our method is built on a classical idea of adversarial training by solving a minimax zero-sum game [12]. The idea is to maximize the loss of predicting the sensitive attribute by a model based on learned representations while minimizing a certain loss of learning the encoder network. To this end, we define a discriminator D𝐰(v):dKD_{{\bf{w}}^{\prime}}(\textbf{v}):\mathbb{R}^{d}\rightarrow\mathbb{R}^{K} parameterized by 𝐰{\bf{w}}^{\prime} that outputs the probabilities of different values of aa for an input data associated with an encoded representation vector vd\textbf{v}\in\mathbb{R}^{d}. For example, a simple choice of D𝐰D_{{\bf{w}}^{\prime}} could be [D𝐰(v)]k=exp(𝐰kv)l=1Kexp(𝐰lv)[D_{{\bf{w}}^{\prime}}(\textbf{v})]_{k}=\frac{\exp({\bf{w}}_{k}^{\prime\top}\textbf{v})}{\sum_{l=1}^{K}\exp({{\bf{w}}^{\prime}_{l}}^{\top}\textbf{v})}.

Then, we introduce a fairness-promoting regularizer:

max𝐰Ffair(𝐰,𝐰):=𝔼(𝐱,a)𝒟a,𝒜𝒫ϕ(𝐰,𝐰;𝒜(𝐱),a),\displaystyle\max_{{\bf{w}}^{\prime}}F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime}):=\mathbb{E}_{({\bf{x}},a)\sim\mathcal{D}_{a},\mathcal{A}\sim\mathcal{P}}\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a),

where ϕ(𝐰,𝐰;𝒜(𝐱),a)\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a) denotes the log-likelihood of the discriminator on predicting the sensitive attribute aa of the augmented data 𝒜(𝐱)\mathcal{A}({\bf{x}}) based on E𝐰(𝒜(𝐱))E_{\bf{w}}(\mathcal{A}({\bf{x}})), i.e., ϕ(𝐰,𝐰;𝒜(𝐱),a):=log([D𝐰(E𝐰(𝒜(𝐱)))]a)\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a):=\log([D_{{\bf{w}}^{\prime}}(E_{{\bf{w}}}(\mathcal{A}({\bf{x}})))]_{a}). Thus, the minimax zero-sum game for learning the encoder network and the discriminator is imposed by:

min𝐰max𝐰F(𝐰,𝐰):=FGCL(𝐰)+αFfair(𝐰,𝐰),\displaystyle\min_{\bf{w}}\max_{{\bf{w}}^{\prime}}F({\bf{w}},{\bf{w}}^{\prime}):=F_{\text{GCL}}({\bf{w}})+\alpha F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime}), (3)

where α\alpha is a parameter that controls the trade-off between the GCL and the fairness regularizer. There are several benefits of this approach compared with existing fairness-aware contrastive learning [30, 17, 25, 47]. First, the contrastive loss remains intact. Hence, it offers more flexibility of choosing different contrastive losses or even other losses for SSL and also makes it possible to extend our framework to multi-modal SSL [35]. Second, only the fairness-promoting regularizer depends on the sensitive attribute. Therefore, it is not necessary to annotate all unlabeled data with sensitive attribute, which makes it more suitable for SSL with a large number of unlabeled data.

Next, we present a theoretical justification of the minimax framework with a fairness-promoting regularizer. To formally quantify the fairness of learned representations, we define a distributional representation fairness as following.

Definition 1 (Distributional Representation Fairness).

For any random data (𝐱,a)({\bf{x}},a), an encoder network EE is called fair in terms of representation if p(E(𝐱)|a=k)=p(E(𝐱))p(E({\bf{x}})|a=k)=p(E({\bf{x}})).

The above definition of distributional representation fairness resembles the demographic parity (DP) requiring p(h(𝐱)|a=k)=p(h(𝐱))p(h({\bf{x}})|a=k)=p(h({\bf{x}})), where h(𝐱)h({\bf{x}}) is a predictor that predicts the class label for 𝐱{\bf{x}}. However, distributional representation fairness is a stronger condition, which implies DP for downstream classification tasks if the encoder network is fixed. This is a simple argument. Let 𝒮\mathcal{S} denote a random labeled training set for learning a linear classification model h(𝐱)=𝐰E(𝐱)h({\bf{x}})={\bf{w}}^{\top}E({\bf{x}}) based on the encoded representation. Denote by 𝐰(𝒮){\bf{w}}(\mathcal{S}) a learned predictive model. With the distributional representation fairness, we have p(E(𝐱)𝐰(𝒮)|a=k)=p(E(𝐱)𝐰(𝒮))p(E({\bf{x}})^{\top}{\bf{w}}(\mathcal{S})|a=k)=p(E({\bf{x}})^{\top}{\bf{w}}(\mathcal{S})) for a random data (𝐱,a)({\bf{x}},a) that is independent of a random labeled training set 𝒮\mathcal{S}. This is due to that E(𝐱)𝐰(𝒮)|a=kE({\bf{x}})^{\top}{\bf{w}}(\mathcal{S})|a=k and E(𝐱)𝐰(𝒮)E({\bf{x}})^{\top}{\bf{w}}(\mathcal{S}) have the same moment generating functions. Hence, DP holds, i.e., p(sign(E(𝐱)𝐰(𝒮))|a=k)=p(sign(E(𝐱)𝐰(𝒮)))p(\text{sign}(E({\bf{x}})^{\top}{\bf{w}}(\mathcal{S}))|a=k)=p(\text{sign}(E({\bf{x}})^{\top}{\bf{w}}(\mathcal{S}))).

Next, we present a proposition for justifying the zero-sum game framework. To this end, we abuse a notation Ffair(E,D)=𝔼𝐱,aϕ(E,D;𝐱,a)F_{\text{fair}}(E,D)=\mathbb{E}_{{\bf{x}},a}\phi(E,D;{\bf{x}},a) for any encoder EE and any discriminator DD.

Proposition 1.

Suppose EE and DD have enough capacity, then the global optimal solution to the zero-sum game minEmaxDFfair(E,D)\min_{E}\max_{D}F_{\text{fair}}(E,D) denoted by E,DE_{*},D_{*} would satisfy p(E(𝐱)|a)=p(E(𝐱))p(E_{*}({\bf{x}})|a)=p(E_{*}({\bf{x}})) and [D(E(𝐱))]k=p(a=k)[D_{*}(E_{*}({\bf{x}}))]_{k}=p(a=k).

Remark: We attribute the credit of the above result to earlier works, e.g., [49]. It indicates the distribution of encoded representations is independent of the sensitive attribute. The proof of the above theorem is similar to the analysis of GAN.

It is notable that the above result is different from that derived in [40, 36], which considers a fixed encoder EE and only ensures the learned model recovers the true conditional distributions of the label and sensitive attribute given the representation. It has nothing related to fairness.

5 Stochastic Algorithm and Analysis

The optimization problem (3) deviates from existing studies of fair representation leaning in that (i) the GCL is a compositional function that requires more advanced optimization techniques in order to ensure convergence without requiring a large batch size [43]; (ii) the problem is a non-convex non-concave minimax compositional optimization, which is not a standard minimax optimization.

5.1 Algorithm Design

A major challenge lies at how to compute a stochastic gradient estimator of FGCL(𝐰)F_{\text{GCL}}({\bf{w}}). In particular, the term 1ni=1nf2(g(𝐰;𝐱i,𝒮i))\frac{1}{n}\sum_{i=1}^{n}f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-})) in the GCL is a finite-sum coupled compositional function [39]. The gradient f2(g(𝐰;𝐱i,𝒮i))g(𝐰;𝐱i,𝒮i)\nabla f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))\nabla g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}) is not easy to compute as the inner function depends on a large number of data in 𝒮i\mathcal{S}_{i}^{-}. Because f2f_{2} is non-linear, an unbiased stochastic estimation for the gradient FGCL(𝐰)\nabla F_{\text{GCL}}({\bf{w}}) is not easily accessible by using mini-batch samples. In particular, the standard minibatch-based approach that uses f2(g(𝐰;𝐱i,i))g(𝐰;𝐱i,i)\nabla f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{B}_{i}^{-}))\nabla g({\bf{w}};{\bf{x}}_{i},\mathcal{B}_{i}^{-}) as a gradient estimator for each sampled data 𝐱i{\bf{x}}_{i} will suffer a large optimization error due to this estimator is not unbiased.

Algorithm 1 Stochastic Optimization for Fair Contrastive Learning (SoFCLR)
1:  Initialization: 𝐰1,𝐰1,𝐮1,𝐦~1{\bf{w}}_{1},{\bf{w}}^{\prime}_{1},{\bf{u}}_{1},\tilde{{\bf{m}}}_{1}.
2:  for t=1Tt=1\cdots T do
3:     Sample batches of data 𝒟\mathcal{B}\subset\mathcal{D}, a𝒟a\mathcal{B}_{a}\subset\mathcal{D}_{a}.
4:     for 𝐱i{\bf{x}}_{i}\in\mathcal{B} do
4:        Sample two data augmentations 𝒜,𝒜𝒫\mathcal{A},\mathcal{A}^{\prime}\sim\mathcal{P}
5:        Calculate g(𝐰,𝒜(𝐱i),i)g({\bf{w}},\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-}), g(𝐰,𝒜(𝐱i),i)g({\bf{w}},\mathcal{A}^{\prime}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})
6:        Update 𝐮it{\bf{u}}_{i}^{t} according to (4)
7:     end for
8:     Compute 𝐦t+1{\bf{m}}_{t+1} as in (5.1) and vt+1\textbf{v}_{t+1} as in (6).
9:     Compute 𝐦~t+1=(1β)𝐦~t+β𝐦t+1\widetilde{{\bf{m}}}_{t+1}=(1-\beta)\widetilde{{\bf{m}}}_{t}+\beta{\bf{m}}_{t+1},
10:     Update 𝐰t+1=𝐰tη𝐦~t+1{\bf{w}}_{t+1}={\bf{w}}_{t}-\eta\widetilde{{\bf{m}}}_{t+1} (or Adam udpate)
11:     Update 𝐰{\bf{w}}^{\prime}: 𝐰t+1=𝐰t+ηvt+1{\bf{w}}^{\prime}_{t+1}={\bf{w}}^{\prime}_{t}+\eta^{\prime}\textbf{v}_{t+1}
12:  end for

To handle this challenge, we will follow the idea in SogCLR [43] by maintaining and updating moving average estimators of the inner function values. Let us define g(𝐰;𝒜(𝐱i),𝐱~)=exp(E(𝐰;𝒜(𝐱i))E(𝐰;𝐱~)/τ)g({\bf{w}};\mathcal{A}({\bf{x}}_{i}),\tilde{{\bf{x}}})=\exp(E({\bf{w}};\mathcal{A}({\bf{x}}_{i}))^{\top}E({\bf{w}};\tilde{{\bf{x}}})/\tau). Then 𝔼𝒜,𝐱~[g(𝐰,𝒜(𝐱i),𝐱~)]=g(𝐰;𝐱i,𝒮i)\mathbb{E}_{\mathcal{A},\tilde{{\bf{x}}}}[g({\bf{w}},\mathcal{A}({\bf{x}}_{i}),\tilde{{\bf{x}}})]=g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}). We use the vector 𝐮=[𝐮1,,𝐮n]n{\bf{u}}=[{\bf{u}}_{1},\cdots,{\bf{u}}_{n}]\in\mathbb{R}^{n} to track the moving average history of stochastic estimator of g(𝐰;𝐱i,𝒮i)g({\bf{w}};{\bf{x}}_{i},\mathcal{S}^{-}_{i}). For sampled 𝐱i{\bf{x}}_{i}\in\mathcal{B}, we update 𝐮i,t{\bf{u}}_{i,t} by

𝐮i,t+1=(1γ)𝐮i,t+γ2[g(𝐰t;𝒜(𝐱i),i)+g(𝐰t;𝒜(𝐱i),i)],\displaystyle{\bf{u}}_{i,t+1}=(1-\gamma){\bf{u}}_{i,t}+\frac{\gamma}{2}\left[g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})+g({\bf{w}}_{t};\mathcal{A}^{\prime}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})\right], (4)

where γ(0,1)\gamma\in(0,1) denotes the moving average parameter. For unsampled 𝐱i{\bf{x}}_{i}\not\in\mathcal{B}, no update is needed, i.e., 𝐮i,t+1=𝐮i,t{\bf{u}}_{i,t+1}={\bf{u}}_{i,t}. Then, f2(g(𝐰t;𝐱i,𝒮i))g(𝐰t;𝐱i,𝒮i)\nabla f_{2}(g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-}) can be estimated by f2(𝐮i,t)g(𝐰t;𝐱i,i)\nabla f_{2}({\bf{u}}_{i,t})\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{B}_{i}^{-}). Compared to the simple minibatch estimator, this estimator ensures diminishing average error. Thus, we compute a stochastic gradient estimator of 𝐰F(𝐰t,𝐰t)\nabla_{\bf{w}}F({\bf{w}}_{t},{\bf{w}}_{t}^{\prime}) by

𝐦t+1=1||𝐱i{𝐰f1(𝐰t;𝐱i,𝒜,𝒜)+τ(𝐰g(𝐰t;𝒜(𝐱i),i)+𝐰g(𝐰t;𝒜(𝐱i),i))2(ϵ0+𝐮i,t)}\displaystyle{\bf{m}}_{t+1}=\frac{1}{|\mathcal{B}|}\sum\limits_{{\bf{x}}_{i}\in\mathcal{B}}\bigg{\{}\nabla_{{\bf{w}}}f_{1}({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{A},\mathcal{A}^{\prime})+\frac{\tau(\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})+\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}^{\prime}({\bf{x}}_{i}),\mathcal{B}_{i}^{-}))}{2(\epsilon_{0}^{\prime}+{\bf{u}}_{i,t})}\bigg{\}}
+α2|a|𝐱ia{𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)+𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)}.\displaystyle+\frac{\alpha}{2|\mathcal{B}_{a}|}\sum\limits_{{\bf{x}}_{i}\in\mathcal{B}_{a}}\bigg{\{}\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}^{\prime}({\bf{x}}_{i}),a_{i})\bigg{\}}. (5)

Then we update the primal variable 𝐰t+1{\bf{w}}_{t+1} using either the momentum method or the Adam method. For updating the dual variable 𝐰{\bf{w}}^{\prime}, we can employ stochastic gradient ascent-type update based on the following stochastic gradient:

vt+1\displaystyle\textbf{v}_{t+1} =12|a|𝐱ia(𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)+𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)).\displaystyle=\frac{1}{2|\mathcal{B}_{a}|}\sum\limits_{{\bf{x}}_{i}\in\mathcal{B}_{a}}(\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}^{\prime}({\bf{x}}_{i}),a_{i})). (6)

Finally, we present detailed steps of the proposed stochastic algorithm in Algorithm 1, which is referred to as SoFCLR. For simplicity of exposition, we use stochastic momentum update for the primal variable 𝐰{\bf{w}} and the stochastic gradient ascent update for the dual variable 𝐰{\bf{w}}^{\prime}.

5.2 Convergence Analysis

The convergence analysis is complicated by the presence of non-convexity and non-concavity of the minimax structure and coupled compositional structure. Our goal is to derive a convergence for finding an ϵ\epsilon-stationary point to the primal objective function Φ(𝐰)=max𝐰F(𝐰,𝐰)\Phi({\bf{w}})=\max_{{\bf{w}}^{\prime}}F({\bf{w}},{\bf{w}}^{\prime}), i.e., a point 𝐰{\bf{w}} such that 𝔼[Φ(𝐰)2]ϵ2\mathbb{E}[\|\nabla\Phi({\bf{w}})\|^{2}]\leq\epsilon^{2}. We emphasize that it is generally impossible to prove this result without imposing some conditions of the objective function.

Assumption 1.

We make the following assumptions.

  1. (a)

    Φ()\Phi(\cdot) is LL-smooth.

  2. (b)

    E𝐰(𝐱)E_{\bf{w}}({\bf{x}}) is smooth and Lipchitz continuous w.r.t 𝐰{\bf{w}}. There exists constant CE<C_{E}<\infty such that E𝐰(𝐱)CE\|E_{\bf{w}}({\bf{x}})\|\leq C_{E} for all 𝐰{\bf{w}} and 𝐱{\bf{x}}.

  3. (c)

    D𝐰(v)D_{{\bf{w}}^{\prime}}(\textbf{v}) is smooth and Lipchitz continuous w.r.t 𝐰{\bf{w}}^{\prime}.

  4. (d)

    There exists Δ\Delta such that Φ(𝐰1)min𝐰Φ(𝐰)Δ\Phi({\bf{w}}_{1})-\min_{{\bf{w}}}\Phi({\bf{w}})\leq\Delta.

  5. (e)

    The following variances are bounded:

    𝔼𝒜,𝐱~(g(𝐰;𝒜(𝐱i),𝐱~)g(𝐰;𝐱i,𝒮i))2σ2\displaystyle\mathbb{E}_{\mathcal{A},\widetilde{\bf{x}}}(g({\bf{w}};\mathcal{A}({\bf{x}}_{i}),\widetilde{\bf{x}})-g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))^{2}\leq\sigma^{2}
    𝔼𝐱,𝒜,𝒜𝐰f1(𝐰;𝐱,𝒜,𝒜)f1(𝐰)2σ2\displaystyle\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}\|\nabla_{{\bf{w}}}f_{1}({\bf{w}};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})-\nabla f_{1}({\bf{w}})\|^{2}\leq\sigma^{2}
    𝔼𝒜,𝐱~𝐰g(𝐰;𝒜(𝐱i),𝐱~)𝐰g(𝐰;𝐱i,𝒮i)2σ2,\displaystyle\mathbb{E}_{\mathcal{A},\widetilde{\bf{x}}}\|\nabla_{\bf{w}}g({\bf{w}};\mathcal{A}({\bf{x}}_{i}),\widetilde{\bf{x}})-\nabla_{{\bf{w}}}g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\|^{2}\leq\sigma^{2},
    𝔼𝒜,𝐱,a𝐰ϕ(𝐰,𝐰;𝒜(𝐱),a)𝐰Ffair(𝐰,𝐰)2σ2,\displaystyle\mathbb{E}_{\mathcal{A},{\bf{x}},a}\|\nabla_{\bf{w}}\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a)-\nabla_{{\bf{w}}}F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime})\|^{2}\leq\sigma^{2},
    𝔼𝒜,𝐱,a𝐰ϕ(𝐰,𝐰;𝒜(𝐱),a)𝐰Ffair(𝐰,𝐰)2σ2.\displaystyle\mathbb{E}_{\mathcal{A},{\bf{x}},a}\|\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a)-\nabla_{{\bf{w}}^{\prime}}F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime})\|^{2}\leq\sigma^{2}.
  6. (f)

    For any 𝐰{\bf{w}}, Ffair(𝐰,)F_{\text{fair}}({\bf{w}},\cdot) satisfies

    (𝐰𝐰)𝐰Ffair(𝐰,𝐰)λ𝐰𝐰2,-({\bf{w}}^{\prime}_{*}-{\bf{w}}^{\prime})^{\top}\nabla_{{\bf{w}}^{\prime}}F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime})\geq\lambda\|{\bf{w}}^{\prime}-{\bf{w}}^{\prime}_{*}\|^{2}, (7)

    where 𝐰argmaxvFfair(𝐰,v){\bf{w}}^{\prime}_{*}\in\arg\max_{\textbf{v}}F_{\text{fair}}({\bf{w}},\textbf{v}) is one optimal solution closest to 𝐰{\bf{w}}^{\prime}.

Remark: Conditions (a, b, c) are simplifications to ensure the objective function and each component function are smooth. Conditions (d, e) are standard conditions for stochastic non-convex optimization. Condition (f) is a special condition, which is called one-point strong convexity [44] or restricted secant inequality [18]. It does not necessarily require the convexity in terms of 𝐰{\bf{w}}^{\prime} and much weaker than strong convexity [18]. It has been proved for wide neural networks [21], i.e., when D𝐰D_{{\bf{w}}^{\prime}} is a wide neural network.

Theorem 2.

Under the above assumption and parameter setting η=𝒪(min{β,||γn,η})\eta=\mathcal{O}\left(\min\left\{\beta,\frac{|\mathcal{B}|\gamma}{n},\eta^{\prime}\right\}\right), β=𝒪(min{||,|a|}ϵ2)\beta=\mathcal{O}\left(\min\{|\mathcal{B}|,|\mathcal{B}_{a}|\}\epsilon^{2}\right), γ=𝒪(||ϵ2)\gamma=\mathcal{O}\left(|\mathcal{B}|\epsilon^{2}\right), η=𝒪(λ|a|ϵ2)\eta^{\prime}=\mathcal{O}\left(\lambda|\mathcal{B}_{a}|\epsilon^{2}\right), after T=𝒪(maxT=\mathcal{O}\Big{(}\max {1min{||,|a|},n||2,1λ3|a|}ϵ4)\left\{\frac{1}{\min\{|\mathcal{B}|,|\mathcal{B}_{a}|\}},\frac{n}{|\mathcal{B}|^{2}},\frac{1}{\lambda^{3}|\mathcal{B}_{a}|}\right\}\epsilon^{-4}\Big{)} iterations, SoFCLR can find an ϵ\epsilon-stationary solution of Φ()\Phi(\cdot).

Remark: We can see that ϵ4\epsilon^{-4} matches the complexity of SGD for non-convex minimization problems. In addition, the factor n/||2n/|\mathcal{B}|^{2} is the same as that of SogCLR for optimizing the GCL in [43]. The additional factor 1/(λ3|a|)1/(\lambda^{3}|\mathcal{B}_{a}|) is due to the maximization of a non-concave function. We refer the detailed statement and proof of Theorem 2 to Appendix B.

6 Experiments

Datasets. We use two face image datasets for our experiments, namely CelebA [22] and UTKface [48]. CelebA is a large-scale face attributes dataset with more than 200K celebrity images, each with binary annotations of 40 attributes. UTKFace includes more than 20K face images labeled by gender, age, and ethnicity. These two datasets have been used in earlier works of fair representation learning [30, 25, 47]. We will construct binary classification tasks on both datasets as detailed later.

Methodology of Evaluations. We will evaluate our algorithms from two perspectives: (i) quantitative performance on downstream classification tasks; (ii) qualitative visualization of learned representations. For quantitative evaluation, we first perform SSL by our algorithm on an unlabeled dataset with partial sensitive attribute information. Then we utilize a labeled training dataset for learning a linear classifier based on the learned representations, and then evaluate the accuracy and fairness metrics on testing data. This approach is known as linear evaluation in the literature of SSL [6].

Table 2: Results on CelebA: accuracy of predicting Attractive and fairness metrics for two sensitive attributes, Male and Young.
(Attractive, Male) Acc Δ\Delta ED Δ\Delta EO Δ\Delta DP IntraAUC InterAUC GAUC WD KL
CE 80.20 (±\pm 0.31) 25.55 (±\pm 0.27) 22.53 (±\pm 0.47) 45.40 (±\pm 0.56) 0.0024 (±\pm 1e-3) 0.2745 (±\pm 3e-3) 0.3053 (±\pm 3e-3) 0.3131 (±\pm 3e-3) 0.7153 (±\pm 4e-3)
CE + EOD 79.70 (±\pm 0.41) 22.18 (±\pm 0.31) 16.75 (±\pm 0.28) 41.65 (±\pm 0.44) 0.0014 (±\pm 1e-3) 0.2372 (±\pm 4e-3) 0.2897 (±\pm 2e-3) 0.2804 (±\pm 4e-3) 0.6189 (±\pm 5e-3)
CE + DPR 80.08 (±\pm 0.28) 23.74 (±\pm 0.48) 17.15 (±\pm 0.21) 43.06 (±\pm 0.34) 0.0051 (±\pm 5e-4) 0.2571 (±\pm 3e-3) 0.2981 (±\pm 3e-3) 0.2924 (±\pm 5e-3) 0.6761 (±\pm 4e-3)
CE + EQL 79.63 (±\pm 0.29) 25.10 (±\pm 0.36) 20.10 (±\pm 0.35) 44.50 (±\pm 0.38) 0.0024 (±\pm 4e-4) 0.2738 (±\pm 4e-3) 0.3037 (±\pm 4e-3) 0.2975 (±\pm 3e-3) 0.7177 (±\pm 4e-3)
ML-AFL 79.44(±\pm 0.32) 32.12 (±\pm 0.33) 23.39 (±\pm 0.41) 48.70 (±\pm 0.35) 0.0030 (±\pm 8e-4) 0.3561 (±\pm 5e-3) 0.3382 (±\pm 3e-3) 0.3341 (±\pm 3e-3) 0.9551 (±\pm 3e-3)
Max-Ent 79.46 (±\pm 0.28) 30.72 (±\pm 0.29) 18.42 (±\pm 0.38) 47.42 (±\pm 0.40) 0.0046 (±\pm 2e-3) 0.3241 (±\pm 4e-3) 0.3289 (±\pm 5e-3) 0.3083 (±\pm 4e-3) 0.9215 (±\pm 3e-3)
SimCLR 80.11 (±\pm 0.28) 26.58 (±\pm 0.34) 17.34 (±\pm 0.38) 44.95 (±\pm 0.32) 0.0055 (±\pm 1e-3) 0.2835 (±\pm 5e-3) 0.3211 (±\pm 4e-3) 0.2458 (±\pm 4e-3) 0.8276 (±\pm 4e-3)
SogCLR 80.53 (±\pm 0.25) 25.38 (±\pm 0.28) 18.71 (±\pm 0.33) 44.51 (±\pm 0.31) 0.0035 (±\pm 6e-4) 0.2659 (±\pm 4e-3) 0.3167 (±\pm 3e-3) 0.2432 (±\pm 3e-3) 0.8055 (±\pm 4e-3)
Boyl 79.58 (±\pm 0.29) 24.51 (±\pm 0.41) 20.99 (±\pm 0.37) 47.02 (±\pm 0.28) 0.0091 (±\pm 5e-4) 0.2713 (±\pm 7e-3) 0.3974 (±\pm 5e-3) 0.2367 (±\pm 5e-3) 0.7641 (±\pm 5e-3)
SimCLR+CCL 79.91 (±\pm 0.28) 22.19 (±\pm 0.35) 18.59 (±\pm 0.32) 39.58 (±\pm 0.30) 0.0069 (±\pm 6e-4) 0.3146 (±\pm 5e-3) 0.3059 (±\pm 4e-3) 0.2143 (±\pm 4e-3) 0.6408 (±\pm 5e-3)
SoFCLR 79.95 (±\pm 0.19) 14.93(±\pm 0.22) 12.60 (±\pm 0.25) 36.50 (±\pm 0.24) 0.0032 (±\pm 3e-4) 0.1592 (±\pm 2e-3) 0.2566(±\pm 2e-3) 0.1402 (±\pm 2e-3) 0.4743 (±\pm 4e-3)
(Attractive, Young) Acc Δ\Delta ED Δ\Delta EO Δ\DeltaDP IntraAUC InterAUC GAUC WD KL
CE 79.23 (±\pm 0.33) 22.79 (±\pm 0.31) 15.90 (±\pm 0.32) 40.47 (±\pm 0.34) 0.0358 (±\pm 2e-3) 0.2434 (±\pm 4e-4) 0.3129 (±\pm 3e-3) 0.3047 (±\pm 3e-3) 0.7275 (±\pm 4e-3)
CE + EOD 79.03 (±\pm 0.34) 22.78 (±\pm 0.28) 15.69 (±\pm 0.35) 41.81 (±\pm 0.38) 0.0403 (±\pm 3e-3) 0.2409 (±\pm 3e-4) 0.3142 (±\pm 4e-3) 0.3079 (±\pm 2e-3) 0.7434 (±\pm 3e-3)
CE + DPR 78.51 (±\pm 0.29) 22.08 (±\pm 0.30) 15.17 (±\pm 0.29) 40.83 (±\pm 0.41) 0.0396 (±\pm 2e-3) 0.2368 (±\pm 3e-4) 0.3110 (±\pm 3e-3) 0.3039 (±\pm 3e-3) 0.7165 (±\pm 3e-3)
CE + EQL 80.02 (±\pm 0.30) 22.09 (±\pm 0.34) 15.68 (±\pm 0.33) 41.53 (±\pm 0.35) 0.0390 (±\pm 2e-3) 0.2332 (±\pm 5e-4) 0.3095 (±\pm 4e-3) 0.3020 (±\pm 4e-3) 0.7082 (±\pm 4e-3)
ML-AFL 79.25 (±\pm 0.31) 31.97 (±\pm 0.31) 22.50 (±\pm 0.31) 48.70 (±\pm 0.29) 0.0451 (±\pm 3e-3) 0.3560 (±\pm 4e-4) 0.3380 (±\pm 4e-3) 0.3340 (±\pm 3e-3) 0.9551 (±\pm 3e-3)
MaxEnt-ALR 79.33 (±\pm 0.30 ) 30.59 (±\pm 0.29) 18.10 (±\pm 0.30) 46.99 (±\pm 0.34) 0.0420 (±\pm 4e-3) 0.2113 (±\pm 5e-4) 0.3117 (±\pm 3e-3) 0.2927 (±\pm 4e-3) 0.7285 (±\pm 3e-3)
SimCLR 79.97 (±\pm 0.29) 17.52 (±\pm 0.31) 18.50 (±\pm 0.31) 42.47 (±\pm 0.41) 0.0381 (±\pm 3e-3) 0.1909 (±\pm 4e-4) 0.2984 (±\pm 3e-3) 0.2098 (±\pm 3e-3) 0.6877 (±\pm 2e-3)
SogCLR 79.73 (±\pm 0.27) 17.21 (±\pm 0.28) 17.61 (±\pm 0.27) 42.01 (±\pm 0.32) 0.0365 (±\pm 3e-3) 0.1940 (±\pm 4e-4) 0.3021 (±\pm 4e-3 ) 0.2114 (±\pm 4e-3) 0.6782 (±\pm 3e-3)
Boyl 79.83 (±\pm 0.28) 17.03(±\pm 0.31) 18.03 (±\pm 0.29) 43.01 (±\pm 0.29) 0.0393 (±\pm 4e-3) 0.2001 (±\pm 8e-4) 0.3233 (±\pm 3e-3 ) 0.2214 (±\pm 5e-3) 0.6804 (±\pm 4e-3)
SimCLR+CCL 79.87 (±\pm 0.26) 17.18 (±\pm 0.31) 17.25( ±\pm 0.28) 42.05 (±\pm 0.30) 0.0385 (±\pm 3e-3) 0.1891 (±\pm 5e-4) 0.2824 (±\pm 4e-3 ) 0.2158 (±\pm 5e-3) 0.6532 (±\pm 4e-3)
SoFCLR 79.93 (±\pm 0.25) 15.34 (±\pm 0.27) 14.10 (±\pm 0.25) 40.05 (±\pm 0.26) 0.0336 (±\pm 2e-3) 0.1652 (±\pm 3e-4) 0.2824 (±\pm 2e-3) 0.1506 (±\pm 3e-3) 0.5905(±\pm 2e-3)

Baselines. We compare with 10 baselines from different categories, including fairness-unaware SSL methods, fairness-aware SSL methods, and conventional fairness-aware supervised learning methods. Fairness-unaware SSL methods include SimCLR [6], SogCLR [43], and Byol [14]. SimCLR and SogCLR are contrastive learning methods and Byol is non-contrastive SSL methods. For fairness-aware SSL, most existing approaches either assume all data have sensitive attribute information [24] or rely on an image generator that is trained separately [47]. In order to be fair with our algorihtm, we construct a strong baseline by combining the loss of SimCLR and CCL [24], where a mini-batch contrastive loss is defined on unlabled data, and a conditional contrastive loss is defined on unlabeled data with sensitive attribute information. We refer to baseline as SimCLR+CCL. For fairness-aware supervised learning, we consider two adversarial fair representation learning approaches namely Max-Ent [37] and Max-AFL proposed [41], and three fairness regularized approaches that optimize the cross-entropy (CE) loss and a fairness regularizer, including equalized odds regularizer (EOD), demographic disparity regularizer (DPR), equalized loss regularizer (EQL). These methods have been considered in previous works [8, 9]. We also include a reference method which just minimizes the CE loss on labeled data. All the experiments are conducted on a server with four GTX1080Ti GPUs.

Fairness Metrics. In order to evaluate the effectiveness of our algorithm, we evaluate a total of 8 fairness metrics. These include commonly used demographic disparity (Δ\Delta DP), equalized odds difference (Δ\Delta ED), and equalized opportunity (Δ\Delta EO), three AUC fairness metrics namely group AUC fairness (GAUC) [42], Inter-group AUC fairness (InterAUC) and Intra-group AUC fairness (IntraAUC) [2], and two distance metrics that measure the distance between the distributions of prediction scores of examples from different groups. To this end, we discretize the prediction scores on testing examples from each group into 100 buckets and calculate the KL-divergence and Wasserstein distance (WD) of two empirical distributions of two groups. For fairness metrics, smaller values indicate better performance.

Neural networks and optimizers. We utilize ResNet18 as the backbone network. For SSL, we add a two-layer MLP for the projection head with widths 256256 and 128128. For our algorithm, we use a two-layer MLP for predicting the sensitive attribute with a hidden dimension of 512512. The updates of model parameters follow the Adam update. The detailed information is described in Appendix C.

Hyperparameter tuning. Each method has some hyper-parameters including those in objective function and optimizers, e.g., the combination weight of a regular loss and fairness regularizer, the learning rate of optimizers, please check in Appendix C for tuning details. In addition, there are multiple fairness metrics besides the accuracay performance. For hyperparameter tuning, we divide the data into training, validation and testing sets. The validation and testing sets have target labels and sensitive attributes. For each method, we tune their parameters in order to obtain an accuracy within a certain threshold (1%1\%) of the standard CE baseline, and then report their different fairness metrics.

6.1 Prediction Results

Results on CelebA. Following the same setting as [30], we use two attributes ‘Male’ and ‘Young’ to define the sensitive attribute. Each attribute divides data into two groups according the binary annotation of each attribute. For target labels, we considier three attributes that have the highest Pearson correlation with the sensitive attribute, i.e, Attractive, Big  Nose, and Bags Under Eyes. Hence, we have six tasks in total. Due to limit of space, we only report the results for two tasks, (Attractive, Male), and (Attractive, Young), and include more results in the Appendix D for other tasks. We use the 80%/10%/10% splits for constructing the training, validation and testing datsets. We assume a subset of 5% random training examples have sensitive attribute information for training by SoFCLR, SimCLR + CCL, and supervised baseline methods. The supervised methods just use those 5% images including their target attribute labels and sensitive attributes for learning a model.

Refer to caption
Figure 1: Learned representations of 1000 testing examples from CelebA by different methods.

The results are presented in Table 2 for CelebA. From the results, we can observe that: (1) our method SoFCLR yields much fair results than existing fairness unaware SSL methods SimCLR, SogCLR and Byol; (2) compared with SimCLR-CCL, we can see that SoFCLR is more effective for obtaining more fair results; (3) the fairness-aware supervised methods are not that effective in our considered setting. This is probably because they only use 5% labeled data for learning the prediction model.

Table 3: Results on UTKFace: accuracy of predicting gender and fairness metrics in terms of two sensitive attributes, Age and Ethinicity.
(Gender, Age) Acc Δ\Delta ED Δ\Delta EO Δ\Delta DP IntraAUC InterAUC GAUC WD KL
SimCLR 85.74 (±\pm 0.31) 19.60 (±\pm 0.34) 28.32 (±\pm 0.41) 19.62 (±\pm 0.41) 0.0457 (±\pm 3e-4) 0.1287 (±\pm 4e-3) 0.1156 (±\pm 4e-3) 0.1512 (±\pm 3e-3) 0.1368 (±\pm 4e-3)
SogCLR 85.86 (±\pm 0.34) 17.83 (±\pm 0.32) 28.28 (±\pm 0.29) 17.58 (±\pm 0.31) 0.0471 (±\pm 4e-4) 0.1227 (±\pm 5e-3) 0.1145 (±\pm 3e-3) 0.1458 (±\pm 4e-3) 0.1416 (±\pm 3e-3)
Byol 85.37 (±\pm 0.37) 17.97(±\pm 0.36) 28.37(±\pm 0.25) 17.49 (±\pm 0.29) 0.0496 (±\pm 5e-4) 0.1221 (±\pm 4e-3) 0.1132 (±\pm 2e-3) 0.1467 (±\pm 5e-3) 0.1383 (±\pm 3e-3)
SimCLR+CCL 85.56 (±\pm 0.36) 16.83 (±\pm 0.33) 27.32 (±\pm 0.25) 17.08 (±\pm 0.28) 0.0483 (±\pm 4e-4) 0.1203 (±\pm 4e-3) 0.1098 (±\pm 3e-3) 0.1329 (±\pm 4e-3) 0.1374 (±\pm 3e-3)
SoFCLR 85.89 (±\pm 0.27) 15.42 (±\pm 0.28) 25.00 (±\pm 0.25) 15.49 (±\pm 0.26) 0.0466 (±\pm 3e-4) 0.1041 (±\pm 4e-3) 0.0901 (±\pm 2e-3) 0.1151 (±\pm 3e-3) 0.1012 (±\pm 2e-3)
(Gender, Ethnicity) Acc Δ\Delta ED Δ\Delta EO Δ\Delta DP IntraAUC InterAUC GAUC WD KL
SimCLR 83.58 (±\pm0.34) 17.23 (±\pm 0.24) 14.43 (±\pm 0.27) 17.21 (±\pm 0.34) 0.0091 (±\pm 6e-4) 0.1352 (±\pm 6e-3) 0.1375 (±\pm 5e-3) 0.1591 (±\pm 4e-3) 0.2017 (±\pm 5e-3)
SogCLR 84.03 (±\pm 0.37) 16.56 (±\pm 0.33) 13.83 (±\pm 0.28) 16.37 (±\pm 0.26) 0.0083 (±\pm 5e-4) 0.1284 (±\pm 5e-3) 0.1572 (±\pm 4e-3) 0.1353 (±\pm 5e-3) 0.1913 (±\pm 6e-3)
Byol 83.87 (±\pm 0.41) 16.92 (±\pm 0.29) 14.08 (±\pm 0.31) 16.85 (±\pm 0.27) 0.0087 (±\pm 6e-4) 0.1273 (±\pm 6e-3) 0.1523 (±\pm 4e-3) 0.1195 (±\pm 6e-3) 0.1237 (±\pm 5e-3)
SimCLR+CCL 83.59 (±\pm 0.30) 15.87 (±\pm 0.28) 13.70 (±\pm 0.29) 15.79 (±\pm 0.31) 0.0081 (±\pm 4e-4) 0.1192 (±\pm 4e-3) 0.1387 (±\pm 5e-3) 0.1195 (±\pm 4e-3) 0.1237 (±\pm 4e-3)
SoFCLR 84.42 (±\pm 0.27) 13.02 (±\pm 0.22) 13.23 (±\pm 0.24) 13.00 (±\pm 0.25) 0.0084 (±\pm 4e-4) 0.1013 (±\pm 4e-3) 0.1029 (±\pm 3e-3) 0.1195 (±\pm 4e-3) 0.1237 (±\pm 3e-3)
Table 4: Transfer learning results on UTKFace. SSL on CelebA, linear evaluation on UTKface with accuracy of predicting gender and fairness metrics in terms of age.
Acc Δ\Delta ED Δ\Delta EO Δ\Delta DP IntraAUC InterAUC GAUC WD KL
CE 93.06 (±\pm 0.37) 9.16 (±\pm 0.34) 8.95 (±\pm 0.31) 7.63 (±\pm 0.44) 0.0375 (±\pm 4e-3) 0.0245 (±\pm 4e-3) 0.0451 (±\pm 5e-3) 0.0684 (±\pm 4e-3) 0.1356 (±\pm 4e-3)
SimCLR 89.04 (±\pm 0.41) 6.89 (±\pm 0.35) 8.54 (±\pm 0.38) 5.91 (±\pm 0.48) 0.0362 (±\pm 5e-3) 0.0325 (±\pm 5e-3) 0.0303 (±\pm 4e-3) 0.0437 (±\pm 5e-3) 0.0769 (±\pm 6e-3)
SogCLR 89.74 (±\pm 0.35) 6.32 (±\pm 0.37) 8.16 (±\pm 0.42) 6.31 (±\pm 0.36) 0.0296 (±\pm 4e-3) 0.0321 (±\pm 4e-3) 0.0298 (±\pm 4e-3) 0.0436 (±\pm 3e-3) 0.0743 (±\pm 5e-3)
Boyl 89.83 (±\pm 0.43) 6.21 (±\pm 0.39) 8.09 (±\pm 0.31) 5.87 (±\pm 0.29) 0.0303 (±\pm 3e-3) 0.0331 (±\pm 7e-3) 0.0301 (±\pm 3e-3) 0.0478 (±\pm 4e-3) 0.0723 (±\pm 7e-3)
SimCLR+CCL 89.79 (±\pm 0.39) 5.78 (±\pm 0.37) 7.35 (±\pm 0.41) 5.67 (±\pm 0.49) 0.0292 (±\pm 4e-3) 0.0197 (±\pm 5e-3) 0.0283 (±\pm 4e-3) 0.0391 (±\pm 4e-3) 0.0683 (±\pm 6e-3)
SoFCLR 89.42 (±\pm 0.29) 4.46 (±\pm 0.30) 5.89 (±\pm 0.23) 4.49 (±\pm 0.27) 0.0282 (±\pm 3e-3) 0.0176 (±\pm 3e-3) 0.0203 (±\pm 3e-3) 0.0271 (±\pm 3e-3) 0.0573 (±\pm 4e-3)

Results on UTKface. There are three attributes for each image, i.e., gender, age, and ethnicity. In our experiments, we use gender as the target label, and the other two as the sensitive attribute. To control the imbalance of different sensitive attribute group, following the setting of [30], we manually construct imbalanced training sets in terms of the sensitive attribute that are highly correlate with the target label (e.g., Caucasian dominates the Male class) and keep the validation and testing data balanced. The details of training data statistics are described in Table 7.

We consider two experimental settings. The first setting is to learn both the representation network and the classifier on the same data. We split train/validation/test with 10000/3200/3200 images. However, since UTKface is a small dataset, 5% of the labeled training samples totaling 500, is not sufficient to learn a good ResNet18. Hence, we just compare different SSL methods on this dataset. For SoFCL and SimCLR-CCL, we still assume 5% of training data have sensitive attribute information. The linear evaluation is using all training images with target labels. The result shown in Table 4 indicate that SoFCLR outperforms all baselines by a large margin on all fairness metrics.

The second setting is to learn a representation network on CelebA using SSL and perform linear evaluation on the UTK data. This is valid because both data have attribute information related to age, which is the Young attribute in CelebA and the age attribute in UTKface. For SoFCLR and SimCLR-CCL, we still assume 5% of training examples of CelebA have sensitive attribute information. The results shown in Table 4 indicate that our SoFCLR still performs the best on all fairness metrics expcet for IntraAUC, while maintaining similar classification accuracy. The standard supervised method CE using all training examples clearly has better accuracy but is much less fair.

Refer to caption
Refer to caption
Figure 2: SoFCLR accuracy vs Δ\Delta ED balance (left), and adversarial loss evolution with varying α\alpha (right) on UTKFace data.

6.2 Fair representation visualization

We compare SoFCLR with SimCLR and SogCLR on the CelebA dataset using ‘Attractive’ as the target label and ‘Male’ as the sensitive attribute. We extract 1000 samples from the test dataset and generate a t-SNE visualization, as depicted in Figure 1. The results indicate that the learned representations by SimCLR and SogCLR are highly related to the sensitive attribute (gender) as indicated by the color. In contrast, the learned representations of SoFCLR removes the disparity impact of the sensitive attribute information. Nevertheless, the learned representations still maintain discriminative power for classifying target labels (attractive vs not-attractive) indicated by the shape.

6.3 Effectiveness of fairness regularizer

To demonstrate the impact of the adversarial fairness regularization in our method, we conduct an experiment on UTKFace data by varying the regularization parameter α\alpha. We use the the target label ‘gender’ and the sensitive attribute ‘age’. We vary α\alpha across a range of values, specifically, {0,0.1,0.3,0.5,0.7,0.9,1}\{0,0.1,0.3,0.5,0.7,0.9,1\}. Notably, when α=0\alpha=0, our proposed algorithm reduces to SogCLR. We report a Pareto curve of accuracy vs Δ\Delta ED in Figure 2 (left) and evolution of the adversarial loss for predicting the sensitive attribute in Figure 2 (right). We can see that by increasing the value α\alpha, our algorithm can effectively control the adverarial loss, which will make the downstream predictive model more fair as shown in the left figure. To ensure the comprehensiveness of our experiments, we also compare our SSL method with a VAE-based approach and extend our method to multi-valued sensitive attributes in Appendix D.

7 Conclusions

We have proposed a zero-sum game for fair self-supervised representation learning. We provided theoretical justification about the distributional representation fairness, and developed a stochastic algorithm for solving the minimax zero-sum game problems and established a convergence guarantee under mild conditions. Experiments on face image datasets demonstrate the effectiveness of our algorithm. One limitation of the work is that it focuses on single modality and multi-modality data would be an interesting future work.

References

  • [1] Solon Barocas, Moritz Hardt, and Arvind Narayanan. Fairness and Machine Learning. fairmlbook.org, 2019. http://www.fairmlbook.org.
  • [2] Alex Beutel, Jilin Chen, Tulsee Doshi, Hai Qian, Li Wei, Yi Wu, Lukasz Heldt, Zhe Zhao, Lichan Hong, Ed H Chi, et al. Fairness in recommendation ranking through pairwise comparisons. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 2212–2220, 2019.
  • [3] Alex Beutel, Jilin Chen, Zhe Zhao, and Ed H. Chi. Data decisions and theoretical implications when adversarially learning fair representations. CoRR, abs/1707.00075, 2017.
  • [4] Alice Bizeul and Carl Allen. SimVAE: Narrowing the gap between discriminative & generative representation learning. In NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning, 2023.
  • [5] Junyi Chai and Xiaoqian Wang. Self-supervised fair representation learning without demographics. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022.
  • [6] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive learning of visual representations. In International conference on machine learning, pages 1597–1607. PMLR, 2020.
  • [7] Xilun Chen, Ben Athiwaratkun, Yu Sun, Kilian Q. Weinberger, and Claire Cardie. Adversarial deep averaging networks for cross-lingual sentiment classification. CoRR, abs/1606.01614, 2016.
  • [8] Valeriia Cherepanova, Vedant Nanda, Micah Goldblum, John P Dickerson, and Tom Goldstein. Technical challenges for training fair neural networks. arXiv preprint arXiv:2102.06764, 2021.
  • [9] Michele Donini, Luca Oneto, Shai Ben-David, John Shawe-Taylor, and Massimiliano Pontil. Empirical risk minimization under fairness constraints, 2020.
  • [10] Harrison Edwards and Amos J. Storkey. Censoring representations with an adversary. In Yoshua Bengio and Yann LeCun, editors, 4th International Conference on Learning Representations, ICLR 2016, San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings, 2016.
  • [11] Yanai Elazar and Yoav Goldberg. Adversarial removal of demographic attributes from text data. In Conference on Empirical Methods in Natural Language Processing, 2018.
  • [12] Yaroslav Ganin, Evgeniya Ustinova, Hana Ajakan, Pascal Germain, Hugo Larochelle, Franccois Laviolette, Mario Marchand, and Victor Lempitsky. Domain-adversarial training of neural networks. J. Mach. Learn. Res., 17(1):2096–2030, jan 2016.
  • [13] Ziyu Gong, Ben Usman, Han Zhao, and David I. Inouye. Towards practical non-adversarial distribution alignment via variational bounds, 2023.
  • [14] Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Ávila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, and Michal Valko. Bootstrap your own latent: A new approach to self-supervised learning. CoRR, abs/2006.07733, 2020.
  • [15] Zhishuai Guo, Yi Xu, Wotao Yin, Rong Jin, and Tianbao Yang. A novel convergence analysis for algorithms of the adam family and beyond. arXiv e-prints, pages arXiv–2104, 2021.
  • [16] Umang Gupta, Aaron M. Ferber, Bistra Dilkina, and Greg Ver Steeg. Controllable guarantees for fair outcomes via contrastive information estimation. CoRR, abs/2101.04108, 2021.
  • [17] Sungwon Han, Seungeon Lee, Fangzhao Wu, Sundong Kim, Chuhan Wu, Xiting Wang, Xing Xie, and Meeyoung Cha. Dualfair: Fair representation learning at both group and individual levels via contrastive self-supervision. In Proceedings of the ACM Web Conference 2023, WWW ’23, page 3766–3774, New York, NY, USA, 2023. Association for Computing Machinery.
  • [18] Hamed Karimi, Julie Nutini, and Mark Schmidt. Linear convergence of gradient and proximal-gradient methods under the polyak-łojasiewicz condition. CoRR, abs/1608.04636, 2016.
  • [19] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • [20] Naveen Kodali, James Hays, Jacob Abernethy, and Zsolt Kira. On convergence and stability of GANs, 2018.
  • [21] Chaoyue Liu, Dmitriy Drusvyatskiy, Mikhail Belkin, Damek Davis, and Yi-An Ma. Aiming towards the minimizers: fast convergence of SGD for overparametrized problems. CoRR, abs/2306.02601, 2023.
  • [22] Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), December 2015.
  • [23] Christos Louizos, Kevin Swersky, Yujia Li, Max Welling, and Richard S. Zemel. The variational fair autoencoder. In Yoshua Bengio and Yann LeCun, editors, 4th International Conference on Learning Representations, ICLR 2016, San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings, 2016.
  • [24] Martin Q Ma, Yao-Hung Hubert Tsai, Paul Pu Liang, Han Zhao, Kun Zhang, Ruslan Salakhutdinov, and Louis-Philippe Morency. Conditional contrastive learning for improving fairness in self-supervised learning. arXiv preprint arXiv:2106.02866, 2021.
  • [25] Martin Q. Ma, Yao-Hung Hubert Tsai, Han Zhao, Kun Zhang, Louis-Philippe Morency, and Ruslan Salakhutdinov. Conditional contrastive learning: Removing undesirable information in self-supervised representations. CoRR, abs/2106.02866, 2021.
  • [26] Ninareh Mehrabi, Fred Morstatter, Nripsuta Saxena, Kristina Lerman, and Aram Galstyan. A survey on bias and fairness in machine learning. CoRR, abs/1908.09635, 2019.
  • [27] Daniel Moyer, Shuyang Gao, Rob Brekelmans, Greg Ver Steeg, and Aram Galstyan. Invariant representations without adversarial training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, NIPS’18, page 9102–9111, Red Hook, NY, USA, 2018. Curran Associates Inc.
  • [28] Maher Nouiehed, Maziar Sanjabi, Tianjian Huang, Jason D. Lee, and Meisam Razaviyayn. Solving a class of non-convex min-max games using iterative first order methods, 2019.
  • [29] Sungho Park, Jewook Lee, Pilhyeon Lee, Sunhee Hwang, D. Kim, and Hyeran Byun. Fair contrastive learning for facial attribute classification. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 10379–10388, 2022.
  • [30] Sungho Park, Jewook Lee, Pilhyeon Lee, Sunhee Hwang, Dohyung Kim, and Hyeran Byun. Fair contrastive learning for facial attribute classification. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10389–10398, 2022.
  • [31] Dana Pessach and Erez Shmueli. A review on fairness in machine learning. ACM Comput. Surv., 55(3), feb 2022.
  • [32] Qi Qi and Shervin Ardeshir. Improving identity-robustness for face models. arXiv preprint arXiv:2304.03838, 2023.
  • [33] Qi Qi, Yan Yan, Zixuan Wu, Xiaoyu Wang, and Tianbao Yang. A simple and effective framework for pairwise deep metric learning. In European Conference on Computer Vision, pages 375–391. Springer, 2020.
  • [34] Zi-Hao Qiu, Quanqi Hu, Yongjian Zhong, Lijun Zhang, and Tianbao Yang. Large-scale stochastic optimization of ndcg surrogates for deep learning with provable convergence, 2023.
  • [35] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, pages 8748–8763. PMLR, 2021.
  • [36] Proteek Chandan Roy and Vishnu Naresh Boddeti. Mitigating information leakage in image representations: A maximum entropy approach. CoRR, abs/1904.05514, 2019.
  • [37] Proteek Chandan Roy and Vishnu Naresh Boddeti. Mitigating information leakage in image representations: A maximum entropy approach. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2586–2594, 2019.
  • [38] Christina Wadsworth, Francesca Vera, and Chris Piech. Achieving fairness through adversarial learning: an application to recidivism prediction. CoRR, abs/1807.00199, 2018.
  • [39] Bokun Wang and Tianbao Yang. Finite-sum coupled compositional stochastic optimization: Theory and applications. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 23292–23317. PMLR, 17–23 Jul 2022.
  • [40] Qizhe Xie, Zihang Dai, Yulun Du, Eduard Hovy, and Graham Neubig. Controllable invariance through adversarial feature learning. In Proceedings of the 31st International Conference on Neural Information Processing Systems, NIPS’17, page 585–596, Red Hook, NY, USA, 2017. Curran Associates Inc.
  • [41] Qizhe Xie, Zihang Dai, Yulun Du, Eduard Hovy, and Graham Neubig. Controllable invariance through adversarial feature learning. Advances in neural information processing systems, 30, 2017.
  • [42] Yao Yao, Qihang Lin, and Tianbao Yang. Stochastic methods for auc optimization subject to auc-based fairness constraints. In International Conference on Artificial Intelligence and Statistics, pages 10324–10342. PMLR, 2023.
  • [43] Zhuoning Yuan, Yuexin Wu, Zihao Qiu, Xianzhi Du, Lijun Zhang, Denny Zhou, and Tianbao Yang. Provable stochastic optimization for global contrastive learning: Small batch does not harm performance. arXiv preprint arXiv:2202.12387, 2022.
  • [44] Zhuoning Yuan, Yan Yan, Rong Jin, and Tianbao Yang. Stagewise training accelerates convergence of testing error over SGD. In Hanna M. Wallach, Hugo Larochelle, Alina Beygelzimer, Florence d’Alché-Buc, Emily B. Fox, and Roman Garnett, editors, Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, pages 2604–2614, 2019.
  • [45] Rich Zemel, Yu Wu, Kevin Swersky, Toni Pitassi, and Cynthia Dwork. Learning fair representations. In Sanjoy Dasgupta and David McAllester, editors, Proceedings of the 30th International Conference on Machine Learning, volume 28 of Proceedings of Machine Learning Research, pages 325–333, Atlanta, Georgia, USA, 17–19 Jun 2013. PMLR.
  • [46] Brian Hu Zhang, Blake Lemoine, and Margaret Mitchell. Mitigating unwanted biases with adversarial learning. In Proceedings of the 2018 AAAI/ACM Conference on AI, Ethics, and Society, AIES ’18, page 335–340, New York, NY, USA, 2018. Association for Computing Machinery.
  • [47] Fengda Zhang, Kun Kuang, Long Chen, Yuxuan Liu, Chao Wu, and Jun Xiao. Fairness-aware contrastive learning with partially annotated sensitive attributes. In The Eleventh International Conference on Learning Representations, 2023.
  • [48] Song Yang Zhang, Zhifei and Hairong Qi. Age progression/regression by conditional adversarial autoencoder. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.
  • [49] Han Zhao and Geoffrey J. Gordon. Inherent tradeoffs in learning fair representation. CoRR, abs/1906.08386, 2019.

Appendix A Fairness Verification

Proof of Theorem 1.

Denote by Dk(E(𝐱))D_{k}(E({\bf{x}})) as the kk-th element of D(E(𝐱))D(E({\bf{x}})) for the kk-th value of the sensitive attribute. Then k=1KDk(E(𝐱))=1\sum_{k=1}^{K}D_{k}(E({\bf{x}}))=1. We define the minimax problem as:

minEmaxD𝔼𝐱,a\displaystyle\min_{E}\max_{D}\mathbb{E}_{{\bf{x}},a} [k=1Kδ(a,k)logDk(E(𝐱))]\displaystyle\left[\sum_{k=1}^{K}\delta(a,k)\log D_{k}(E({\bf{x}}))\right]

Let us first fix EE and optimize DD. The objective is equivalent to

𝔼xk=1Kp(a=k|E(𝐱))logDk(E(𝐱))\displaystyle\mathbb{E}_{x}\sum_{k=1}^{K}p(a=k|E({\bf{x}}))\log D_{k}(E({\bf{x}}))

By maximizing D(E(𝐱))D(E({\bf{x}})), we have Dk(E(𝐱))=p(a=k|E(𝐱))D_{k}(E({\bf{x}}))=p(a=k|E({\bf{x}})). Then we have the following objective for EE:

𝔼𝐱,a\displaystyle\mathbb{E}_{{\bf{x}},a} [k=1Kδ(a,k)logp(a=k|E)]=𝔼𝐱,a[k=1Kδ(a,k)logp(E|a=k)p(a=k)p(E)]\displaystyle\left[\sum_{k=1}^{K}\delta(a,k)\log p(a=k|E)\right]\ =\mathbb{E}_{{\bf{x}},a}\left[\sum_{k=1}^{K}\delta(a,k)\log\frac{p(E|a=k)p(a=k)}{p(E)}\right]\
=𝔼𝐱,a[k=1Kδ(a,k)logp(a=k)]+𝔼𝐱,a[k=1Kδ(a,k)logp(E|a=k)p(E)]\displaystyle=\mathbb{E}_{{\bf{x}},a}\left[\sum_{k=1}^{K}\delta(a,k)\log p(a=k)\right]+\mathbb{E}_{{\bf{x}},a}\left[\sum_{k=1}^{K}\delta(a,k)\log\frac{p(E|a=k)}{p(E)}\right]\
=C+𝔼𝐱,a[logp(E|a)p(E)]=C+𝔼a𝔼p(E|a)[logp(E|a)p(E)]=C+𝔼a[KL(p(E|a),p(E))]\displaystyle=C+\mathbb{E}_{{\bf{x}},a}\left[\log\frac{p(E|a)}{p(E)}\right]=C+\mathbb{E}_{a}\mathbb{E}_{p(E|a)}\left[\log\frac{p(E|a)}{p(E)}\right]=C+\mathbb{E}_{a}[\text{KL}(p(E|a),p(E))]

where CC is independent of EE. Hence by minimizing over EE we have the optimal EE_{*} satisfying p(E|a)=p(E)p(E_{*}|a)=p(E_{*}). As a result, [D(E(𝐱))]k=p(a=k|E(𝐱))=p(a=k)[D_{*}(E_{*}({\bf{x}}))]_{k}=p(a=k|E_{*}({\bf{x}}))=p(a=k). ∎

Appendix B Convergence Analysis of SoFCLR

For simplicity, we use the following notations in this section,

F1(𝐰)=𝔼𝐱,𝒜,𝒜f1(𝐰;𝐱,𝒜,𝒜),\displaystyle F_{1}({\bf{w}})=\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}f_{1}({\bf{w}};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}),
ϕ(𝐰,𝐰)=𝔼(𝐱,a)𝒟a,𝒜𝒫{ϕ(𝐰,𝐰;𝒜(𝐱),a)},\displaystyle\phi({\bf{w}},{\bf{w}}^{\prime})=\mathbb{E}_{({\bf{x}},a)\sim\mathcal{D}_{a},\mathcal{A}\sim\mathcal{P}}\{\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a)\},
g(𝐰)=[g(𝐰;𝐱1,𝒮1),,g(𝐰;𝐱n,𝒮n)],\displaystyle g({\bf{w}})=[g({\bf{w}};{\bf{x}}_{1},\mathcal{S}_{1}^{-}),\dots,g({\bf{w}};{\bf{x}}_{n},\mathcal{S}_{n}^{-})],
F2(𝐰)=1ni=1nf2(g(𝐰;𝐱i,𝒮i)),\displaystyle F_{2}({\bf{w}})=\frac{1}{n}\sum_{i=1}^{n}f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-})),
F3(𝐰)=max𝐰dαϕ(𝐰,𝐰).\displaystyle F_{3}({\bf{w}})=\max_{{\bf{w}}^{\prime}\in\mathbb{R}^{d^{\prime}}}\alpha\phi({\bf{w}},{\bf{w}}^{\prime}).

Then the problem

min𝐰dmax𝐰dF(𝐰,𝐰):=FGCL(𝐰)+αFfair(𝐰,𝐰)\displaystyle\min_{{\bf{w}}\in\mathbb{R}^{d}}\max_{{\bf{w}}^{\prime}\in\mathbb{R}^{d^{\prime}}}F({\bf{w}},{\bf{w}}^{\prime}):=F_{\text{GCL}}({\bf{w}})+\alpha F_{\text{fair}}({\bf{w}},{\bf{w}}^{\prime}) (8)
=𝔼𝐱,𝒜,𝒜f1(𝐰;𝐱,𝒜,𝒜)+1ni=1nf2(g(𝐰;𝐱i,𝒮i))+α𝔼(𝐱,a)𝒟a,𝒜𝒫{ϕ(𝐰,𝐰;𝒜(𝐱),a)}\displaystyle=\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}f_{1}({\bf{w}};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{n}\sum_{i=1}^{n}f_{2}(g({\bf{w}};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))+\alpha\mathbb{E}_{({\bf{x}},a)\sim\mathcal{D}_{a},\mathcal{A}\sim\mathcal{P}}\{\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}),a)\}

can be written as

min𝐰dmax𝐰dF(𝐰,𝐰)=min𝐰dΦ(𝐰)\displaystyle\min_{{\bf{w}}\in\mathbb{R}^{d}}\max_{{\bf{w}}^{\prime}\in\mathbb{R}^{d^{\prime}}}F({\bf{w}},{\bf{w}}^{\prime})=\min_{{\bf{w}}\in\mathbb{R}^{d}}\Phi({\bf{w}}) (9)

where

Φ(𝐰)=F1(𝐰)+F2(𝐰)+F3(𝐰).\Phi({\bf{w}})=F_{1}({\bf{w}})+F_{2}({\bf{w}})+F_{3}({\bf{w}}).
Assumption 2.

We make the following assumptions.

  1. (a)

    Φ()\Phi(\cdot) is LL-smooth.

  2. (b)

    f2()f_{2}(\cdot) is differentiable, Lf2L_{f_{2}}-smooth, Cf2C_{f_{2}}-Lipchitz continuous;

  3. (c)

    g(;𝐱i,𝒮i)g(\cdot;{\bf{x}}_{i},\mathcal{S}_{i}^{-}) is differentiable, CgC_{g}-Lipchitz continuous for all 𝐱i𝒟{\bf{x}}_{i}\in\mathcal{D}.

  4. (d)

    ϕ(𝐰,𝐰)\nabla\phi({\bf{w}},{\bf{w}}^{\prime}) is LϕL_{\phi}-Lipschitz continuous.

  5. (e)

    There exists ΔΦ<\Delta_{\Phi}<\infty such that Φ(𝐰1)ΦΔΦ\Phi({\bf{w}}_{1})-\Phi^{*}\leq\Delta_{\Phi}.

  6. (f)

    The stochastic estimators g(𝐰;𝒜(𝐱i),𝐱i)g({\bf{w}};\mathcal{A}({\bf{x}}_{i}),{\bf{x}}_{i}^{-}), f1(𝐰;𝐱,𝒜,𝒜)f_{1}({\bf{w}};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}), 𝐰g(𝐰;𝒜(𝐱i),𝐱i)\nabla_{\bf{w}}g({\bf{w}};\mathcal{A}({\bf{x}}_{i}),{\bf{x}}_{i}^{-}), 𝐰ϕ(𝐰,𝐰;𝒜(𝐱i),ai)\nabla_{\bf{w}}\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}_{i}),a_{i}), 𝐰ϕ(𝐰,𝐰;𝒜(𝐱i),ai)\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}},{\bf{w}}^{\prime};\mathcal{A}({\bf{x}}_{i}),a_{i}) have bounded variance σ2\sigma^{2}.

  7. (g)

    For any 𝐰{\bf{w}}, ϕ(𝐰,)-\phi({\bf{w}},\cdot) is λ\lambda-one-point strongly convex, i.e.,

    (𝐰𝐰)𝐰ϕ(𝐰,𝐰)λ𝐰𝐰2,-({\bf{w}}^{\prime}-{\bf{w}}^{\prime}_{*})^{\top}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}},{\bf{w}}^{\prime})\geq\lambda\|{\bf{w}}^{\prime}-{\bf{w}}^{\prime}_{*}\|^{2},

    where 𝐰argminvϕ(𝐰,v){\bf{w}}^{\prime}_{*}\in\arg\min_{\textbf{v}}\phi({\bf{w}},\textbf{v}) is one optimal solution closest to 𝐰{\bf{w}}^{\prime}.

Note that Assumption 2 is generalized from Assumption 1, which fits for all problems in the formulation of Problem (8). Next, we show that Assumption 1 implies Assumption 2. In fact, it suffices to show that (b,c)(b,c) of Assumption 1 implies (b,c,d)(b,c,d) of Assumption 2. The exact formulation of f2()=τlog(ϵ0+)f_{2}(\cdot)=\tau\log(\epsilon_{0}^{\prime}+\cdot) naturally leads to its differentiability, Lipschitz continuity and smoothness. Given the formulation g(𝐰,𝐱i,𝒮i)=𝔼𝐱~𝒮i𝔼𝒜exp(E(𝒜(𝐱i))E1(𝐱~)/τ)g({\bf{w}},{\bf{x}}_{i},\mathcal{S}_{i}^{-})=\mathbb{E}_{\tilde{{\bf{x}}}\sim\mathcal{S}_{i}^{-}}\mathbb{E}_{\mathcal{A}}\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E_{1}(\tilde{{\bf{x}}})/\tau), the boundedness and Lipschitz continuity of E(𝐱),E1(𝐱)E({\bf{x}}),E_{1}({\bf{x}}), the gradient

𝐰g(𝐰,𝐱i,𝒮i)=𝔼𝐱~𝒮i𝔼𝒜E(𝒜(𝐱i))E1(𝐱~)+E(𝒜(𝐱i))E1(𝐱~)τexp(E(𝒜(𝐱i))E1(𝐱~)/τ)\nabla_{\bf{w}}g({\bf{w}},{\bf{x}}_{i},\mathcal{S}_{i}^{-})=\mathbb{E}_{\tilde{{\bf{x}}}\sim\mathcal{S}_{i}^{-}}\mathbb{E}_{\mathcal{A}}\frac{\nabla E(\mathcal{A}({\bf{x}}_{i}))^{\top}E_{1}(\tilde{{\bf{x}}})+E(\mathcal{A}({\bf{x}}_{i}))^{\top}\nabla E_{1}(\tilde{{\bf{x}}})}{\tau}\exp(E(\mathcal{A}({\bf{x}}_{i}))^{\top}E_{1}(\tilde{{\bf{x}}})/\tau)

is bounded by a finite constant. The smoothness of ϕ(𝐰,𝐰)\phi({\bf{w}},{\bf{w}}^{\prime}) follows from the smoothness and Lipschitz continuity of 𝒟𝐰\mathcal{D}_{{\bf{w}}^{\prime}} and E𝐰E_{{\bf{w}}}.

It has been shown in the literature that smoothness and one-point strong convexity imply PL condition [44].

Lemma 1.

[Lemma 9 in [44]] Suppose hh is LhL_{h}-smooth and μ\mu-one-point strongly convex w.r.t. 𝐰{\bf{w}}^{*} with h(𝐰)=0\nabla h({\bf{w}}^{*})=0, then

h(𝐰)22μLh(h(𝐰)h(𝐰)).\|\nabla h({\bf{w}})\|^{2}\geq\frac{2\mu}{L_{h}}(h({\bf{w}})-h({\bf{w}}^{*})).

Thus, under Assumption 2, ϕ(𝐰,)-\phi({\bf{w}},\cdot) satisfies λLϕ\frac{\lambda}{L_{\phi}}-PL condition. Following a similar proof to Lemma 17 in [15], we have the following lemma.

Lemma 2.

Suppose Assumption 2 holds. Consider the update 𝐰t+1=𝐰t+ηvt+1{\bf{w}}^{\prime}_{t+1}={\bf{w}}^{\prime}_{t}+\eta^{\prime}\textbf{v}_{t+1} in Algorithm 1. With ηmin{λ/(2Lϕ2),4/λ}\eta^{\prime}\leq\min\left\{\lambda/(2L_{\phi}^{2}),4/\lambda\right\}, we have that for any 𝐰(𝐰t)argmax𝐰ϕ(𝐰t,𝐰){\bf{w}}^{\prime}({\bf{w}}_{t})\in\operatorname*{arg\,max}_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}) there is a 𝐰(𝐰t+1)argmax𝐰ϕ(𝐰t+1,𝐰){\bf{w}}^{\prime}({\bf{w}}_{t+1})\in\operatorname*{arg\,max}_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t+1},{\bf{w}}^{\prime}) such that

𝔼[𝐰t+1𝐰(𝐰t+1)2](1ηλ4)𝔼[𝐰t𝐰(𝐰t)2]+η2σ2|a|+2Lϕ4ηλ3𝔼[𝐰t𝐰t+12]\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t+1})\|^{2}]\leq(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+\frac{\eta^{\prime 2}\sigma^{2}}{|\mathcal{B}_{a}|}+\frac{2L_{\phi}^{4}}{\eta^{\prime}\lambda^{3}}\mathbb{E}[\|{\bf{w}}_{t}-{\bf{w}}_{t+1}\|^{2}]
Lemma 3 (Lemma 9 in [34]).

Suppose Assumption 2 holds. Consider the update

𝐮i,t+1={(1γ)𝐮i,t+γ2[g(𝐰t;𝒜(𝐱i),i)+g(𝐰t;𝒜(𝐱i),i)],𝐱i𝐮i,t,𝐱i\displaystyle{\bf{u}}_{i,t+1}=\begin{cases}(1-\gamma){\bf{u}}_{i,t}+\frac{\gamma}{2}\left[g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})+g({\bf{w}}_{t};\mathcal{A}^{\prime}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})\right],\quad&{\bf{x}}_{i}\in\mathcal{B}\\ {\bf{u}}_{i,t},&{\bf{x}}_{i}\not\in\mathcal{B}\end{cases}

With γ<1/2\gamma<1/2, we have

𝔼[1n𝐮t+1g(𝐰t+1)2](1||γ2n)𝔼[1n𝐮tg(𝐰t)2]+4Bγ2σ22n|i|+4nCg2||γ𝔼[𝐰t𝐰t+12].\mathbb{E}[\frac{1}{n}\|{\bf{u}}^{t+1}-g({\bf{w}}_{t+1})\|^{2}]\leq(1-\frac{|\mathcal{B}|\gamma}{2n})\mathbb{E}[\frac{1}{n}\|{\bf{u}}^{t}-g({\bf{w}}_{t})\|^{2}]+\frac{4B\gamma^{2}\sigma^{2}}{2n|\mathcal{B}_{i}^{-}|}+\frac{4nC_{g}^{2}}{|\mathcal{B}|\gamma}\mathbb{E}[\|{\bf{w}}_{t}-{\bf{w}}_{t+1}\|^{2}].
Lemma 4.

Suppose Assumption 2 holds. Considering the update 𝐰t+1=𝐰tη𝐦~t+1{\bf{w}}_{t+1}={\bf{w}}_{t}-\eta\widetilde{{\bf{m}}}_{t+1}, with η1/(2L)\eta\leq 1/(2L), we have

Φ(𝐰t+1)\displaystyle\Phi({\bf{w}}_{t+1}) Φ(𝐰t)+η2Φ(𝐰t)𝐦~t+12η2Φ(𝐰t)2η4𝐦~t+12.\displaystyle\leq\Phi({\bf{w}}_{t})+\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})-\tilde{{\bf{m}}}_{t+1}\|^{2}-\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})\|^{2}-\frac{\eta}{4}\|\tilde{{\bf{m}}}_{t+1}\|^{2}.

Now we present a formal statement of Theorem 2.

Theorem 3.

Suppose Assumption 2 holds. With β=min{1,min{||,|i|,|a|}ϵ212C1}\beta=\min\left\{1,\frac{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}\epsilon^{2}}{12C_{1}}\right\}, γ=min{12,|i|ϵ2384Cg2Lf22σ2}\gamma=\min\left\{\frac{1}{2},\frac{|\mathcal{B}_{i}^{-}|\epsilon^{2}}{384C_{g}^{2}L_{f_{2}}^{2}\sigma^{2}}\right\}, η=min{λ2Lϕ2,λ|a|ϵ2384α2Lϕ2σ2}\eta^{\prime}=\min\left\{\frac{\lambda}{2L_{\phi}^{2}},\frac{\lambda|\mathcal{B}_{a}|\epsilon^{2}}{384\alpha^{2}L_{\phi}^{2}\sigma^{2}}\right\}, η=min{12L,2βL,||γ32Lf2Cg2n,ηλ64αLϕκ}\eta=\min\left\{\frac{1}{2L},\frac{2\beta}{L},\frac{|\mathcal{B}|\gamma}{32L_{f_{2}}C_{g}^{2}n},\frac{\eta^{\prime}\lambda}{64\alpha L_{\phi}\kappa}\right\}, after

T4ΛPηϵ2\displaystyle T\geq\frac{4\Lambda_{P}}{\eta\epsilon^{2}} =4ΛPϵ2max{2L,L2β,32Lf2Cg2n||γ,64αLϕκηλ}\displaystyle=\frac{4\Lambda_{P}}{\epsilon^{2}}\max\left\{2L,\frac{L}{2\beta},\frac{32L_{f_{2}}C_{g}^{2}n}{|\mathcal{B}|\gamma},\frac{64\alpha L_{\phi}\kappa}{\eta^{\prime}\lambda}\right\}
=4ΛPϵ2max{2L,6C1Lmin{||,|i|,|a|}ϵ2,64Lf2Cg2n||,12288nCg4Lf23σ2|||i|ϵ2,128αLϕ3κλ2,24576α3Lϕ3κσ2λ2|a|ϵ2}\displaystyle=\frac{4\Lambda_{P}}{\epsilon^{2}}\max\left\{2L,\frac{6C_{1}L}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}\epsilon^{2}},\frac{64L_{f_{2}}C_{g}^{2}n}{|\mathcal{B}|},\frac{12288nC_{g}^{4}L_{f_{2}}^{3}\sigma^{2}}{|\mathcal{B}||\mathcal{B}_{i}^{-}|\epsilon^{2}},\frac{128\alpha L_{\phi}^{3}\kappa}{\lambda^{2}},\frac{24576\alpha^{3}L_{\phi}^{3}\kappa\sigma^{2}}{\lambda^{2}|\mathcal{B}_{a}|\epsilon^{2}}\right\}

iterations, Algorithm 1 ensures 1Tt=1T𝔼[Φ(𝐰t)2]ϵ2\frac{1}{T}\sum_{t=1}^{T}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]\leq\epsilon^{2}.

B.1 Proof of Lemma 2

To prove Lemma 2, we need the following lemma.

Lemma 5.

[Lemma A.3 in [28]] Assume h(x,y)h(x,y) is LhL_{h} smooth and h(x,y)-h(x,y) satisfies μ\mu-PL condition w.r.t. yy. For any x1,x2x_{1},x_{2} and y(x1)argmaxyh(x1,y)y(x_{1})\in\operatorname*{arg\,max}_{y^{\prime}}h(x_{1},y^{\prime}), there exists y(x2)argmaxyh(x2,y)y(x_{2})\in\operatorname*{arg\,max}_{y^{\prime}}h(x_{2},y^{\prime}) such that

y(x1)y(x2)Lh2μx1x2.\|y(x_{1})-y(x_{2})\|\leq\frac{L_{h}}{2\mu}\|x_{1}-x_{2}\|.

Recall that we assume ϕ(𝐰,𝐰)-\phi({\bf{w}},{\bf{w}}^{\prime}) to be λ\lambda-one-point strongly convex in 𝐰{\bf{w}}^{\prime}, and Lemma 1 implies that ϕ(𝐰,𝐰)-\phi({\bf{w}},{\bf{w}}^{\prime}) satisfies λLϕ\frac{\lambda}{L_{\phi}}-condition w.r.t. 𝐰{\bf{w}}^{\prime}. Thus, by Lemma 5, 𝐰(𝐰)argmax𝐰ϕ(𝐰,𝐰){\bf{w}}^{\prime}({\bf{w}})\in\operatorname*{arg\,max}_{{\bf{w}}^{\prime}}\phi({\bf{w}},{\bf{w}}^{\prime}) is Lϕ22λ\frac{L_{\phi}^{2}}{2\lambda}-Lipschitz continuous. Then we follow the proof of Lemma 17 in [15] to prove Lemma 2. We would like to emphasize that concavity is irrelevant in Lemma 2, while it is required in Lemma 17 in [15].

Proof.
𝔼[𝐰t+1𝐰(𝐰t)2]\displaystyle\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}] (10)
=𝔼[𝐰t+ηvt+1𝐰(𝐰t)2]\displaystyle=\mathbb{E}[\|{\bf{w}}^{\prime}_{t}+\eta^{\prime}\textbf{v}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]
=𝔼[𝐰t+ηvt+1𝐰(𝐰t)+η𝐰ϕ(𝐰t,𝐰t)η𝐰ϕ(𝐰t,𝐰t)2]\displaystyle=\mathbb{E}[\|{\bf{w}}^{\prime}_{t}+\eta^{\prime}\textbf{v}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t})+\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})-\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]
=𝔼[𝐰t𝐰(𝐰t)+η𝐰ϕ(𝐰t,𝐰t)2]+η2𝔼[vt+1η𝐰ϕ(𝐰t,𝐰t)2]\displaystyle=\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})+\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]+\eta^{\prime 2}\mathbb{E}[\|\textbf{v}_{t+1}-\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]
𝔼[𝐰t𝐰(𝐰t)+η𝐰ϕ(𝐰t,𝐰t)2]+η2σ22|a|\displaystyle\leq\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})+\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]+\frac{\eta^{\prime 2}\sigma^{2}}{2|\mathcal{B}_{a}|}

where

𝔼[𝐰t𝐰(𝐰t)+η𝐰ϕ(𝐰t,𝐰t)2]\displaystyle\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})+\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]
=𝔼[𝐰t𝐰(𝐰t)η𝐰ϕ(𝐰t,𝐰(𝐰t))+η𝐰ϕ(𝐰t,𝐰t)2]\displaystyle=\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})-\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}({\bf{w}}_{t}))+\eta^{\prime}\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]
=𝔼[𝐰t𝐰(𝐰t)2]+η2𝔼[𝐰ϕ(𝐰t,𝐰(𝐰t))𝐰ϕ(𝐰t,𝐰t)2]\displaystyle=\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+\eta^{\prime 2}\mathbb{E}[\|\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}({\bf{w}}_{t}))-\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]
+2η𝐰t𝐰(𝐰t),𝐰ϕ(𝐰t,𝐰t)\displaystyle\quad+2\eta^{\prime}\langle{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t}),\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\rangle
(a)𝔼[𝐰t𝐰(𝐰t)2]+η2𝔼[𝐰ϕ(𝐰t,𝐰(𝐰t))𝐰ϕ(𝐰t,𝐰t)2]ηλ𝔼[𝐰t𝐰(𝐰t)2]\displaystyle\stackrel{{\scriptstyle(a)}}{{\leq}}\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+\eta^{\prime 2}\mathbb{E}[\|\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}({\bf{w}}_{t}))-\nabla_{{\bf{w}}^{\prime}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\|^{2}]-\eta^{\prime}\lambda\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]
(b)(1ηλ2)𝔼[𝐰t𝐰(𝐰t)2]\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}(1-\frac{\eta^{\prime}\lambda}{2})\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]

where inequality (a)(a) uses the λ\lambda-one-point strong convexity of ϕ(𝐰,)-\phi({\bf{w}},\cdot), and inequality (b)(b) uses the assumption ηλ/(2Lϕ2)\eta^{\prime}\leq\lambda/(2L_{\phi}^{2}).

Then we have

𝔼[𝐰t+1𝐰(𝐰t+1)2]\displaystyle\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t+1})\|^{2}]
(1+ηλ4)𝔼[𝐰t+1𝐰(𝐰t)2]+(1+4ηλ)𝔼[𝐰(𝐰t+1)𝐰(𝐰t)2]\displaystyle\leq(1+\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+(1+\frac{4}{\eta^{\prime}\lambda})\mathbb{E}[\|{\bf{w}}^{\prime}({\bf{w}}_{t+1})-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]
(c)(1ηλ4)𝔼[𝐰t𝐰(𝐰t)2]+(1+ηλ4)η2σ22|a|+(1+4ηλ)Lϕ44λ2𝔼[𝐰t+1𝐰t2]\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+(1+\frac{\eta^{\prime}\lambda}{4})\frac{\eta^{\prime 2}\sigma^{2}}{2|\mathcal{B}_{a}|}+(1+\frac{4}{\eta^{\prime}\lambda})\frac{L_{\phi}^{4}}{4\lambda^{2}}\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}_{t}\|^{2}]
(d)(1ηλ4)𝔼[𝐰t𝐰(𝐰t)2]+η2σ2|a|+2Lϕ4ηλ3𝔼[𝐰t+1𝐰t2].\displaystyle\stackrel{{\scriptstyle(d)}}{{\leq}}(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2}]+\frac{\eta^{\prime 2}\sigma^{2}}{|\mathcal{B}_{a}|}+\frac{2L_{\phi}^{4}}{\eta^{\prime}\lambda^{3}}\mathbb{E}[\|{\bf{w}}^{\prime}_{t+1}-{\bf{w}}^{\prime}_{t}\|^{2}].

where inequality (c)(c) uses the Lϕ22λ\frac{L_{\phi}^{2}}{2\lambda}-Lipschitz continuity of 𝐰(){\bf{w}}^{\prime}(\cdot) and inequality 10, and inequality (d)(d) uses the assumption η4/λ\eta^{\prime}\leq 4/\lambda. ∎

B.2 Proof of Lemma 4

Proof.

By the smoothness of Φ(𝐰)\Phi({\bf{w}}), we have

Φ(𝐰t+1)\displaystyle\Phi({\bf{w}}_{t+1}) Φ(𝐰t)+Φ(𝐰t),𝐰t+1𝐰t+L2𝐰t+1𝐰t2\displaystyle\leq\Phi({\bf{w}}_{t})+\langle\nabla\Phi({\bf{w}}_{t}),{\bf{w}}_{t+1}-{\bf{w}}_{t}\rangle+\frac{L}{2}\|{\bf{w}}_{t+1}-{\bf{w}}_{t}\|^{2}
=Φ(𝐰t)+ηΦ(𝐰t),𝐦~t+1+Lη22𝐦~t+12\displaystyle=\Phi({\bf{w}}_{t})+\eta\langle\nabla\Phi({\bf{w}}_{t}),\tilde{{\bf{m}}}_{t+1}\rangle+\frac{L\eta^{2}}{2}\|\tilde{{\bf{m}}}_{t+1}\|^{2}
=Φ(𝐰t)+η2Φ(𝐰t)𝐦~t+12η2Φ(𝐰t)2+(Lη22η2)𝐦~t+12\displaystyle=\Phi({\bf{w}}_{t})+\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})-\tilde{{\bf{m}}}_{t+1}\|^{2}-\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})\|^{2}+(\frac{L\eta^{2}}{2}-\frac{\eta}{2})\|\tilde{{\bf{m}}}_{t+1}\|^{2}
Φ(𝐰t)+η2Φ(𝐰t)𝐦~t+12η2Φ(𝐰t)2η4𝐦~t+12\displaystyle\leq\Phi({\bf{w}}_{t})+\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})-\tilde{{\bf{m}}}_{t+1}\|^{2}-\frac{\eta}{2}\|\nabla\Phi({\bf{w}}_{t})\|^{2}-\frac{\eta}{4}\|\tilde{{\bf{m}}}_{t+1}\|^{2}

where the last inequality uses ηL1/2\eta L\leq 1/2.

B.3 Proof of Theorem 3

Proof.

The formulation of the gradient is given by

Φ(𝐰t)\displaystyle\nabla\Phi({\bf{w}}_{t}) =F1(𝐰t)+F2(𝐰t)+F3(𝐰t)\displaystyle=\nabla F_{1}({\bf{w}}_{t})+\nabla F_{2}({\bf{w}}_{t})+\nabla F_{3}({\bf{w}}_{t})
=𝔼𝐱,𝒜,𝒜f1(𝐰t;𝐱,𝒜,𝒜)+1ni=1ng(𝐰t;𝐱i,𝒮i)f2(g(𝐰t;𝐱i,𝒮i))+α𝐰ϕ(𝐰t,𝐰(𝐰t))\displaystyle=\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}\nabla f_{1}({\bf{w}}_{t};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}(g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))+\alpha\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}({\bf{w}}_{t}))

Recall the formulation of 𝐦t+1{\bf{m}}_{t+1},

𝐦t+1\displaystyle{\bf{m}}_{t+1} =1||𝐱if1(𝐰t;𝐱,𝒜,𝒜)+1||𝐱i[𝐰g(𝐰t;𝒜(𝐱i),i)+𝐰g(𝐰t;𝒜(𝐱i),i)]f2(𝐮i,t)\displaystyle=\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\nabla f_{1}({\bf{w}}_{t};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\big{[}\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})+\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})\big{]}\nabla f_{2}({\bf{u}}_{i,t})
+α2|a|𝐱ia{𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)+𝐰ϕ(𝐰t,𝐰t;𝒜,𝐱i,ai)}\displaystyle\quad+\frac{\alpha}{2|\mathcal{B}_{a}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}_{a}}\bigg{\{}\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}^{\prime},{\bf{x}}_{i},a_{i})\bigg{\}}

Define

Gt+1\displaystyle G_{t+1} =𝔼𝐱,𝒜,𝒜f1(𝐰t;𝐱,𝒜,𝒜)+1ni=1ng(𝐰t;𝐱i,𝒮i)f2(𝐮i,t)+α𝐰ϕ(𝐰t,𝐰t)\displaystyle=\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}\nabla f_{1}({\bf{w}}_{t};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}({\bf{u}}_{i,t})+\alpha\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})

so that we have 𝔼t[𝐦t+1]=Gt+1\mathbb{E}_{t}[{\bf{m}}_{t+1}]=G_{t+1}, where 𝔼t\mathbb{E}_{t} denotes the expectation over the randomness at tt-th iteration.

Consider

𝔼t[Φ(𝐰t)𝐦~t+12]\displaystyle\mathbb{E}_{t}\left[\|\nabla\Phi({\bf{w}}_{t})-\tilde{{\bf{m}}}_{t+1}\|^{2}\right]
=𝔼t[Φ(𝐰t)(1β)𝐦~tβ𝐦t+12]\displaystyle=\mathbb{E}_{t}\left[\|\nabla\Phi({\bf{w}}_{t})-(1-\beta)\tilde{{\bf{m}}}_{t}-\beta{\bf{m}}_{t+1}\|^{2}\right]
=𝔼t[(1β)(Φ(𝐰t1)𝐦~t)+(1β)(Φ(𝐰t)Φ(𝐰t1))+β(Φ(𝐰t)Gt+1)+β(Gt+1𝐦t+1)2]\displaystyle=\mathbb{E}_{t}\left[\|(1-\beta)(\nabla\Phi({\bf{w}}_{t-1})-\tilde{{\bf{m}}}_{t})+(1-\beta)(\nabla\Phi({\bf{w}}_{t})-\nabla\Phi({\bf{w}}_{t-1}))+\beta(\nabla\Phi({\bf{w}}_{t})-G_{t+1})+\beta(G_{t+1}-{\bf{m}}_{t+1})\|^{2}\right]
=(a)(1β)(Φ(𝐰t1)𝐦~t)+(1β)(Φ(𝐰t)Φ(𝐰t1))+β(Φ(𝐰t)Gt+1)2\displaystyle\stackrel{{\scriptstyle(a)}}{{=}}\|(1-\beta)(\nabla\Phi({\bf{w}}_{t-1})-\tilde{{\bf{m}}}_{t})+(1-\beta)(\nabla\Phi({\bf{w}}_{t})-\nabla\Phi({\bf{w}}_{t-1}))+\beta(\nabla\Phi({\bf{w}}_{t})-G_{t+1})\|^{2}
+β2𝔼t[Gt+1𝐦t+12]\displaystyle\quad+\beta^{2}\mathbb{E}_{t}\left[\|G_{t+1}-{\bf{m}}_{t+1}\|^{2}\right]
(b)(1+β)(1β)2Φ(𝐰t1)𝐦~t2+2(1+1β)[(1β)2Φ(𝐰t)Φ(𝐰t1)2+β2Φ(𝐰t)Gt+12]\displaystyle\stackrel{{\scriptstyle(b)}}{{\leq}}(1+\beta)(1-\beta)^{2}\|\nabla\Phi({\bf{w}}_{t-1})-\tilde{{\bf{m}}}_{t}\|^{2}+2(1+\frac{1}{\beta})\big{[}(1-\beta)^{2}\|\nabla\Phi({\bf{w}}_{t})-\nabla\Phi({\bf{w}}_{t-1})\|^{2}+\beta^{2}\|\nabla\Phi({\bf{w}}_{t})-G_{t+1}\|^{2}\big{]}
+β2𝔼t[Gt+1𝐦t+12]\displaystyle\quad+\beta^{2}\mathbb{E}_{t}\left[\|G_{t+1}-{\bf{m}}_{t+1}\|^{2}\right]
(c)(1β)Φ(𝐰t1)𝐦~t2+4L2β𝐰t𝐰t12+4βΦ(𝐰t)Gt+12+β2𝔼t[Gt+1𝐦t+12]\displaystyle\stackrel{{\scriptstyle(c)}}{{\leq}}(1-\beta)\|\nabla\Phi({\bf{w}}_{t-1})-\tilde{{\bf{m}}}_{t}\|^{2}+\frac{4L^{2}}{\beta}\|{\bf{w}}_{t}-{\bf{w}}_{t-1}\|^{2}+4\beta\|\nabla\Phi({\bf{w}}_{t})-G_{t+1}\|^{2}+\beta^{2}\mathbb{E}_{t}\left[\|G_{t+1}-{\bf{m}}_{t+1}\|^{2}\right]

where equality (a)(a) uses 𝔼t[𝐦t+1]=Gt+1\mathbb{E}_{t}[{\bf{m}}_{t+1}]=G_{t+1}, inequality (b)(b) is due to a+b2(1+β)a2+(1+1β)b2\|a+b\|^{2}\leq(1+\beta)\|a\|^{2}+(1+\frac{1}{\beta})\|b\|^{2}, inequality (c)(c) uses the assumption β1\beta\leq 1 and the smoothness of Φ()\Phi(\cdot).

Furthermore, one may bound the last two terms as following.

Φ(𝐰t)Gt+12\displaystyle\|\nabla\Phi({\bf{w}}_{t})-G_{t+1}\|^{2} =1ni=1ng(𝐰t;𝐱i,𝒮i)f2(g(𝐰t;𝐱i,𝒮i))1ni=1ng(𝐰t;𝐱i,𝒮i)f2(𝐮i,t)\displaystyle=\bigg{\|}\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}(g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))-\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}({\bf{u}}_{i,t})
+α[𝐰ϕ(𝐰t,𝐰(𝐰t))𝐰ϕ(𝐰t,𝐰t)]2\displaystyle\quad+\alpha\big{[}\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}({\bf{w}}_{t}))-\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})\big{]}\bigg{\|}^{2}
2Cg2Lf221ni=1ng(𝐰t;𝐱i,𝒮i)𝐮i,t2+2α2Lϕ2𝐰(𝐰t)𝐰t2\displaystyle\leq 2C_{g}^{2}L_{f_{2}}^{2}\frac{1}{n}\sum_{i=1}^{n}\|g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})-{\bf{u}}_{i,t}\|^{2}+2\alpha^{2}L_{\phi}^{2}\|{\bf{w}}^{\prime}({\bf{w}}_{t})-{\bf{w}}^{\prime}_{t}\|^{2}
2Cg2Lf22ng(𝐰t)𝐮t2+2α2Lϕ2𝐰(𝐰t)𝐰t2.\displaystyle\leq\frac{2C_{g}^{2}L_{f_{2}}^{2}}{n}\|g({\bf{w}}_{t})-{\bf{u}}_{t}\|^{2}+2\alpha^{2}L_{\phi}^{2}\|{\bf{w}}^{\prime}({\bf{w}}_{t})-{\bf{w}}^{\prime}_{t}\|^{2}.
𝔼t[Gt+1𝐦t+12]\displaystyle\mathbb{E}_{t}\left[\|G_{t+1}-{\bf{m}}_{t+1}\|^{2}\right]
=𝔼t[𝔼𝐱,𝒜,𝒜f1(𝐰t;𝐱,𝒜,𝒜)+1ni=1ng(𝐰t;𝐱i,𝒮i)f2(𝐮i,t)+α𝐰ϕ(𝐰t,𝐰t)\displaystyle=\mathbb{E}_{t}\bigg{[}\bigg{\|}\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}\nabla f_{1}({\bf{w}}_{t};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}({\bf{u}}_{i,t})+\alpha\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})
1||𝐱if1(𝐰;𝐱,𝒜,𝒜)1||𝐱i[𝐰g(𝐰t;𝒜(𝐱i),i)+𝐰g(𝐰t;𝒜(𝐱i),i)]f2(𝐮i,t)\displaystyle\quad-\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\nabla f_{1}({\bf{w}};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})-\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\big{[}\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})+\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})\big{]}\nabla f_{2}({\bf{u}}_{i,t})
α2|a|𝐱ia{𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)+𝐰ϕ(𝐰t,𝐰t;𝒜,𝐱i,ai)}2]\displaystyle\quad-\frac{\alpha}{2|\mathcal{B}_{a}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}_{a}}\bigg{\{}\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}^{\prime},{\bf{x}}_{i},a_{i})\bigg{\}}\bigg{\|}^{2}]
3σ2||+3𝔼t[1ni=1ng(𝐰t;𝐱i,𝒮i)f2(𝐮i,t)1||𝐱i[𝐰g(𝐰t;𝒜(𝐱i),i)+𝐰g(𝐰t;𝒜(𝐱i),i)]f2(𝐮i,t)2]\displaystyle\leq\frac{3\sigma^{2}}{|\mathcal{B}|}+3\mathbb{E}_{t}\bigg{[}\bigg{\|}\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{t};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}({\bf{u}}_{i,t})-\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\big{[}\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})+\nabla_{\bf{w}}g({\bf{w}}_{t};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})\big{]}\nabla f_{2}({\bf{u}}_{i,t})\bigg{\|}^{2}\bigg{]}
+3𝔼t[α𝐰ϕ(𝐰t,𝐰t)α2|a|𝐱ia{𝐰ϕ(𝐰t,𝐰t;𝒜(𝐱i),ai)+𝐰ϕ(𝐰t,𝐰t;𝒜,𝐱i,ai)}2\displaystyle\quad+3\mathbb{E}_{t}\bigg{[}\bigg{\|}\alpha\nabla_{{\bf{w}}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t})-\frac{\alpha}{2|\mathcal{B}_{a}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}_{a}}\bigg{\{}\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{\bf{w}}\phi({\bf{w}}_{t},{\bf{w}}^{\prime}_{t};\mathcal{A}^{\prime},{\bf{x}}_{i},a_{i})\bigg{\}}\bigg{\|}^{2}
3σ2||+24Cg2Cf22||+6Cf22σ22|i|+3α2σ22|a|C1min{||,|i|,|a|}\displaystyle\leq\frac{3\sigma^{2}}{|\mathcal{B}|}+\frac{24C_{g}^{2}C_{f_{2}}^{2}}{|\mathcal{B}|}+\frac{6C_{f_{2}}^{2}\sigma^{2}}{2|\mathcal{B}_{i}^{-}|}+\frac{3\alpha^{2}\sigma^{2}}{2|\mathcal{B}_{a}|}\leq\frac{C_{1}}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}}

where C1=max{9σ2+72Cg2Cf22,9Cf22σ2,9α2σ22}C_{1}=\max\{9\sigma^{2}+72C_{g}^{2}C_{f_{2}}^{2},9C_{f_{2}}^{2}\sigma^{2},\frac{9\alpha^{2}\sigma^{2}}{2}\}.

For simplicity, we denote

δ𝐦t=Φ(𝐰t)𝐦~t+12,δ𝐰t=𝐰t𝐰(𝐰t)2,δ𝐮t=1n𝐮tg(𝐰t)2.\displaystyle\delta_{\bf{m}}^{t}=\|\nabla\Phi({\bf{w}}_{t})-\tilde{{\bf{m}}}_{t+1}\|^{2},\quad\delta_{{\bf{w}}^{\prime}}^{t}=\|{\bf{w}}^{\prime}_{t}-{\bf{w}}^{\prime}({\bf{w}}_{t})\|^{2},\quad\delta_{{\bf{u}}}^{t}=\frac{1}{n}\|{\bf{u}}^{t}-g({\bf{w}}_{t})\|^{2}.

Combining the results above yields

𝔼[δ𝐦t+1](1β)𝔼[δ𝐦t]+4L2η2β𝔼[𝐦~t+12]+4β(2Cg2Lf22𝔼[δ𝐮t+1]+2α2Lϕ2𝔼[δ𝐰t+1])+β2C1min{||,|i|,|a|}\displaystyle\mathbb{E}\left[\delta_{\bf{m}}^{t+1}\right]\leq(1-\beta)\mathbb{E}\left[\delta_{\bf{m}}^{t}\right]+\frac{4L^{2}\eta^{2}}{\beta}\mathbb{E}[\|\tilde{{\bf{m}}}_{t+1}\|^{2}]+4\beta\left(2C_{g}^{2}L_{f_{2}}^{2}\mathbb{E}[\delta_{{\bf{u}}}^{t+1}]+2\alpha^{2}L_{\phi}^{2}\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t+1}]\right)+\frac{\beta^{2}C_{1}}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}} (11)

By Lemma 2 we have

𝔼[δ𝐰t+1](1ηλ4)𝔼[δ𝐰t]+η2σ2|a|+2Lϕ4η2ηλ3𝔼[𝐦~t+12]\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t+1}]\leq(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t}]+\frac{\eta^{\prime 2}\sigma^{2}}{|\mathcal{B}_{a}|}+\frac{2L_{\phi}^{4}\eta^{2}}{\eta^{\prime}\lambda^{3}}\mathbb{E}[\|\tilde{{\bf{m}}}_{t+1}\|^{2}] (12)

By Lemma 3 we have

𝔼[δ𝐮t+1](1||γ2n)𝔼[δ𝐮t]+4||γ2σ22n|i|+4nCg2η2||γ𝔼[𝐦~t+12].\mathbb{E}\left[\delta_{{\bf{u}}}^{t+1}\right]\leq(1-\frac{|\mathcal{B}|\gamma}{2n})\mathbb{E}\left[\delta_{{\bf{u}}}^{t}\right]+\frac{4|\mathcal{B}|\gamma^{2}\sigma^{2}}{2n|\mathcal{B}_{i}^{-}|}+\frac{4nC_{g}^{2}\eta^{2}}{|\mathcal{B}|\gamma}\mathbb{E}[\|\tilde{{\bf{m}}}_{t+1}\|^{2}]. (13)

By Lemma 4 we have

𝔼[Φ(𝐰t+1)]\displaystyle\mathbb{E}[\Phi({\bf{w}}_{t+1})] 𝔼[Φ(𝐰t)]+η2𝔼[δ𝐦t]η2𝔼[Φ(𝐰t)2]η4𝔼[𝐦~t+12]\displaystyle\leq\mathbb{E}[\Phi({\bf{w}}_{t})]+\frac{\eta}{2}\mathbb{E}[\delta_{\bf{m}}^{t}]-\frac{\eta}{2}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]-\frac{\eta}{4}\mathbb{E}[\|\tilde{{\bf{m}}}_{t+1}\|^{2}] (14)

Summing (14)(\ref{ineq:a4}), ηβ×(11)\frac{\eta}{\beta}\times(\ref{ineq:a1}), 16Cg2Lf22nη||γ×(13)\frac{16C_{g}^{2}L_{f_{2}}^{2}n\eta}{|\mathcal{B}|\gamma}\times(\ref{ineq:a3}), 32α2Lϕ2ηηλ×(12)\frac{32\alpha^{2}L_{\phi}^{2}\eta}{\eta^{\prime}\lambda}\times(\ref{ineq:a2}) yields

𝔼[Φ(𝐰t+1)]Φ+ηβ𝔼[δ𝐦t+1]+16Cg2Lf22nη||γ(1||γ2n)𝔼[δ𝐮t+1]+32α2Lϕ2ηηλ(1ηλ4)𝔼[δ𝐰t+1]\displaystyle\mathbb{E}[\Phi({\bf{w}}_{t+1})]-\Phi^{*}+\frac{\eta}{\beta}\mathbb{E}\left[\delta_{\bf{m}}^{t+1}\right]+\frac{16C_{g}^{2}L_{f_{2}}^{2}n\eta}{|\mathcal{B}|\gamma}(1-\frac{|\mathcal{B}|\gamma}{2n})\mathbb{E}\left[\delta_{{\bf{u}}}^{t+1}\right]+\frac{32\alpha^{2}L_{\phi}^{2}\eta}{\eta^{\prime}\lambda}(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t+1}] (15)
𝔼[Φ(𝐰t)]Φ+ηβ(1β2)𝔼[δ𝐦t]+16Cg2Lf22nη||γ(1||γ2n)𝔼[δ𝐮t]+32α2Lϕ2ηηλ(1ηλ4)𝔼[δ𝐰t]\displaystyle\leq\mathbb{E}[\Phi({\bf{w}}_{t})]-\Phi^{*}+\frac{\eta}{\beta}(1-\frac{\beta}{2})\mathbb{E}\left[\delta_{\bf{m}}^{t}\right]+\frac{16C_{g}^{2}L_{f_{2}}^{2}n\eta}{|\mathcal{B}|\gamma}(1-\frac{|\mathcal{B}|\gamma}{2n})\mathbb{E}\left[\delta_{{\bf{u}}}^{t}\right]+\frac{32\alpha^{2}L_{\phi}^{2}\eta}{\eta^{\prime}\lambda}(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t}]
η2𝔼[Φ(𝐰t)2]+η(4L2η2β2+64Lf22Cg4n2η2||2γ2+64α2Lϕ6η2η2λ414)𝔼[𝐦~t+12]\displaystyle\quad-\frac{\eta}{2}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]+\eta\left(\frac{4L^{2}\eta^{2}}{\beta^{2}}+\frac{64L_{f_{2}}^{2}C_{g}^{4}n^{2}\eta^{2}}{|\mathcal{B}|^{2}\gamma^{2}}+\frac{64\alpha^{2}L_{\phi}^{6}\eta^{2}}{\eta^{\prime 2}\lambda^{4}}-\frac{1}{4}\right)\mathbb{E}[\|\tilde{{\bf{m}}}_{t+1}\|^{2}]
+ηβC1min{||,|i|,|a|}+32Cg2Lf22γσ2η|i|+32α2Lϕ2ηησ2λ|a|\displaystyle\quad+\frac{\eta\beta C_{1}}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}}+\frac{32C_{g}^{2}L_{f_{2}}^{2}\gamma\sigma^{2}\eta}{|\mathcal{B}_{i}^{-}|}+\frac{32\alpha^{2}L_{\phi}^{2}\eta\eta^{\prime}\sigma^{2}}{\lambda|\mathcal{B}_{a}|}

Set

η=min{12L,2βL,||γ32Lf2Cg2n,ηλ216αLϕ3}\eta=\min\left\{\frac{1}{2L},\frac{2\beta}{L},\frac{|\mathcal{B}|\gamma}{32L_{f_{2}}C_{g}^{2}n},\frac{\eta^{\prime}\lambda^{2}}{16\alpha L_{\phi}^{3}}\right\}

so that 4L2η2β2+64Lf22Cg4n2η2||2γ2+64α2Lϕ6η2η2λ4140\frac{4L^{2}\eta^{2}}{\beta^{2}}+\frac{64L_{f_{2}}^{2}C_{g}^{4}n^{2}\eta^{2}}{|\mathcal{B}|^{2}\gamma^{2}}+\frac{64\alpha^{2}L_{\phi}^{6}\eta^{2}}{\eta^{\prime 2}\lambda^{4}}-\frac{1}{4}\leq 0. Define potential function

Pt=𝔼[Φ(𝐰t)]Φ+ηβ𝔼[δ𝐦t]+16Cg2Lf22nη||γ(1||γ2n)𝔼[δ𝐮t]+32α2Lϕ2ηηλ(1ηλ4)𝔼[δ𝐰t],P_{t}=\mathbb{E}[\Phi({\bf{w}}_{t})]-\Phi^{*}+\frac{\eta}{\beta}\mathbb{E}\left[\delta_{\bf{m}}^{t}\right]+\frac{16C_{g}^{2}L_{f_{2}}^{2}n\eta}{|\mathcal{B}|\gamma}(1-\frac{|\mathcal{B}|\gamma}{2n})\mathbb{E}\left[\delta_{{\bf{u}}}^{t}\right]+\frac{32\alpha^{2}L_{\phi}^{2}\eta}{\eta^{\prime}\lambda}(1-\frac{\eta^{\prime}\lambda}{4})\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{t}],

then we have

1Tt=1T𝔼[Φ(𝐰t)2]2P1ηT+2βC1min{||,|i|,|a|}+64Cg2Lf22γσ2|i|+64α2Lϕ2ησ2λ|a|\frac{1}{T}\sum_{t=1}^{T}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]\leq\frac{2P_{1}}{\eta T}+\frac{2\beta C_{1}}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}}+\frac{64C_{g}^{2}L_{f_{2}}^{2}\gamma\sigma^{2}}{|\mathcal{B}_{i}^{-}|}+\frac{64\alpha^{2}L_{\phi}^{2}\eta^{\prime}\sigma^{2}}{\lambda|\mathcal{B}_{a}|} (16)

If we initialize 𝐮i,1=12[g(𝐰1;𝒜(𝐱i),i)+g(𝐰1;𝒜(𝐱i),i)]{\bf{u}}_{i,1}=\frac{1}{2}\left[g({\bf{w}}_{1};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})+g({\bf{w}}_{1};\mathcal{A}^{\prime}({\bf{x}}_{i}),\mathcal{B}^{-}_{i})\right] for all ii, then we have 𝔼[δ𝐮1]σ22|i|\mathbb{E}[\delta_{{\bf{u}}}^{1}]\leq\frac{\sigma^{2}}{2|\mathcal{B}^{-}_{i}|}. Moreover, we run multiple steps of stochastic gradient ascent to approximate 𝐰(𝐰1){\bf{w}}^{\prime}({\bf{w}}_{1}) so that 𝔼[δ𝐰1]1\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{1}]\leq 1. We set 𝐦~1=𝐦2\tilde{{\bf{m}}}_{1}={\bf{m}}_{2} so that

𝔼[δ𝐦1]\displaystyle\mathbb{E}[\delta_{{\bf{m}}}^{1}] =𝔼[Φ(𝐰1)𝐦~22]\displaystyle=\mathbb{E}[\|\Phi({\bf{w}}_{1})-\tilde{{\bf{m}}}_{2}\|^{2}]
=𝔼[Φ(𝐰1)𝐦22]\displaystyle=\mathbb{E}[\|\Phi({\bf{w}}_{1})-{\bf{m}}_{2}\|^{2}]
=𝔼[𝔼𝐱,𝒜,𝒜f1(𝐰1;𝐱,𝒜,𝒜)+1ni=1ng(𝐰1;𝐱i,𝒮i)f2(g(𝐰1;𝐱i,𝒮i))+α𝐰ϕ(𝐰1,𝐰(𝐰1))\displaystyle=\mathbb{E}\bigg{[}\bigg{\|}\mathbb{E}_{{\bf{x}},\mathcal{A},\mathcal{A}^{\prime}}\nabla f_{1}({\bf{w}}_{1};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})+\frac{1}{n}\sum_{i=1}^{n}\nabla g({\bf{w}}_{1};{\bf{x}}_{i},\mathcal{S}_{i}^{-})\nabla f_{2}(g({\bf{w}}_{1};{\bf{x}}_{i},\mathcal{S}_{i}^{-}))+\alpha\nabla_{{\bf{w}}}\phi({\bf{w}}_{1},{\bf{w}}^{\prime}({\bf{w}}_{1}))
1||𝐱if1(𝐰1;𝐱,𝒜,𝒜)1||𝐱i[𝐰g(𝐰1;𝒜(𝐱i),i)+𝐰g(𝐰1;𝒜(𝐱i),i)]f2(𝐮i,t)\displaystyle\quad-\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\nabla f_{1}({\bf{w}}_{1};{\bf{x}},\mathcal{A},\mathcal{A}^{\prime})-\frac{1}{|\mathcal{B}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}}\big{[}\nabla_{\bf{w}}g({\bf{w}}_{1};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})+\nabla_{\bf{w}}g({\bf{w}}_{1};\mathcal{A}({\bf{x}}_{i}),\mathcal{B}_{i}^{-})\big{]}\nabla f_{2}({\bf{u}}_{i,t})
α2|a|𝐱ia{𝐰ϕ(𝐰1,𝐰1;𝒜(𝐱i),ai)+𝐰ϕ(𝐰1,𝐰1;𝒜,𝐱i,ai)}2]\displaystyle\quad-\frac{\alpha}{2|\mathcal{B}_{a}|}\sum_{{\bf{x}}_{i}\in\mathcal{B}_{a}}\bigg{\{}\nabla_{\bf{w}}\phi({\bf{w}}_{1},{\bf{w}}^{\prime}_{1};\mathcal{A}({\bf{x}}_{i}),a_{i})+\nabla_{\bf{w}}\phi({\bf{w}}_{1},{\bf{w}}^{\prime}_{1};\mathcal{A}^{\prime},{\bf{x}}_{i},a_{i})\bigg{\}}\bigg{\|}^{2}\bigg{]}
3σ2||+9Cg2Lf22𝔼[δ𝐮1]+9Cg2Cf22||+σ2Cf222|i|+6α2Lϕ2𝔼[δ𝐰1]+6α2σ22|a|\displaystyle\leq\frac{3\sigma^{2}}{|\mathcal{B}|}+9C_{g}^{2}L_{f_{2}}^{2}\mathbb{E}[\delta_{{\bf{u}}}^{1}]+\frac{9C_{g}^{2}C_{f_{2}}^{2}}{|\mathcal{B}|}+\frac{\sigma^{2}C_{f_{2}}^{2}}{2|\mathcal{B}_{i}^{-}|}+6\alpha^{2}L_{\phi}^{2}\mathbb{E}[\delta_{{\bf{w}}^{\prime}}^{1}]+\frac{6\alpha^{2}\sigma^{2}}{2|\mathcal{B}_{a}|}
3σ2||+9Cg2Lf22σ22|i|+9Cg2Cf22||+σ2Cf222|i|+6α2Lϕ2+6α2σ22|a|=:C2\displaystyle\leq\frac{3\sigma^{2}}{|\mathcal{B}|}+9C_{g}^{2}L_{f_{2}}^{2}\frac{\sigma^{2}}{2|\mathcal{B}^{-}_{i}|}+9\frac{C_{g}^{2}C_{f_{2}}^{2}}{|\mathcal{B}|}+\frac{\sigma^{2}C_{f_{2}}^{2}}{2|\mathcal{B}_{i}^{-}|}+6\alpha^{2}L_{\phi}^{2}+\frac{6\alpha^{2}\sigma^{2}}{2|\mathcal{B}_{a}|}=:C_{2}

Define ΛP:=ΔΦ+2C2L+Lf2σ24|i|+αλ2P1\Lambda_{P}:=\Delta_{\Phi}+\frac{2C_{2}}{L}+\frac{L_{f_{2}}\sigma^{2}}{4|\mathcal{B}^{-}_{i}|}+\frac{\alpha\lambda}{2}\geq P_{1}. Then

1Tt=1T𝔼[Φ(𝐰t)2]2ΛPηT+2βC1min{||,|i|,|a|}+64Cg2Lf22γσ2|i|+64α2Lϕ2ησ2λ|a|.\frac{1}{T}\sum_{t=1}^{T}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]\leq\frac{2\Lambda_{P}}{\eta T}+\frac{2\beta C_{1}}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}}+\frac{64C_{g}^{2}L_{f_{2}}^{2}\gamma\sigma^{2}}{|\mathcal{B}_{i}^{-}|}+\frac{64\alpha^{2}L_{\phi}^{2}\eta^{\prime}\sigma^{2}}{\lambda|\mathcal{B}_{a}|}. (17)

Set

β=min{1,min{||,|i|,|a|}ϵ212C1},γ=min{12,|i|ϵ2384Cg2Lf22σ2},η=min{4λ,λ2Lϕ2,λ|a|ϵ2384α2Lϕ2σ2},\beta=\min\left\{1,\frac{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}\epsilon^{2}}{12C_{1}}\right\},\quad\gamma=\min\left\{\frac{1}{2},\frac{|\mathcal{B}_{i}^{-}|\epsilon^{2}}{384C_{g}^{2}L_{f_{2}}^{2}\sigma^{2}}\right\},\quad\eta^{\prime}=\min\left\{\frac{4}{\lambda},\frac{\lambda}{2L_{\phi}^{2}},\frac{\lambda|\mathcal{B}_{a}|\epsilon^{2}}{384\alpha^{2}L_{\phi}^{2}\sigma^{2}}\right\},

then after

T4ΛPηϵ2\displaystyle T\geq\frac{4\Lambda_{P}}{\eta\epsilon^{2}} =4ΛPϵ2max{2L,L2β,32Lf2Cg2n||γ,16αLϕ3ηλ2}\displaystyle=\frac{4\Lambda_{P}}{\epsilon^{2}}\max\left\{2L,\frac{L}{2\beta},\frac{32L_{f_{2}}C_{g}^{2}n}{|\mathcal{B}|\gamma},\frac{16\alpha L_{\phi}^{3}}{\eta^{\prime}\lambda^{2}}\right\}
=4ΛPϵ2max{2L,6C1Lmin{||,|i|,|a|}ϵ2,64Lf2Cg2n||,12288nCg4Lf23σ2|||i|ϵ2,4αLϕ3λ,32αLϕ5λ3,6144α3Lϕ5σ2λ3|a|ϵ2}\displaystyle=\frac{4\Lambda_{P}}{\epsilon^{2}}\max\left\{2L,\frac{6C_{1}L}{\min\{|\mathcal{B}|,|\mathcal{B}_{i}^{-}|,|\mathcal{B}_{a}|\}\epsilon^{2}},\frac{64L_{f_{2}}C_{g}^{2}n}{|\mathcal{B}|},\frac{12288nC_{g}^{4}L_{f_{2}}^{3}\sigma^{2}}{|\mathcal{B}||\mathcal{B}_{i}^{-}|\epsilon^{2}},\frac{4\alpha L_{\phi}^{3}}{\lambda},\frac{32\alpha L_{\phi}^{5}}{\lambda^{3}},\frac{6144\alpha^{3}L_{\phi}^{5}\sigma^{2}}{\lambda^{3}|\mathcal{B}_{a}|\epsilon^{2}}\right\}

iterations, we have 1Tt=1T𝔼[Φ(𝐰t)2]ϵ2\frac{1}{T}\sum_{t=1}^{T}\mathbb{E}[\|\nabla\Phi({\bf{w}}_{t})\|^{2}]\leq\epsilon^{2}.

Appendix C More Details about Experiments

In all our experiments, we tuned the learning rates (lr) for the Adam optimizer [19] within the range {1e\{1e-3,1e3,1e-4,1e4,1e-5}5\}. The batch size is set to 128 for CelebA and 64 for UTKFace. For baseline end-to-end regularized methods that require task labels (CE, CE+EOD, CE+DPR, CE+EQL, and CE+EQL), we trained for 100 epochs, and the regularizer weights were tuned in the range {0.1,0.3,0.5,0.7,0.9,1}\{0.1,0.3,0.5,0.7,0.9,1\}. For adversarial baselines with task labels (ML-AFL and Max-Ent), we trained for 100 epochs for both feature representation learning and the adversarial head optimization, sequentially. For contrastive-based baselines and our method SoFCLR, we trained for 100 epochs for contrastive learning and 20 epochs for linear evaluation, incorporating a stagewise learning rate decay strategy at the 10th epoch by a factor of 10. The combination weights in ML-AFL, Max-Ent, SimCLR+CCL, and SoFCLR were fine-tuned within the range {0.1,0.3,0.5,0.7,0.9,1}\{0.1,0.3,0.5,0.7,0.9,1\}. All results are reported based on 4 independent runs.

Appendix D More Experimental Results

D.1 More quantitative performance on CelebA datasets in Section 6.1.

Table 5: Results on CelebA: accuracy of predicting Big Nose and fairness metrics for two sensitive attributes, Male and Young.
(Big Nose, Male) Acc Δ\Delta ED Δ\Delta EO Δ\DeltaDP IntraAUC InterAUC GAUC WD KL
CE 82.21 (±\pm 0.42) 22.37 (±\pm 0.35) 28.31 (±\pm 0.36) 23.01(±\pm 0.39) 0.0433 (±\pm 9e-4) 0.3610 (±\pm 5e-3) 0.2973(±\pm 6e-3) 0.204(±\pm 5e-3) 0.6771 (±\pm 6e-3)
CE + EOD 82.15 (±\pm 0.40) 19.09 (±\pm 0.33) 27.32 (±\pm 0.39) 22.71(±\pm 0.33) 0.0463 (±\pm 8e-4) 0.3671 (±\pm 7e-3) 0.2985(±\pm 7e-3) 0.1965(±\pm 6e-3) 0.6843(±\pm 8e-3)
CE + DPR 82.29 (±\pm 0.38) 18.92 (±\pm 0.37) 24.31 (±\pm 0.40) 22.62(±\pm 0.38) 0.0425 (±\pm 7e-4) 0.3860 (±\pm 9e-3) 0.3046(±\pm 6e-3) 0.1949(±\pm 7e-3) 0.706(±\pm 6e-3)
CE + EQL 81.91 (±\pm 0.39) 20.92 (±\pm 0.41) 27.15 (±\pm 0.43) 25.55(±\pm 0.34) 0.0403 (±\pm 8e-4) 0.3571 (±\pm 7e-3) 0.2971(±\pm 7e-3) 0.1975(±\pm 5e-3) 0.637(±\pm 8e-3)
ML-AFL 81.81 (±\pm 0.36) 29.66 (±\pm 0.37) 18.76 (±\pm 0.38) 24.07(±\pm 0.41) 0.0502 (±\pm 1e-3) 0.5660(±\pm 6e-3) 0.3683 (±\pm 7e-3) 0.2124(±\pm 4e-3) 1.2163(±\pm 7e-3)
Max-Ent 81.71 (±\pm 0.43) 16.16 (±\pm 0.34) 23.98 (±\pm 0.46) 16.44(±\pm 0.45) 0.0505 (±\pm 8e-4) 0.526(±\pm 7e-3) 0.3618(±\pm 8e-3) 0.1977(±\pm 7e-3) 1.1411(±\pm 6e-3)
SimCLR 82.72 (±\pm 0.37) 18.49 (±\pm 0.39) 28.71 (±\pm 0.35) 21.13(±\pm 0.38) 0.0664 (±\pm 6e-4) 0.4462(±\pm 7e-3) 0.3488(±\pm 6e-3) 0.1777(±\pm 5e-3) 0.9986(±\pm 7e-3)
SogCLR 82.64 (±\pm 0.41) 16.35 (±\pm 0.30) 26.12 (±\pm 0.31) 19.91(±\pm 0.42) 0.0636 (±\pm 7e-4) 0.4584(±\pm 8e-3) 0.3656(±\pm 9e-3) 0.1728(±\pm 8e-3) 0.9986(±\pm 7e-3)
Boyl 82.62 (±\pm 0.35) 16.48 (±\pm 0.32) 25.13 (±\pm 0.36) 19.67(±\pm 0.39) 0.0647 (±\pm 6e-4) 0.4567(±\pm 6e-3) 0.3435(±\pm 7e-3) 0.1745(±\pm 5e-3) 0.8989(±\pm 9e-3)
SimCLR + CCL 82.11 (±\pm 0.38) 15.34 (±\pm 0.35) 24.24 (±\pm 0.33) 16.56(±\pm 0.36) 0.0589 (±\pm 8e-4) 0.3678(±\pm 7e-3) 0.2897(±\pm 6e-3) 0.1544(±\pm 7e-3) 0.6691(±\pm 6e-3)
SoFCLR 81.83 (±\pm 0.33) 8.61 (±\pm 0.28) 18.64 (±\pm 0.29) 12.91(±\pm 0.30) 0.0341 (±\pm 5e-4) 0.1538(±\pm 5e-3) 0.2299(±\pm 4e-3) 0.0816(±\pm 5e-3) 0.5809(±\pm 4e-3)
(Big Nose, Young) Acc Δ\Delta ED Δ\Delta EO Δ\DeltaDP IntraAUC InterAUC GAUC WD KL
CE 81.78 (±\pm 0.38) 20.96 (±\pm 0.55) 20.01 (±\pm 0.48) 23.22(±\pm 0.43) 0.0401 (±\pm 8e-4) 0.2169(±\pm 9e-3) 0.2046(±\pm 6e-3) 0.1495 (±\pm 7e-3) 0.2851(±\pm 8e-3)
CE + EOD 81.59 (±\pm 0.47) 18.96 (±\pm 0.43) 18.31 (±\pm 0.41) 24.14(±\pm 0.38) 0.0389 (±\pm 7e-4) 0.2345(±\pm 8e-3) 0.2037(±\pm 5e-3) 0.1476 (±\pm 7e-3) 0.2651(±\pm 7e-3)
CE + DPR 82.11 (±\pm 0.51) 20.32 (±\pm 0.37) 18.01 (±\pm 0.47) 23.12(±\pm 0.41) 0.0433 (±\pm 8e-4) 0.3610(±\pm 9e-3) 0.2973(±\pm 4e-3) 0.1603 (±\pm 6e-3) 0.2573(±\pm 7e-3)
CE + EQL 81.58 (±\pm 0.38) 17.23 (±\pm 0.41) 18.59 (±\pm 0.39) 25.40(±\pm 0.45) 0.0369 (±\pm 9e-4) 0.2185(±\pm 7e-3) 0.2014(±\pm 7e-3) 0.1477 (±\pm 8e-3) 0.2753(±\pm 9e-3)
ML-AFL 81.66 (±\pm 0.44) 22.51 (±\pm 0.35) 16.2 (±\pm 0.41) 24.93(±\pm 0.36) 0.0360 (±\pm 1e-3) 0.3331(±\pm 5e-3) 0.2392(±\pm 6e-3) 0.1449(±\pm 9e-3) 0.4033(±\pm 6e-3)
Max-Ent 81.72 (±\pm 0.35) 21.69 (±\pm 0.29) 16.37 (±\pm 0.35) 25.55(±\pm 0.33) 0.0487 (±\pm 6e-4) 0.2919(±\pm 8e-3) 0.2289(±\pm 6e-3) 0.1529(±\pm 9e-3) 0.3661(±\pm 7e-3)
SimCLR 82.57 (±\pm 0.40) 12.59 (±\pm 0.31) 17.41 (±\pm 0.37) 16.70(±\pm 0.34) 0.0564 (±\pm 7e-4) 0.2214(±\pm 8e-3) 0.2208(±\pm 8e-3) 0.1139(±\pm 6e-3) 0.3325(±\pm 6e-3)
SogCLR 82.48(±\pm 0.36) 12.05 (±\pm 0.43) 16.21 (±\pm 0.39) 15.37(±\pm 0.39) 0.0559 (±\pm 5e-4) 0.2333(±\pm 6e-3) 0.2268(±\pm 5e-3) 0.1141(±\pm 7e-3) 0.3635(±\pm 6e-3)
Boyl 82.31 (±\pm 0.43) 12.39 (±\pm 0.37) 16.46 (±\pm 0.34) 16.01(±\pm 0.41) 0.0567 (±\pm 7e-4) 0.2345(±\pm 6e-3) 0.2249(±\pm 7e-3) 0.1201(±\pm 6e-3) 0.3647(±\pm 6e-3)
SimCLR + CCL 82.37 (±\pm 0.39) 11.88 (±\pm 0.36) 15.90 (±\pm 0.37) 14.89(±\pm 0.35) 0.0536 (±\pm 6e-4) 0.2264(±\pm 7e-3) 0.2187(±\pm 6e-3) 0.1101(±\pm 7e-3) 0.2893(±\pm 7e-3)
SoFCLR 82.36 (±\pm 0.31) 9.71 (±\pm 0.27) 14.61 (±\pm 0.30) 12.90(±\pm 0.31) 0.0545 (±\pm 5e-4) 0.2165(±\pm 5e-3) 0.2090(±\pm 3e-3) 0.1042(±\pm 6e-3) 0.1869(±\pm 5e-3)
Table 6: Results on CelebA: accuracy of predicting Bags Under Eyes and fairness metrics for two sensitive attributes, Male and Young.
(Bags Under Eyes, Male) Acc Δ\Delta OD Δ\Delta EO Δ\DeltaDP IntraAUC InterAUC GAUC WD KL
CE 81.49 ( ±\pm 0.34) 5.67 ( ±\pm 0.32) 11.23 ( ±\pm 0.33) 7.09 ( ±\pm 0.39) 0.0919 (±\pm 4e-3) 0.2781 (±\pm 5e-3) 0.2921 (±\pm 5e-3) 0.1355 (±\pm 4e-3) 0.6803 (±\pm 5e-3)
CE+EOD 81.24 ( ±\pm 0.33) 6.71 ( ±\pm 0.33) 11.12 ( ±\pm 0.34) 6.75 ( ±\pm 0.36) 0.0848 (±\pm 5e-3) 0.3045 (±\pm 6e-3) 0.2933 (±\pm 6e-3) 0.1347 (±\pm 6e-3) 0.6767 (±\pm 6e-3)
CE + DPR 81.27 ( ±\pm 0.41) 6.39 ( ±\pm 0.37) 10.06 ( ±\pm 0.37) 6.63 (±\pm 0.39) 0.1034 (±\pm 6e-3) 0.2902 (±\pm 5e-3) 0.2968 (±\pm 6e-3) 0.1414 (±\pm 6e-3) 0.6951 (±\pm 5e-3)
CE + EQL 81.36 ( ±\pm 0.38) 6.57 ( ±\pm 0.35) 10.20 ( ±\pm 0.40) 6.95 (±\pm 0.43) 0.0937 (±\pm 3e-3) 0.309 (±\pm 8e-3) 0.3024 (±\pm 6e-3) 0.1443 (±\pm 5e-3) 0.7278 (±\pm 7e-3)
ML-AFL 81.74 ( ±\pm 0.40) 6.31 ( ±\pm 0.33) 12.10 ( ±\pm 0.39) 7.23 ( ±\pm 0.33) 0.0963 (±\pm 5e-3) 0.3231 (±\pm 6e-3) 0.2392 (±\pm 5e-3) 0.1449 (±\pm 5e-3) 0.6883 (±\pm 6e-3)
Max-Ent 81.36 ( ±\pm 0.35) 6.19 ( ±\pm 0.34) 11.29 (±\pm 0.38) 6.83 ( ±\pm 0.37) 0.0883 (±\pm 4e-3) 0.2923 (±\pm 5e-3) 0.2439 (±\pm 7e-3) 0.1529 (±\pm 4e-3) 0.6723 (±\pm 5e-3)
SimCLR 82.14 ( ±\pm 0.43) 10.09 ( ±\pm 0.31) 16.10 (±\pm 0.41) 10.25 ( ±\pm 0.39) 0.0905 (±\pm 4e-3) 0.3257 (±\pm 6e-3) 0.3271 (±\pm 5e-3) 0.1568 (±\pm 5e-3) 0.8885 (±\pm 5e-3)
SogCLR 81.63 ( ±\pm 0.35) 8.78 ( ±\pm 0.30) 14.15 ( ±\pm 0.33) 9.83 ( ±\pm 0.36) 0.0881 (±\pm 5e-3) 0.3241 (±\pm 5e-3) 0.3259 (±\pm 6e-3) 0.1499 (±\pm 4e-3) 0.8652 (±\pm 6e-3)
Boyl 81.73 ( ±\pm 0.38) 8.63 ( ±\pm 0.33) 13.23 ( ±\pm 0.36) 9.64 ( ±\pm 0.37) 0.0923 (±\pm 6e-3) 0.3345 (±\pm 5e-3) 0.3149 (±\pm 6e-3) 0.1423 (±\pm 6e-3) 0.8211 (±\pm 6e-3)
SimCLR + CCL 81.58 ( ±\pm 0.37) 7.67 ( ±\pm 0.35) 11.89 ( ±\pm 0.37) 8.92 ( ±\pm 0.35) 0.0911 (±\pm 5e-3) 0.2911 (±\pm 5e-3) 0.3041 (±\pm 5e-3) 0.1213 (±\pm 5e-3) 0.7611 (±\pm 6e-3)
SoFCLR 81.43 ( ±\pm 0.29) 5.19 ( ±\pm 0.27) 7.47 ( ±\pm 0.31) 6.53 ( ±\pm 0.30) 0.0902 (±\pm 3e-3) 0.2348 (±\pm 4e-3) 0.2583 (±\pm 4e-3) 0.0838 (±\pm 3e-3) 0.4794 (±\pm 4e-3)
(Bags Under Eyes, Young) Acc Δ\Delta ED Δ\Delta EO Δ\Delta DP IntraAUC InterAUC GAUC WD KL
CE 83.13 (±\pm 0.41) 12.51 (±\pm 0.39) 18.12 (±\pm 0.37) 14.45 (±\pm 0.42) 0.0376 (±\pm 8e-4) 0.1861 (±\pm 6e-3) 0.1842 (±\pm 8e-3) 0.1195 (±\pm 5e-3) 0.2391 (±\pm 6e-3)
CE + EOD 83.03 (±\pm 0.38) 10.96 (±\pm 0.37) 14.84 (±\pm 0.38) 14.23 (±\pm 0.39) 0.0372 (±\pm 9e-4) 0.1682 (±\pm 8e-3) 0.1767(±\pm 9e-3) 0.1131 (±\pm 4e-3) 0.2177 (±\pm 7e-3)
CE + DPR 82.67 (±\pm 0.37) 8.32 (±\pm 0.43) 11.05 (±\pm 0.40) 11.33 (±\pm 0.41) 0.0413 (±\pm 8e-4) 0.1622 (±\pm 7e-3) 0.1745(±\pm 9e-3) 0.1043 (±\pm 5e-3) 0.2103 (±\pm 6e-3)
CE + EQL 82.57 (±\pm 0.40) 9.02 (±\pm 0.37) 11.68 (±\pm 0.37) 12.33 (±\pm 0.38) 0.0434 (±\pm 7e-4) 0.1579 (±\pm 5e-3) 0.1704(±\pm 7e-3) 0.1037 (±\pm 5e-3) 0.1991 (±\pm 6e-3)
ML-AFL 81.91 (±\pm 0.38) 8.08 (±\pm 0.41) 12.22 (±\pm 0.51) 8.92 (±\pm 0.43) 0.0427 (±\pm1e-3) 0.1926 (±\pm 7e-3) 0.1823(±\pm 8e-3) 0.0963 (±\pm 6e-3) 0.2543 (±\pm 6e-3)
Max-Ent 81.56 (±\pm 0.42) 10.61 (±\pm 0.39) 16.16 (±\pm 0.38) 9.12 (±\pm 0.37) 0.0442 (±\pm 7e-4) 0.1939 (±\pm 9e-3) 0.1993(±\pm 7e-3) 0.1993 (±\pm 4e-3) 0.2737 (±\pm 7e-3)
SimCLR 82.13 (±\pm 0.37) 10.01 (±\pm 0.40) 17.01 (±\pm 0.43) 9.81 (±\pm 0.39) 0.0484 (±\pm1e-3) 0.1901 (±\pm 8e-3) 0.2011(±\pm 6e-3) 0.1057 (±\pm 6e-3) 0.2808 (±\pm 7e-3)
SogCLR 81.63 (±\pm 0.38) 8.13 (±\pm 0.38) 14.21 (±\pm 0.41) 9.63 (±\pm 0.36) 0.0494 (±\pm 8e-4) 0.1906 (±\pm 6e-3) 0.1990(±\pm 7e-3) 0.1029(±\pm 4e-3) 0.2827 (±\pm 6e-3)
Boyl 81.56 (±\pm 0.34) 8.54(±\pm 0.37) 15.32 (±\pm 0.42) 9.71 (±\pm 0.37) 0.0467 (±\pm 7e-4) 0.1924 (±\pm 6e-3) 0.1948 (±\pm 6e-3) 0.1037 (±\pm 5e-3) 0.2748(±\pm 7e-3)
SimCLR + CCL 81.43 (±\pm 0.36) 7.93 (±\pm 0.37) 13.91 (±\pm 0.39) 9.13 (±\pm 0.34) 0.0443 (±\pm 7e-4) 0.1877 (±\pm 7e-3) 0.1849 (±\pm 8e-3) 0.0837 (±\pm 5e-3) 0.2564 (±\pm 6e-3)
SoFCLR 81.62 (±\pm 0.33) 6.91 (±\pm 0.34) 10.32 (±\pm 0.35) 7.89 (±\pm 0.31) 0.0377 (±\pm 5e-4) 0.1729 (±\pm 6e-3) 0.1701 (±\pm 5e-3) 0.0565 (±\pm 3e-3) 0.1944 (±\pm 5e-3)
Table 7: UTKFace training data statistics for two different tasks with ethnicity and age as the sensitive attribute, respectively.
ethnicity Caucasian Others age <=35<=35 >35>35
Female 1000 4000 Female 3250 1750
Male 4000 1000 Male 1750 3250

The convergence curves of our algorithm on UTKface data are shown in Figure 3. We also plot the prediction score distributions for positive and negative class on UTKFace in Figure 4. The results on other tasks of CelebA data are shown in Table 6 and 6.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 3: The convergence curves of different objective components optimized by SoFCLR with varying α\alpha values on the UTKFace dataset, using gender as the target label and different sensitive attributes, are shown in the figure.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 4: Prediction score distributions for positive and negative class on UTKFace with gender as the target and ethnicity as the sensitive attribute.

D.2 Ablation Studies on Multi-valued Sensitive Attribute

While adding experiments for multi-valued sensitive attribute will add value to this paper, this should not dim the major contribution of this paper. The analysis to multi-valued sensitive attribute is a straightforward extention. Here are some details. Indeed, the optimization algorithms presented in Section 5 is generic enough to cover the multi-valued sensitive attribute as long as the adversarial loss is replaced by the following one. Below, we present the analysis of fairness and also present some empirical results. For multi-valued sensitive attribute with KK possible values, we let D(E(x))KD(E(x))\in\mathbb{R}^{K} denote the predicted probabilities for KK possible values of the sensitive attribute. Denote by Dk(E(x))D_{k}(E(x)) as the kk-th element of D(E(x))D(E(x)) for the kk-th value of the sensitive attribute. Then k=1KDk(E(x))=1\sum_{k=1}^{K}D_{k}(E(x))=1. We define the minimax problem as:

minEmaxD𝔼x,a\displaystyle\min_{E}\max_{D}\mathbb{E}_{x,a} [k=1Kδ(a,k)logDk(E(x))]\displaystyle\left[\sum_{k=1}^{K}\delta(a,k)\log D_{k}(E(x))\right]

Let us first fix EE and optimize DD. The objective is equivalent to

𝔼xk=1Kp(a=k|E(x))logDk(E(x))\displaystyle\mathbb{E}_{x}\sum_{k=1}^{K}p(a=k|E(x))\log D_{k}(E(x))

By maximizing D(E(x))D(E(x)), we have Dk(E(x))=p(a=k|E(x))D_{k}(E(x))=p(a=k|E(x)). Then we have the following objective for EE:

𝔼x,a\displaystyle\mathbb{E}_{x,a} [k=1Kδ(a,k)logp(a=k|E)]\displaystyle\left[\sum_{k=1}^{K}\delta(a,k)\log p(a=k|E)\right]
=𝔼x,a[k=1Kδ(a,k)logp(E|a=k)p(a=k)p(E)]\displaystyle=\mathbb{E}_{x,a}\left[\sum_{k=1}^{K}\delta(a,k)\log\frac{p(E|a=k)p(a=k)}{p(E)}\right]
=𝔼x,a[k=1Kδ(a,k)logp(a=k)]+𝔼x,a[k=1Kδ(a,k)logp(E|a=k)p(E)]\displaystyle=\mathbb{E}_{x,a}\left[\sum_{k=1}^{K}\delta(a,k)\log p(a=k)\right]+\mathbb{E}_{x,a}\left[\sum_{k=1}^{K}\delta(a,k)\log\frac{p(E|a=k)}{p(E)}\right]
=C+𝔼x,a[logp(E|a)p(E)]=C+𝔼a𝔼p(E|a)[logp(E|a)p(E)]=C+𝔼a[KL(p(E|a),p(E))]\displaystyle=C+\mathbb{E}_{x,a}\left[\log\frac{p(E|a)}{p(E)}\right]=C+\mathbb{E}_{a}\mathbb{E}_{p(E|a)}\left[\log\frac{p(E|a)}{p(E)}\right]=C+\mathbb{E}_{a}[\text{KL}(p(E|a),p(E))]

where CC is independent of EE. Hence by minimizing over EE we have the optimal EE_{*} satisfying p(E|a)=p(E)p(E_{*}|a)=p(E_{*}). As a result, Dk(E(x))=p(a=k|E(x))=p(a=k)D^{*}k(E(x))=p(a=k|E_{(}x))=p(a=k).

We have an experiment on UTKface dataset. We consider the sensitive attribute of age. Departing from the conventional binary division based on age 35, we stratify age into four distinct groups, delineated at ages 20, 35, and 60. Other settings are the same as in the paper. We compare SoFCLR with the baseline of SogCLR.

Table 8: Experimental results on UTKFace predicting binary task label Gender and fairness metrics for multi-valued Age attribute.
Acc ΔED\Delta ED ΔEO\Delta EO ΔDP\Delta DP IntraAUC InterAUC GAUC WD KL
SogCLR 87.52 19.01 10.34 11.22 0.0910 0.0402 0.0515 0.1195 0.5832
SoFCLR 87.84 16.97 9.84 10.01 0.0880 0.0399 0.0473 0.1503 0.5173

The results are reported in Table 8. The first three and the last two fairness metrics are computed in a similar way. We take ΔED\Delta ED as an example. Given four groups, g0, g1, g2, g3. We compare pariwise fairness metrics and average over all pairs of values of the sensitive attribute, i.e, ΔED=(ΔED(g0,g1)+ΔED(g1,g2)+ΔED(g2,g3))/3\Delta ED=(\Delta ED(g0,g1)+\Delta ED(g1,g2)+\Delta ED(g2,g3))/3. For AUC fairness metrics, we convert it into four one-vs-all binary tasks and compuate averaged fairness metrics, g0 vs not g0 + g1 vs not g1, g2 vs not g2, g3 vs not g3).

D.3 Compare to VAE-based method.

We have compared with one VAE based method from Louizos et al., 2015 [23]. It is worth to mentioning that, no image-based VAE code was released by Louizos et al., 2015. For a fair comparison, we adopt a ResNet18-based Encoder-Decoder framework with an unsupervised setups where yy is unavailable and partial sensitive attribute aa is available. For the samples sensitive attribute labels are unavailable, we choose to train the model using the original VAE-type loss, specifically without the MMD (Maximum Mean Discrepancy) fairness regularizer. We training the losses for 100 epochs on the UTKface dataset, followed by conducting linear evaluations as in the paper.

Table 9: Experimental results on UTKFace of predicting Gender and fairness metric for Ethnicity sensitive attribute.
Acc ΔED\Delta ED ΔEO\Delta EO ΔDP\Delta DP IntraAUC InterAUC GAUC WD KL
Louizos et al., 2015 60.05 4.79 6.68 9.6 0.0237 0.0755 0.0182 0.003 0.0129
SoFCLR 84.42 13.02 13.23 13.00 0.0084 0.1013 0.1029 0.1195 0.1237

The results are reported on Table 9. We can see that the accuracy of the VAE-based method is much worse than our method by 24%. On the other hand, its fairness metrics are better. This is expected due to the tradeoff between accuracy and fairness.