This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Adversarial network training using higher-order moments in a modified Wasserstein distance

Oliver Serang
A-Alpha Bio, Seattle, WA, USA
[email protected]
Abstract

Generative-adversarial networks (GANs) have been used to produce data closely resembling example data in a compressed, latent space that is close to sufficient for reconstruction in the original vector space. The Wasserstein metric has been used as an alternative to binary cross-entropy, producing more numerically stable GANs with greater mode covering behavior. Here, a generalization of the Wasserstein distance, using higher-order moments than the mean, is derived. Training a GAN with this higher-order Wasserstein metric is demonstrated to exhibit superior performance, even when adjusted for slightly higher computational cost. This is illustrated generating synthetic antibody sequences.

1 Introduction

1.1 Generative-adversarial network

The generative-adversarial network (GAN) is a game-theoretic technique for generating values according to a latent distribution estimated on nn example data xn××ux\in\mathbb{R}^{n\times\ell\times u}.[1] GANs employ a generator, g:y×ug:\mathbb{R}^{y}\rightarrow\mathbb{R}^{\ell\times u}, which maps high-entropy inputs to an immitation datum; these high-entropy inputs y\in\mathbb{R}^{y} effectively determine a location in the latent space and are decoded to produces an immitation datum. GANs also employ a discriminator, d:×u[0,1]d:\mathbb{R}^{\ell\times u}\rightarrow[0,1], which is used to evaluate the plausibility that a datum is genuine. Generator and discriminator are trained in an adversarial manner, with the goal of reaching an equilibrium where both implicitly encode the distribution of real data in the latent space. If training is successful, X^=g(Z)\hat{X}=g\left(Z\right) (where Z𝒩(0,1)yZ\sim\mathcal{N}(0,1)^{y}) will produce data resembling a row of xx; d(k)d\left(\mathbb{R}^{k}\right) will correspond to the cumulative density in a unimodal latent space where the latent space density projects the empirical distribution of xix_{i}.

1.2 Cross-entropy loss

GANs are typically trained using a cross-entropy loss to optimize the parameters of both gg and dd, which measures the expected bits of surprise that samples from a foreground distribution would produce if they had been drawn from a background distribution. The parameters θg\theta_{g} are optimized to minimize the surprise of the Bernoulli distribution 1Σ,d(g(Z))1-\Sigma,d(g(Z)) given the background distribution 0,10,1 (i.e., minimizing the surprise from a background that scores d(X^)=1d(\hat{X})=1). The parameters θd\theta_{d} are optimized to minimize the surprise of the Bernoulli distribution 1Σ,d(x)1-\Sigma,d(x) given the background distribution with d(xi)=1d(x_{i})=1 and d(X^)=0d\left(\hat{X}\right)=0.

1.3 Wasserstein metric loss

When two distributions are highly dissimilar from one another, their support may be distinct such that cross-entropy becomes numerically unstable. This causes uninformative loss metrics: two distributions with non-overlapping support are quantified identically to two distributions whose supports are non-overlapping and very far from one another. These factors lead to poor training, particularly given that gg will initially produce noise, which will quite likely have poor overlap with real data in the latent space.

For this reason, Wasserstein distance was proposed to replace cross-entropy.[2] Wasserstein distance is the continuous version of the discrete earth-mover distance, which solves an optimal transport problem measuring the minimal movements in Euclidean distance that could be used to transform one probability density to another. Earth-mover distance is well defined, even when the two distributions have disjoint support. This avoids modal collapse while training.

If earth-mover distance is used to measure the distance between distributions pAp_{A} and pBp_{B}, then the set of candidate solutions γ\gamma will be functions with domain supp(pA)×supp(pB)\operatorname*{supp}(p_{A})\times\operatorname*{supp}(p_{B}) and where the marginals equal pAp_{A} and pBp_{B}. Thus, ΔEM(pA,pB)=infγΠ(pA,pB)𝔼a,bγab\Delta_{EM}(p_{A},p_{B})=\inf_{\gamma\in\Pi(p_{A},p_{B})}\mathbb{E}_{a,b\sim\gamma}\|a-b\|, where Π(pA,pB)\Pi(p_{A},p_{B}) is the set of distributions with marginals pA,pBp_{A},p_{B}.

The discrete formulation can be solved combinatorically via LP; however, the continuous formulation, Wasserstein distance, is computed via the Kantorovich-Rubinstein dual[3], which we show below.

ΔW(pA,pB)\displaystyle\Delta_{W}(p_{A},p_{B}) =\displaystyle= infγΠ(pA,pB)𝔼a,bγab\displaystyle\inf_{\gamma\in\Pi(p_{A},p_{B})}\mathbb{E}_{a,b\sim\gamma}\|a-b\|
=\displaystyle= infγ𝔼a,bγab+{0,γΠ(pA,pB),else.\displaystyle\inf_{\gamma}\mathbb{E}_{a,b\sim\gamma}\|a-b\|+\begin{cases}0,&\gamma\in\Pi(p_{A},p_{B})\\ \infty,&\text{else}.\end{cases}

The penalty term, here named λ(pA,pB,γ)\lambda(p_{A},p_{B},\gamma), can be recreated using an adversarial critic function, ff, which has a unitless codomain:

λ(pA,pB,γ)=supf𝔼apA[f(a)]𝔼bpB[f(b)]𝔼a,bγ[f(a)f(b)]={0,γΠ(pA,pB),else.\lambda(p_{A},p_{B},\gamma)=\\ \sup_{f}\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right]-\mathbb{E}_{a,b\sim\gamma}\left[f(a)-f(b)\right]=\\ \begin{cases}0,&\gamma\in\Pi(p_{A},p_{B})\\ \infty,&\text{else.}\end{cases}

λ(pA,pB,γ)=\lambda(p_{A},p_{B},\gamma)=\infty is achieved when γΠ(pA,pB)\gamma\not\in\Pi(p_{A},p_{B}) because ff can be made s.t., w.l.o.g., |f(a)|1|f(a)|\gg 1 at the value aa where pA(a)γ(a,b)bp_{A}(a)\neq\int_{\infty}^{\infty}\gamma(a,b)\partial b.

Thus,

ΔW(pA,pB)=infγsupf𝔼a,bγ[ab+f(b)f(a)]+𝔼apA[f(a)]𝔼bpB[f(b)].\Delta_{W}(p_{A},p_{B})=\\ \inf_{\gamma}\sup_{f}\mathbb{E}_{a,b\sim\gamma}\left[\|a-b\|+f(b)-f(a)\right]+\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right].

We can further reorder infγsupf\inf_{\gamma}\sup_{f} to supfinfγ\sup_{f}\inf_{\gamma}: For any function tt, h(β)=infat(α,β)h(\beta)=\inf_{a}t(\alpha,\beta), and δ=infαsupβt(α,β)=infαh(β)\delta=\inf_{\alpha}\sup_{\beta}t(\alpha,\beta)=\inf_{\alpha}h(\beta), and so α,h(β)t(α,β)\forall\alpha,h(\beta)\leq t(\alpha,\beta). Thus infαsupβt(α,β)infαsupβh(α)=infαδ=δ\inf_{\alpha}\sup_{\beta}t(\alpha,\beta)\geq\inf_{\alpha}\sup_{\beta}h(\alpha)=\inf_{\alpha}\delta=\delta (i.e., weak duality). Furthermore, if tt is convex in α\alpha and concave in β\beta, then the minimax principle yields infαsupβt(α,β)=supβinfαt(α,β)\inf_{\alpha}\sup_{\beta}t(\alpha,\beta)=\sup_{\beta}\inf_{\alpha}t(\alpha,\beta) (i.e., strong duality). Because ΔW\Delta_{W} is convex in γ\gamma (here manifest via convexity in a,ba,b) and concave in ff (manifest via concave uses of ff rather then concavity of ff itself), we have

ΔW(pA,pB)=supfinfγ𝔼a,bγ[ab+f(b)f(a)]+𝔼apA[f(a)]𝔼bpB[f(b)].=supf𝔼apA[f(a)]𝔼bpB[f(b)]+infγ𝔼a,bγ[ab+f(b)f(a)].\Delta_{W}(p_{A},p_{B})=\\ \sup_{f}\inf_{\gamma}\mathbb{E}_{a,b\sim\gamma}\left[\|a-b\|+f(b)-f(a)\right]+\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right].\\ =\sup_{f}\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right]+\inf_{\gamma}\mathbb{E}_{a,b\sim\gamma}\left[\|a-b\|+f(b)-f(a)\right].

infγ\inf_{\gamma} is achieved by concentrating the mass of γ\gamma where ab+f(b)f(a)<0\|a-b\|+f(b)-f(a)<0 and setting γ=0\gamma=0 wherever ab+f(b)f(a)0\|a-b\|+f(b)-f(a)\geq 0. Thus infγ𝔼a,bγ[ab+f(b)f(a)]0\inf_{\gamma}\mathbb{E}_{a,b\sim\gamma}\left[\|a-b\|+f(b)-f(a)\right]\leq 0. This constrains that where f(a)f(b)ab>1\frac{f(a)-f(b)}{\|a-b\|}>1, the dual penalty term will become -\infty, and so we need only consider ff s.t. f(a)f(b)ab1\frac{f(a)-f(b)}{\|a-b\|}\leq 1. This is equivalent to constraining ff s.t. all secants having a maximum slope 1\leq 1 (i.e., Lipschitz fL1\|f\|_{L}\leq 1) yields the weakest penalty, 0:

ΔW(pA,pB)=supf:fL1𝔼apA[f(a)]𝔼bpB[f(b)].\Delta_{W}(p_{A},p_{B})=\sup_{f:\|f\|_{L}\leq 1}\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right].

In WGAN training, our critic functions as ff, exploiting differences between real and generated sequences. The critic loss function is simply the difference between mean critic values of generated sequences minus mean critic values of real sequences; minimizing this loss will maximize discrimination, with real sequences awarded higher critic scores. With the goal of attaining Lipschitz continuity on ff, we constrain its parameters θf\theta_{f}, clipping them to small values [τ,τ]\in[-\tau,\tau] at the end of each batch step. A small enough τ\tau will ensure Lipschitz continuity for any finite network. Furthermore, fL11ζfLζ\|f\|_{L}\leq 1\leftrightarrow\frac{1}{\zeta}\|f\|_{L}\leq\zeta; therefore, τ\tau can be chosen rather arbitrarily as long as τ1\tau\ll 1, because the cone of functional solutions {f:fLζ1}\{f:\|f\|_{L}\leq\zeta\leq 1\} includes all nonnegative scales of functions for which fL1\|f\|_{L}\leq 1. Choice of τ\tau will influence the optimal choice of learning rate.

When training gg, the WGAN attempts to fool the critic and thus maximize the loss used by the critic ff. Thus, gg’s loss is the negative of ff’s loss. In practice, θg\theta_{g} does not influence critic values of real data f(xi)f(x_{i}), and so gg’s loss needs only be 𝔼[f(X^)]-\mathbb{E}\left[f(\hat{X})\right] to maximize critic scores of generated sequences.

All expectations are taken via Monte Carlo (i.e., by taking the mean of ff scores over each batch).

2 Methods

2.1 A WGAN using Wasserstein distance with higher moments

In this manuscript, we propose a modified WGAN, in which we consider other λ\lambda^{\prime} satisfying

λ(pA,pB,γ)={0,γΠ(pA,pB),else.\lambda^{\prime}(p_{A},p_{B},\gamma)=\begin{cases}0,&\gamma\in\Pi(p_{A},p_{B})\\ \infty,&\text{else.}\end{cases}

Wasserstein distance employs duality via an adversarial ff that concentrates where (w.l.o.g.) pA(a)γ(a,b)bp_{A}(a)\neq\int_{\infty}^{\infty}\gamma(a,b)\partial b:

λ(pA,pB,γ)=supf𝔼apA[f(a)]𝔼bpB[f(b)]𝔼a,bγ[f(a)f(b)].\lambda(p_{A},p_{B},\gamma)=\sup_{f}\mathbb{E}_{a^{\prime}\sim p_{A}}\left[f(a^{\prime})\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[f(b^{\prime})\right]-\mathbb{E}_{a,b\sim\gamma}\left[f(a)-f(b)\right].

This correspond to ff exploiting deviations in the first moment of ff under distribution pAp_{A} (w.l.o.g.).

Motivated by the method of moments, we consider the first mm moments, μ1,μ2,μm\mu_{1},\mu_{2},\ldots\mu_{m}. At WGAN convergence, deviations between pA,pBp_{A},p_{B} and marginals of γ\gamma should not be exploitable at any qq moment:

λq(pA,pB,γ)=supf𝔼apA[f(a)q]𝔼bpB[f(b)q]𝔼a,bγ[f(a)qf(b)q]=0.\lambda^{\prime}_{q}(p_{A},p_{B},\gamma)=\sup_{f}\mathbb{E}_{a^{\prime}\sim p_{A}}\left[{f(a^{\prime})}^{q}\right]-\mathbb{E}_{b^{\prime}\sim p_{B}}\left[{f(b^{\prime})}^{q}\right]-\mathbb{E}_{a,b\sim\gamma}\left[{f(a)}^{q}-{f(b)}^{q}\right]=0.

We continue using a signed deviation for the first moment (i.e., f(X^)f(xi)f(\hat{X})\leq f(x_{i})), but unsigned deviations for the remaining moments:

λ(pA,pB,γ)=λ1(pA,pB,γ)+j=2m|λj(pA,pB,γ)|.\lambda^{\prime}(p_{A},p_{B},\gamma)=\lambda^{\prime}_{1}(p_{A},p_{B},\gamma)+\sum_{j=2}^{m}|\lambda^{\prime}_{j}(p_{A},p_{B},\gamma)|.

Note that ff is still used in a convex manner, as its outputs are either unconstrained (in λ1\lambda_{1}) or within a bounded polytope λj>1\lambda^{\prime}_{j>1}; strong duality holds.

The same derivation holds under central moments, which are used in these experiments.

Lipschitz continuity under higher moments of ff is achieved by decreasing τ\tau s.t. |f|ϵ12|f|\leq\epsilon\ll\frac{1}{2}. In this case, both μ1\mu_{1} and the all central moments can be bounded: j>1,|fjμj|2ϵ1\forall j>1,|f^{j}-\mu_{j}|\leq 2\epsilon\ll 1, and thus j>1,f(a)jf(b)jab<f(a)f(b)ab<1\forall j>1,\frac{{f(a)}^{j}-{f(b)}^{j}}{\|a-b\|}<\frac{f(a)-f(b)}{\|a-b\|}<1. Thus Lipschitz continuity is ensured by the standard Wasserstein derivation.

Since ff concentrates at deviations between distributions, it should approach a Dirac delta before convergence if gg’s training lags the training of ff; thus, here we have not investigated using the higher moments informing λ\lambda^{\prime} when training ff; ff is trained using the same Wasserstein loss using λ\lambda. λ\lambda^{\prime} is used to train gg, which corresponds to replacing standard code

preds_neg = critic(ins_neg)
mean_neg = torch.mean(preds_neg)
loss_gen = -mean_neg

with new code

preds_neg = critic(ins_neg)
preds_pos = critic(ins_pos)
mean_pos = torch.mean(preds_pos)
mean_neg = torch.mean(preds_neg)
# signed difference between mu_1 values:
lambda_gen = mean_pos - mean_neg
for j in range(2, n_moments+1):
loss_gen += torch.abs(
# absolute difference between mu_j values:
torch.mean( (preds_neg-mean_neg)**j ) -
torch.mean( (preds_pos-mean_pos)**j )
)

This formulation allows learning from batches in gestalt. When the number of computed moments mm equals the batch size bb, there is sufficient information to recover the entire distribution; furthermore, even relatively few moments can accurately summarize the distribution in practice in a manner reminiscent of the fast multipole method[4, 5].

2.2 Impact on runtime

Considering higher moments at batch size bb results in a per-batch runtime Ω(b2)\in\Omega(b^{2}) or Ω(mb)\in\Omega(m\cdot b) if the layer of moments is used to separate the data from critic output. Also, the modified WGAN needs to compute critic scores f(xi)f(x_{i}) when training the generator. This increases computation cost.

Fractional moments can be informative and numerically stable; however, in the general case, they require arithmetic on complex numbers and may negatively influence performance.

3 Results

To benchmark WGAN training methodology, we train WGANs to output heavy chain antibody sequences. The overall scheme for this WGAN is heavily inspired by the seminal neural network-based antibody sequence design work of Tileli et al.[6]; furthermore, our model architecture is inspired by the multi-layer convolutional network from that work.

3.1 Experimental setup

3.1.1 Sequence data

Heavy chain sequence examples from from Observed Antibody Space[7, 8] are filtered for outliers based on sequence length and sampled to 2.5×1052.5\times 10^{5} sequences. Note that sequences are not embedded using a multiple sequence alignment; instead, every sequence is appended with starting and ending characters ^ and $ and then padded with $ so that all sequences have the same length. Sequences are embedded via one-hot embedding.

3.1.2 Critic and generator model architectures

Critic: The critic is constructed of two 2D convolutional layers. For simplicity, 2D convolution is performed with padding such that there is no movement of the kernel over the axis labeling amino acids in the one-hot embedding and the input padded sequence length \ell equals the length of the convolved vector. In this manner, each of the cc channels of the first 2D convolution essentially passes a PSSM with uu embedding characters over the sequence. These cc channels are transposed to view them as a single matrix of one channel with an alternate embedding with cc characters. This is again nonlinearized with leaky ReLU, and 2D convolved again to produce a single channel output. This is now akin to using a PSSM on kk-mer motifs (using an alphabet of cc possible motifs) rather than on amino acids, which is in turn equivalent to inferring an order-k1k-1 Markov model. Padding is performed in the same manner as in the first 2D convolution; the output is a vector of the same length as the original amino acid sequence. This vector is condensed to a single value via feedforward layers: each collection has linear layers that halve the number of nodes followed by leaky ReLU transfer function to allow nonlinearity. Note that 2D convolution here is equivalent to several channels of 1D convolution and can be implemented as such.

Generator: The generator is nearly identical to the critic in reverse. Thinking of the critic and generator as two halves of an autoencoder, inverting the critic’s compression to lower-dimensional latent space, deconvolution would be desired; however, deconvolution is a form of convolution (but with a kernel whose values have been multaplicatively inverted in the frequency domain) as shown by the convolution theorem.[9]

A standard normal noise vector ZZ inflates to a vector with length u\ell\cdot u where uu is the size of the alphabet used for one-hot embedding. A leaky ReLU is used to permit nonlinearity. The vector is then viewed as a matrix ×c\in\mathbb{R}^{\ell\times c} matrix and is convolved in 2D to produce cc channels of output (padding in the 2D convolution matches the approach used in the critic). This output is transposed to be viewed as a single matrix of one channel with an alternate embedding in cc new characters. Nonlinearity is again induced with leaky ReLU. The matrix is then convolved 2D again with the same padding strategy and uu channels out and transposed to form a matrix of one channel and uu characters embedded. This is nonlinearized with leaky ReLU. Note that this matrix is of the same shape as used by the sequence embedding. Softmax is then applied to the character embedding axis, forcing it into an embedding that resembles a one-hot.

All leaky ReLUs have negative slope 0.2. Layers are delimited with dropout 0.1 during training, but not during evaluation.

3.1.3 Evaluation

After each epoch, quality of x^\hat{x} are evaluated using KL divergence of the categorical distributions of 6-mer sequences given the 6-mer sequence distribution from the 2.5×1052.5\times 10^{5} heavy chain antibody sequences and 2×1042\times 10^{4} sequences sampled from gg. For numeric stability, KL divergence is computed using a pseudocount of 101010^{-10} added to values not in the background distribution’s support.

3.1.4 Hyperparameters

Learning rate and batch size are chosen by training a standard WGAN network is trained for several replicates with various learning rates {0.1,0.01,0.001,0.0001}\in\{0.1,0.01,0.001,0.0001\} and batch sizes b{64,128,256,512}b\in\{64,128,256,512\}. The learning rate that produced the best 6-mer KL divergence is 0.001. b=128b=128 yielded the best 6-mer KL divergence while still maintaining >75%>75\% GPU usage with nvidia-smi. These hyperparameters are used for training the WGANs using higher-order moments.

3.1.5 Reproducibility

The random seeds 0,1,40,1,\ldots 4 are used for replicate experiments. This includes seeding random, numpy.random, torch, and torch.utils.data.DataLoader. torch.use_deterministic_algorithms(True) is used, along with the accompanying recommended environment variable set by export CUBLAS_WORKSPACE_CONFIG=:4096:8.

3.1.6 Training details

Models are instantiated and trained with pytorch 1.10. A shuffled DataLoader with 8 worker threads and pinned memory is used in training. In each epoch, dd is trained on each batch and gg is trained on 15\frac{1}{5} of batches to avoid adjusting gθg_{\theta} with improper guidance from an uninformed critic.

Adaptive moment estimation (Adam) is used for gradient descent.[10]

Benchmarks are performed on AWS g4dn.8xlarge instance using a single Nvidia T4. Storage IOPS and throughput are maximized.

3.2 Influence of higher moments on WGAN performance

Data are produced using 5 replicate trials of 200 epochs. For fairness, each loss function investigated used every random seed 0,1,4\in 0,1,\ldots 4.

Figure 1 illustrates the relationship between sequence quality produced by gg at each epoch and the loss function used.

Figure 2 illustrates the relationship between sequence quality produced by gg and the loss function value for each loss function used. Note that for any m>1m>1, generator loss uses penalty term λλ\lambda^{\prime}\geq\lambda, and so a small m>1m>1 loss function necessarily implies a small loss function using a standard WGAN.

Refer to caption
Figure 1: Sequence quality during training. A standard WGAN (with loss using only the first moment μ1\mu_{1}) is compared to WGANs using further central moments m4m\leq 4. Each scatter data point represents an estimate of the post-epoch 6-mer KL divergence on 2×1042\times 10^{4} sequences generated by gg. Five replicate experiments are performed for each model, each is fit with a curve, and the aggregate of all curves is plotted. For reference, the dashed line plots the standard WGAN shifted 25 epochs early.
Refer to caption Refer to caption
(a) (b)
Figure 2: Quality of generator loss function throughout training. A standard WGAN (with loss using only the first moment μ1\mu_{1}) is compared to WGANs using further central moments m{1,2,3,4}m\in\{1,2,3,4\} using data obtained from Figure 1. The correspondence between the loss functions and the 6-mer KL divergence quantify the quality of the loss function as a surrogate for optimizing sequence quality. Critic loss is standard throughout, but is affected by generator training; generator loss varies with mm. Loss functions are computed during training, and so include dropout and other stochastic effects.
mm Mean runtime (s) KL & crit. loss ρ\rho KL & gen. loss ρ\rho
1 2917.66 0.7059 0.4169
2 3079.00 0.8961 0.9257
3 3097.12 0.8388 0.8722
4 3110.64 0.8921 0.9205
Table 1: Runtime and quality of various generator loss functions. Standard WGAN loss with m=1m=1 is compared to losses with higher moments m>1m>1. Total 200 epoch runtime is averaged over five replicates. The Spearman rank correlation coefficient ρ\rho between each loss function and the 6-mer KL divergence is also displayed. Critic loss is standard throughout, but is affected by generator training. Generator loss varies m{1,2,3,4}m\in\{1,2,3,4\}. Loss functions are computed during training, and so include dropout and other stochastic effects.

4 Discussion

Figure 1 demonstrates that early in training, the standard WGAN exhibits superior performance; however, later on, using higher moments results in benefit to sequence quality, specifically for the m=2m=2 and m=4m=4 loss functions. This is shown to be rougly equivalent to gaining a 25 epoch advantage.

Figure 2 and Table 1 demonstrate a greater correspondence between sequence quality and loss functions with higher moments.

Qualitatively, using higher moments incentivizes optimizing batches as a whole. One way that this may manifest is by improving batch diversity of x^\hat{x} to better match that of xx, thereby reducing modal collapse. For early epochs, this could explain the slightly poorer performance, as these m>1m>1 loss functions will initially be less seeking of a dominant nearby mode.

Interestingly, m=3m=3 did not perform as well as m=2m=2 and m=4m=4. This could be because training the critic inherently drives ff toward a high-concentration similar to a Dirac delta. While the first moment μ1\mu_{1} is informative and even moments describe spread (variance μ2\mu_{2} quantifies spread, excess kurtosis μ4\mu_{4} quantifies modality near μ1\mu_{1}), the skew, μ3\mu_{3}, informs of direction but in a way that may here be less useful or numerically stable than simply using μ1\mu_{1}.

Using higher moments increased runtimes, but not substantially. Training modified WGAN with m=2m=2 moments in gg’s loss required <5.6%<5.6\% more runtime than the standard WGAN. At a cost of $2.176 per hour[11], this corresponds to a cost of $1.76 per replicate of the standard WGAN, and less than $0.10 more expensive to train the m=2m=2 variant; however, the m=2m=2 variant reaches comparable convergence in 75% of the training, and thus would cost roughly $1.63 per replicate. For larger data and more stringent convergence criteria, the exponentially decaying gain in sequence quality by training for further epochs suggests that this 25 epoch advantage demonstrated by the m>1m>1 variants would produce benefits far more dramatic. Furthermore, it is likely the benefits illustrated here would be stronger with more replicate experiments.

It is possible that deviations for different moments should receive their own weighting in computing the loss function. In this manner, it may be desirable to perform batch-based discrimination, where each batch is reduced to its constituent moments, and then a critic f(μ1,μ2,μm)f^{\prime}(\mu_{1},\mu_{2},\ldots\mu_{m}) is computed on the moments μ1,μ2,μm\mu_{1},\mu_{2},\ldots\mu_{m} of critic values ff for the batch. Parameters θf\theta_{f^{\prime}} could be learned and clamped during training to ensure fL1\|f^{\prime}\|_{L}\leq 1.

5 Conclusion

Here we have shown that viewing distributions with several moments rather than only using the first moment, 𝔼[f]\mathbb{E}[f], improves WGAN training. We could also easily train the critic ff using this strategy.

6 Acknowledgements

Thank you to Ryan Emerson and Randolph Lopez for the scientific discussion, James Harrang for the helpful comments, and to the entire A-Alpha Bio team.

7 Declarations

7.1 Conflicts of interest

O.S. is an employee of A-Alpha Bio and owns stock options in the company.

References

  • [1] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. Advances in neural information processing systems, 27, 2014.
  • [2] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
  • [3] John Thickstun. Kantorovich-Rubinstein Duality, 2019.
  • [4] Leslie Greengard and Vladimir Rokhlin. A fast algorithm for particle simulations. Journal of computational physics, 73(2):325–348, 1987.
  • [5] Julianus Pfeuffer and Oliver Serang. A bounded p-norm approximation of max-convolution for sub-quadratic bayesian inference on additive factors. The Journal of Machine Learning Research, 17(1):1247–1285, 2016.
  • [6] Tileli Amimeur, Jeremy M Shaver, Randal R Ketchem, J Alex Taylor, Rutilio H Clark, Josh Smith, Danielle Van Citters, Christine C Siska, Pauline Smidt, Megan Sprague, et al. Designing feature-controlled humanoid antibody discovery libraries using generative adversarial networks. BioRxiv, 2020.
  • [7] Aleksandr Kovaltsuk, Jinwoo Leem, Sebastian Kelm, James Snowden, Charlotte M Deane, and Konrad Krawczyk. Observed antibody space: a resource for data mining next-generation sequencing of antibody repertoires. The Journal of Immunology, 201(8):2502–2509, 2018.
  • [8] Tobias H Olsen, Fergus Boyles, and Charlotte M Deane. Observed antibody space: A diverse database of cleaned, annotated, and translated unpaired and paired antibody sequences. Protein Science, 31(1):141–146, 2022.
  • [9] John G Proakis. Digital signal processing: principles algorithms and applications. Pearson Education, 2001.
  • [10] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • [11] Amazon EC2 G4 Instances, 8 2022. Archived at https://web.archive.org/web/20220809081441/https://aws.amazon.com/ec2/instance-types/g4/.