This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

DEMix Layers: Disentangling Domains for Modular Language Modeling

Suchin Gururangan†♢  Mike LewisAri Holtzman  Noah A. Smith†♠  Luke Zettlemoyer†♢
Paul G. Allen School of Computer Science & Engineering, University of Washington
Allen Institute for AI
Facebook AI Research
Seattle, WA, USA
[email protected]
Abstract

We introduce a new domain expert mixture (DEMix) layer that enables conditioning a language model (LM) on the domain of the input text. A DEMix layer is a collection of expert feedforward networks, each specialized to a domain, that makes the LM modular: experts can be mixed, added or removed after initial training. Extensive experiments with autoregressive transformer LMs (up to 1.3B parameters) show that DEMix layers reduce test-time perplexity, increase training efficiency, and enable rapid adaptation with little overhead. We show that mixing experts during inference, using a parameter-free weighted ensemble, allows the model to better generalize to heterogeneous or unseen domains. We also show that experts can be added to iteratively incorporate new domains without forgetting older ones, and that experts can be removed to restrict access to unwanted domains, without additional training. Overall, these results demonstrate benefits of explicitly conditioning on textual domains during language modeling.

1 Introduction

Refer to caption
Figure 1: Illustration of a DEMix layer in a single transformer block. During training, expert feedforward networks are conditionally activated based on the domain (here, document provenance) of the input sequence (i.e., scientific papers or court opinions). At inference time, the language model has new modular functions: domain experts can be mixed to handle heterogeneous domains, added to adapt to novel domains, or removed to “forget” unwanted domains. Image attribution: news icon from emojipedia.org; all other icons from istockphoto.com.

Conventional language model (LM) training algorithms assume data homogeneity: all parameters are updated to minimize the loss on all of the data. We refer to this approach as dense training. Yet human language is as varied as human experience, a fact researchers often refer to obliquely when they use the term domain to describe distinct underlying subpopulations of the corpus. Dense training leaves variation in the data to be implicitly discovered (Aharoni and Goldberg, 2020), assuming that models will be able to fit all domains equally well.

While dense training is convenient, and densely trained LMs achieve impressive results (Brown et al., 2020), the approach has drawbacks with respect to generalization, efficiency, and flexibility. Even if training data is sourced from many domains, dense training can in practice emphasize subsets of the data in proportion to their ease of access (Oren et al., 2019; Fan et al., 2020), limiting generalization to less prevalent domains. Updating all parameters of the network gets substantially more expensive as model size grows (Strubell et al., 2019), making fine-tuning or domain-adaptive pretraining (DAPT; Gururangan et al., 2020) harder to perform with smaller computational budgets. It is also difficult to adapt to new domains without forgetting the original data (McCloskey and Cohen, 1989; Aghajanyan et al., 2021) or restrict access to certain domains the LM has been exposed to during training (e.g., those that contain hate speech; Bender et al. 2021), leading to risks of unwanted behavior (Gehman et al., 2020).

To address these limitations of dense training, we argue that LMs should be designed with modularity. We propose a modular LM that has components specialized to distinct domains in the training data, and can be customized at inference-time by mixing, adding, or removing these separated components as needed. This design principle emphasizes the ability to rapidly adapt the LM after training, a need that has been broadly advocated for language systems (Dinan et al., 2021; Lazaridou et al., 2021).

We introduce modularity into an LM with a new domain expert (DEMix) layer that explicitly conditions the LM on the domain of the input text (when it is known), or estimates the input domain during inference (when it is not known). A DEMix layer is a drop-in substitute for a feedforward layer in a transformer LM (e.g., GPT-3), creating a specialized version of the layer (or expert) per domain (see Figure 1; §3).111This is an example of conditional computation (Fedus et al., 2021; Lepikhin et al., 2020; Lewis et al., 2021; Roller et al., 2021), which follow prior literature on mixture of experts Jacobs et al. (1991); Shazeer et al. (2017). Unlike dense training, conditional computation activates different parameters for different inputs. Instead of learning how to route data to experts, the DEMix layer routing mechanism follows from a natural, observable segmentation of the data. We find that replacing every feedforward layer in the transformer with a DEMix layer offers new affordances for modularity, addressing the challenges above, while improving performance in both training domains and novel test-time domains.

Although the concept of a domain lacks a rigorous definition in NLP, we use coarse provenance categories (e.g., whether a document is a medical research paper or a Reddit post) as a conditioning variable when training an LM with DEMix layers (§2). Training on data from eight different domains, we find that DEMix layers consistently improve in-domain performance (§4). However, because these categories may not be an optimal segmentation of the training data, or may lack coverage of test-time domains, naively selecting a single domain expert at test time can hurt generalization. Instead, we introduce a parameter-free probabilistic approach to dynamically estimate a weighted mixture of domains during inference (§5). Mixing experts improves DEMix performance not only on novel test-time domains, but also on test data from the training domains, which may themselves be heterogeneous. Our results suggest that introducing modularity into an LM need not come at a cost to generalization performance.

Because DEMix forces experts to specialize to domains, the overall model can be (partially) disentangled after training. Beyond mixing, we can add (§6) or remove (§7) domain experts, resulting in predictable changes in model behavior at inference time: adding experts allows for model adaptation without updating all parameters (hence avoiding forgetting), and removing experts allows for simulating the removal of training domains without additional training. Overall, DEMix layers demonstrate benefits of explicitly conditioning on textual domains during language modeling, and our results suggest that these benefits persist at scale. Our code is publicly available.222http://github.com/kernelmachine/demix

Domain Corpus # Train (Eval.) Tokens
Training 1B 30M NewsWire sentences (Chelba et al., 2014) 700M (10M)
CS 1.89M full-text CS papers from S2ORC (Lo et al., 2020) 4.5B (10M)
Legal 2.22M U.S. court opinions, 1658 to 2018 (Caselaw Access Project, 2018) 10.5B (10M)
Med 3.2M full-text medical papers from S2ORC (Lo et al., 2020) 9.5B (10M)
WebText 8M Web documents (Gokaslan and Cohen, 2019) 6.5B (10M)
RealNews 35M articles from RealNews Zellers et al. (2019) 15B (10M)
Reddit Reddit comments from pushshift.io (Baumgartner et al., 2020) 25B (10M)
Reviews 30M Amazon product reviews (Ni et al., 2019) 2.1B (10M)
Total    73.8B (80M)
Domain Corpus # Train (Eval.) Tokens
Novel ACL Papers 1.5K NLP papers from ACL (Dasigi et al., 2021) 1M (1M)
Breaking News 20K latest articles from 400 English news sites (Baly et al., 2018) 11M (1M)
Contracts 500 commercial legal contracts (Hendrycks et al., 2021) 1.5M (1M)
CORD-19 400K excerpts from COVID-19 research papers (Wang et al., 2020) 60M (10M)
Github 230K public Github repository contents (Github Archive Project, ) 200M (10M)
Gutenberg 3.2M copyright-expired books (Project Gutenberg, ) 3B (10M)
Tweets 1M English tweets from 2013-2018 8M (1M)
Yelp Reviews 6M Yelp restaurant reviews (Yelp Reviews, ) 600M (10M)
Table 1: Domains that make up our multi-domain training corpus, including the size of our training and evaluation (i.e. validation and test) data, in whitespace-separated tokens. \dagger indicates datasets that we (partially) anonymize (§2). Reddit was extracted and obtained by a third party and made available on pushshift.io, and was anonymized by Xu et al. (2020); we use their version. See Appendix §A.1 for more details on how these data were collected.

2 Multi-Domain Corpus

We center this study around a large, multi-domain corpus we constructed with explicit provenance metadata (Table 1). While other multi-domain corpora (Koh et al., 2021; Gao et al., 2020) cover many more domains and tasks, the corpus we introduce contains substantial metadata-tagged text for language modeling, as well as datasets with friendly licensing to support reproducibility.

2.1 Document Provenance as a Domain Label

While a growing body of work has attempted to address the structure and composition of language domains (Eisenstein et al., 2014; Plank, 2016; Aharoni and Goldberg, 2020; Gururangan et al., 2020), fundamentally what a domain is remains a matter of debate. In this work, we focus on the provenance of a document, operationalized coarsely by the dataset we used to access it, which approximates a social process that produced it. Defining domains this way is easy and intuitive, conveys a great deal about the variation in a document’s language, and aligns with common practice in NLP research. However, other accounts of variation in language (e.g., Lucy and Bamman, 2021), and richer notions of relationships among domains (e.g., hierarchies; Gururangan et al., 2020), may be studied in future work.

2.2 Corpus Description

The multi-domain corpus we use in this study consists of two parts. The first is a collection of training domains: text from eight domains of largely English text, listed at the top of Table 1, each of which vary in complexity and coverage and has been the subject of study in NLP.333The metadata for each document includes at least its provenance, and in some cases more information (e.g., URLs, publication venue, or legal jurisdiction). Future work might explore more fine-grained notions of domain.

The second part is a collection of novel domains: text from eight domains also of largely English text, listed at the bottom of Table 1, which may or may not align with the training domains. The novel domains allow us to measure how models generalize to a more challenging data distribution shift, where domain boundaries may be less clear.

See Appendix §A.1 for more details on how these data were collected. To support future work with the data, we also release a standard API to download and preprocess it into a format compatible with Fairseq (Ott et al., 2019).444https://github.com/kernelmachine/demix-data We replace user identifiable information (e.g., email addresses, user handles, social security numbers, credit card numbers, phone numbers) with dummy tokens.555While it is difficult to anonymize data perfectly, especially at scale, we use a suite of regexes to identify commonly occurring identifiable information on the Internet. See Appendix §A.2 for more details.

3 DEMix Layer

3.1 Background: Mixture-of-Experts Transformers

The transformer architecture is comprised of interleaved multi-head self-attention, layer-norms, and feedforward networks (Vaswani et al., 2017). Each of these layers produces a vector representation for each of the input tokens. Our focus is on the feedforward component:

𝐡t,=FFN(𝐡t,1),\displaystyle\mathbf{h}_{t,\ell}=\mathrm{FFN}(\mathbf{h}_{t,\ell-1}), (1)

where 𝐡t,\mathbf{h}_{t,\ell} is the vector for the ttth token produced by layer \ell.

Shazeer et al. (2017) propose a formulation of one or more feedforward layers as an ensemble of nn experts FFN1,,FFNn\mathrm{FFN}_{1},\ldots,\mathrm{FFN}_{n}, assigned weights respectively by functions g1,,gng_{1},\ldots,g_{n}:

FFN(𝐡t,1)\displaystyle\mathrm{FFN}(\mathbf{h}_{t,\ell-1}) =j=1ngj(𝐡t,1)FFNj(𝐡t,1)\displaystyle=\sum_{j=1}^{n}g_{j}(\mathbf{h}_{t,\ell-1})\cdot\mathrm{FFN}_{j}(\mathbf{h}_{t,\ell-1}) (2)

The gg function routes tokens to different experts, usually each a separate instance of the original feedforward network. If gg routes to a single expert, then the computational cost (in floating-point operations; FLOPs) will be same as the original feedforward network, even though it has slightly more than nn times as many parameters.

3.2 DEMix Routing

Previous approaches learn the weighting functions gg at a token-level, and either assign at most one (Fedus et al., 2021) or two (Lepikhin et al., 2020) experts per token. This necessitates load balancing and other techniques to encourage the model to use all experts instead of relying on just a few (Fedus et al., 2021; Lewis et al., 2021).

We instead use domain metadata provided with training documents to route data to experts at the document (i.e., sequence) level. During training, every token in the same sequence is assigned to the same expert based on the domain label.

Let 𝒟\mathcal{D} denote the set of domain labels (i.e., the eight labels in Table 1). If we index the experts by 𝒟\mathcal{D} and d𝒟d\in\mathcal{D} is the domain label for the current training instance, then

gj(𝐡t,)\displaystyle g_{j}(\mathbf{h}_{t,\ell}) ={1if j=d0otherwise\displaystyle=\left\{\begin{array}[]{ll}1&\mbox{if $j=d$}\\ 0&\mbox{otherwise}\end{array}\right. (5)

While we assume that each training document is associated with a single domain label, we relax this requirement at inference time (§5), which improves model performance in mixed and unknown domain scenarios.

3.3 DEMix Architecture

Our design results in one expert in a DEMix layer per domain (i.e., eight experts for eight training domains in our multi-domain corpus).

We replace every feedforward layer in the transformer with a DEMix layer, in contrast to previous work (Fedus et al., 2021; Lepikhin et al., 2020) that interleaves shared and expert layers. Preliminary experiments showed that interleaving led to worse in-domain performance with DEMix layers. We hypothesize that shared layers may serve as a bottleneck to find shared features between domains, and may impact performance adversely when training domains are highly different from one another.666Indeed, preliminary experiments suggest that interleaving expert layers causes large performance hits in the most distinct domains, i.e., those with lower vocabulary overlap with other domains in the corpus. Future work might perform careful comparisons of different architectural choices.

In this study, each expert FFNj\mathrm{FFN}_{j} is a two-layer MLP with the same dimensions as the original FFN\mathrm{FFN} layer of the transformer. As with other conditional computation models (Fedus et al., 2021; Lepikhin et al., 2020), this means that the effective number of parameters in the overall DEMix LM increases (Table 2). While this incurs memory costs, the computational budget we consider in this study centers around runtime costs. DEMix layers decrease the runtime costs of training the LM.

3.4 DEMix Training

DEMix layers increase the total parameters of the LM while also reducing GPU latency costs during training, effectively reducing runtime costs of training the LM.

Dense training (also referred to as data-parallel) is usually implemented by copying model parameters to every GPU, feeding a different mini-batch of shuffled data to each GPU, computing a stochastic gradient for each mini-batch, and updating all parameters synchronously with the average stochastic gradient from across all GPUs.

To train an LM with DEMix layers, we instead partition the GPUs among the domains, so that each GPU is assigned a single domain (along with its corresponding expert). During training, we fill a mini-batch with kk sequences, where each sequence represents data from a particular domain, and we send each mini-batch to its dedicated domain expert. We use larger batch sizes by performing data-parallel training between expert parameters on GPUs assigned to the same domain; we assign nn/8 GPUs to each domain (Table 2). To reduce overfitting, we ensure that each of these nn/8 GPUs is assigned to different shards of their domain’s training data.

We compare the training efficiency of Dense and DEMix models up to 1.3B parameters per GPU in Table 2. Compared to Dense LMs, DEMix layers achieve the same or slightly higher throughput (measured in TFLOPs/GPU) for the same total FLOPs per update, despite adding significantly more parameters.

DEMix achieves higher throughput because we only synchronize expert parameters allocated to the same domain.777Shared parameters are synchronized across all GPUs. As we increase model size, this results in a reduction of latency costs between GPUs, and hence, faster training; instead of synchronizing parameters over nn GPUs, we perform eight synchronizations over nn/8 GPUs.888While this technique reduces latency costs, the bandwidth costs are the same between DEMix and Dense models.

In this work, we assume that there is sufficient data for each training domain that each expert can be exposed to the same amount of data, and load balancing between experts is not necessary. Future work may consider how varying the amount of data per domain influences absolute and relative performance across domains, especially in the long tail of rare domains.

While the total number of parameters of DEMix LMs are substantially larger than their Dense counterparts, since the practical training costs are essentially the same, we compare baselines in all subsequent experiments based on parameters per GPU, as we do in Table 2.

4 In-Domain Performance

Parameters per GPU
125M 350M 760M 1.3B
Dense GPUs 32 64 128 128
Total Experts 0 0 0 0
GPUs/expert 0 0 0 0
Total params 125M 350M 760M 1.3B
TFLOPs/update 556 3279 13,637 23,250
TFLOPs/GPU 31 37 45 51
DEMix GPUs 32 64 128 128
Total Experts 8 8 8 8
GPUs/expert 4 8 16 16
Total params 512M 1.8B 3.8B 7.0B
TFLOPs/update 556 3279 13,637 23,250
TFLOPs/GPU 31 37 48 55
Table 2: Our specifications for training Dense and DEMix LMs. All models are trained for about 48 hours on V100 GPUs. DEMix layers increase the total parameters of the LM while maintaining (or increasing) throughput, measured in TFLOPs/GPU. We use the formula described in Narayanan et al. (2021) to calculate these metrics. See Appendix §A.3 for more details.

The first set of experiments in this study considers the impact of replacing the conventional feedforward layers in a transformer LM with DEMix layers. We run all experiments in this section with the training domains (Table 1).

4.1 Experimental Setup

Architecture and Input

The model architecture is a randomly-initialized LM with the GPT-3 (Brown et al., 2020) architecture implemented in Fairseq (Ott et al., 2019). We experiment with multiple architectures (i.e., those of GPT-3 small, medium, large, and XL), at a maximum size of about 1.3B parameters per GPU. We use the GPT-2 (Radford et al., 2019) vocabulary of 50,264 BPE types, and train with 1,024-token sequences, with cross-document boundaries. Each document has a beginning-of-sentence token prepended to it.

Hyperparameters

We set the total number of training steps based on this allocated runtime, set 8% of these steps to be warm-up, and use the Adam optimizer (Kingma and Ba, 2017) with a polynomial learning rate decay. Learning rates are tuned for each model separately over {0.0001, 0.0003, 0.0005}, taking the fastest learning rate that avoids divergence. Each worker processes two sequences of length 1,024, and gradients are accumulated over 8 updates. We clip gradients if their L2L_{2} norm exceeds 0.1. See Appendix §A.4 for more details. These settings are inspired by Lewis et al. (2021).

Computational Budget

We follow previous work in using runtime as the primary computational budget, which provides a better comparison of the practical costs of training conditional compute and dense models (Lewis et al., 2021). We assume a fixed budget of about 48 hours on NVIDIA V100 32GB GPUs. We display the number of GPUs used for each model size in Table 2; we chose these GPU budgets because larger models require more compute to train properly (Lewis et al., 2021; Kaplan et al., 2020), and found these GPU budgets to result in stable training for each model size given mostly fixed hyperparameters.

Evaluation

We report test-set perplexities after about 48 hours of training. In all tables, we report each result with respect to a set number of parameters per GPU, as in Table 2. As mentioned in §3.4, DEMix LM will have a larger effective size than the Dense LM at the same increased throughput.

Parameters per GPU
125M 350M 760M 1.3B
Dense 20.6 16.5 14.5 13.8
Dense (Balanced) 19.9 15.8 14.3 13.6
+Domain-Token 19.2 15.9 14.3 13.4
DEMix (naive) 18.4 15.5 14.2 13.8
DEMix (cached; §5.4) 17.8 14.7 13.9 13.4
Table 3: Average of in-domain test-set perplexity. We discuss the last row in §5.4.

4.2 Compared Models

Dense

The first baseline is a Dense model that treats the data as homogeneous, i.e., it shares all parameters across all domains. Under this setup, the language model parameters are copied across all GPUs, and gradients computed during training are all-reduced across every GPU. There is no explicit conditioning on domain.

Dense (Balanced)

Under this setting, we train densely but ensure that the model is exposed to an equal amount of data from each domain. While there is still no explicit conditioning on domain, the gradient updates that the model makes during training are an average of those computed across all domains represented in a batch.

+Domain-Token

This model is trained identically to Dense (Balanced), but we prepend a token indicating the sequence’s domain to every sequence block (during training and test time). A variant of this domain token is explored in some previous studies (Zellers et al., 2019; Keskar et al., 2019). This baseline provides domain information to the language model in the form of input supervision. We ignore the domain token when computing perplexity during evaluation.

DEMix (naive)

We replace every feedforward layer in the transformer with a DEMix layer, as detailed in §3. Under this setting, the domain of the test data is known and revealed to the model (e.g., the CS expert is used for CS test data), which we refer to as naive. We also ensure that the model is exposed to an equal amount of data from each domain.

1.3B parameters per GPU
Domain Dense DEMix DEMix
(naive) (cached prior; §5.4)
1B 11.8 11.5 11.3
CS 13.5 12.2 12.1
Legal 6.8 6.7 6.7
Med 9.5 9.2 9.1
WebText 13.8 14.6 14.3
RealNews 12.5 13.3 13.1
Reddit 28.4 30.6 28.1
Reviews 14.0 12.6 12.5
Average 13.8 13.8 13.4
Table 4: Test-set perplexity by domain, for an LM with 1.3B parameters per GPU. We discuss the last column in §5.4.

4.3 Results

Table 3 shows test-set perplexities, averaged across the eight training domains. First, we observe that domain balancing is consistently helpful for Dense training. We find that balancing is especially important in cases in which there is an imbalance of domain prevalence, confirming similar observations from previous studies (Arivazhagan et al., 2019).999Balancing improves performance on most domains, but hurts performance relative to a Dense baseline on the Reddit domain (Appendix §A.5). In the multi-domain corpus, there is far more Reddit text than anything else; see Table 1.

Next, we observe that the benefits of additional domain information (i.e, domain tokens or DEMix layers) are clearest for the smallest model; for larger models, the benefits are smaller but consistent. This result suggests that domain-specific information enables the model to better specialize to different domains in its training data. However, as the model size grows, the Dense baseline becomes increasingly better at fitting the training domains, catching up to models with additional domain information, in the average case.

4.4 Domain Hetereogeneity

A more complete view of the experiments with the largest model is shown in Table 4. We see that even at scale, most training domains benefit from DEMix layers in a naive setting (where the domain label is revealed at test time), but some do not; WebText, RealNews, and Reddit fare worse than the Dense baseline. We believe that this variation can be explained by heterogeneity within domains and varying degrees of similarity between them. Dense training may be advantageous for domains that have a higher degree of overlap with other domains in the corpus (and therefore, benefit from parameter sharing).

Refer to caption
Figure 2: Domain experts in DEMix specialize to their domain. We compute the above heatmap with a DEMix LM with 1.3B parameters per GPU. Each cell of the heatmap is a ratio between an expert’s test perplexity on a domain to that of the expert trained on that domain. The diagonal indicates that each expert has the best performance on its assigned domain. While some experts (e.g., 1B, Med) do not transfer well to most domains in the training corpus, WebText and RealNews experts transfer much better, confirming their heterogeneity. Key: LG \rightarrow Legal, MD \rightarrow Med, WT \rightarrow WebText, RN \rightarrow RealNews, RD \rightarrow Reddit, RV \rightarrow Reviews.

To provide further evidence for this explanation, we measure the hetereogeneity of domains in the multi-domain corpus, according to a DEMix LM. We plot a matrix of the perplexity changes across all domain experts in Figure 2, comparing all experts against the expert explicitly trained for each domain. As the perplexity change tends lower, the corresponding expert has higher affinity to the target domain.

First, we observe that domain experts have the highest affinity to their assigned domain, indicating that they do specialize. We also observe that some experts, e.g., WebText, RealNews, and Reddit, have relatively high affinities to many domains, suggesting that these domains are hetereogeneous. Separately we observe that an expert’s affinity to a domain correlates positively with bigram overlap between the expert domain and target domain (rr=0.40, tt=3.45, pp=0.001). This further suggests that similar domains have more closely aligned domain experts.

These findings suggest that a discrete notion of domain, while usually helpful on average (in our artificially constructed population of eight training domains), is too rigid. In the next section, we introduce new ways of softening Equation 5 into a mixture over domain experts, to improve performance on heterogeneous domains.

5 Mixing Experts at Inference Time

The previous section establishes that incorporating DEMix layers improves LM performance on test data from known training domains. At inference time, the domain label was revealed to the model and used to select an expert within each DEMix layer. In practice, however, text may not come with a domain label, may straddle multiple domains, or may not belong to any of the domains constructed at training time; the provenance of the data may even be unknown.

In these cases, rather than a hard choice among experts (Equation 5), we propose to treat g1,,gng_{1},\ldots,g_{n} as mixture coefficients, transforming the domain membership of an input text into a matter of probabilistic belief. Unlike previously proposed mixture-of-experts formulations (Shazeer et al., 2017; Lepikhin et al., 2020), this approach introduces no new parameters and the weights are computed only at test time.101010We choose to explore inference-time mechanisms instead of training mechanisms to mix experts because 1) we want to avoid substantially increasing training costs, i.e., GPU communication between domain experts and 2) we want to maintain the modularity of experts. Exploring mechanisms for training expert mixtures while satisfying these desiderata is a rich area for future work.

To analyze inference-time behavior in mixed or unknown domain scenarios, we turn to the corpus of novel domains in the multi-domain corpus (Table 1). As mentioned in §2, these domains have fuzzier boundaries, compared to the training domains.

Refer to caption
Figure 3: Illustration of inference with domain expert mixing. For a given input text 𝒙<t\boldsymbol{x}_{<t} from CORD-19, we estimate a posterior domain probabilities p(Dt𝒙<t)p(D_{t}\mid\boldsymbol{x}_{<t}), informed by a prior that is either iteratively updated during inference, or is precomputed and cached on held-out data. In this example, the model assigns highest domain probabilities to the medical and news domains. We use these probabilities in a weighted mixture of expert outputs to compute the hidden representation 𝐡t\mathbf{h}_{t}.

5.1 Dynamically Estimating Domain Membership

Consider the probabilistic view of language modeling, where we estimate p(Xt𝒙<t)p(X_{t}\mid\boldsymbol{x}_{<t}). We introduce a domain variable, DtD_{t}, alongside each word. We assume that this hidden variable depends on the history, 𝒙<t\boldsymbol{x}_{<t}, so that:

p(Xt𝒙<t)\displaystyle p(X_{t}\mid\boldsymbol{x}_{<t}) =j=1np(Xt𝒙<t,Dt=j)p(Dt=j𝒙t)gj\displaystyle{=\sum_{j=1}^{n}p(X_{t}\mid\boldsymbol{x}_{<t},D_{t}=j)\cdot\underbrace{p(D_{t}=j\mid\boldsymbol{x}_{t})}_{g_{j}}} (6)

This model is reminiscent of class-based nn-gram LMs (Brown et al., 1992) and their derivatives (Saul and Pereira, 1997).

We have already designed the DEMix LM to condition on a domain label, giving a form for p(Xt𝒙<t,Dt=j)p(X_{t}\mid\boldsymbol{x}_{<t},D_{t}=j). The modification is to treat g1,,gng_{1},\ldots,g_{n} as a posterior probability over domains, calculated at each timestep, given the history so far.

To do this, we apply Bayes’ rule:

p(Dt=j𝒙t)\displaystyle{p(D_{t}=j\mid\boldsymbol{x}_{t})} =p(𝒙<tDt=j)p(Dt=j)p(𝒙<t)\displaystyle{=\frac{p(\boldsymbol{x}_{<t}\mid D_{t}=j)\cdot p(D_{t}=j)}{p(\boldsymbol{x}_{<t})}} (7)
=p(𝒙<tDt=j)p(Dt=j)j=1np(𝒙<tDt=j)p(Dt=j)\displaystyle{=\frac{p(\boldsymbol{x}_{<t}\mid D_{t}=j)\cdot p(D_{t}=j)}{\sum_{j^{\prime}=1}^{n}p(\boldsymbol{x}_{<t}\mid D_{t}=j^{\prime})\cdot p(D_{t}=j^{\prime})}} (8)

The conditional probabilities of word sequences given a domain label, as noted above, are already defined by the DEMix LM. For the prior over domain labels, we consider three alternatives:

Refer to caption
Figure 4: Estimates of posteriors p(Dt𝒙<t)p(D_{t}\mid\boldsymbol{x}_{<t}) with a DEMix LM with 1.3B parameters per GPU, after 100 sequences (i.e., 102,400 tokens) of data in training domains (top heatmap) and new domains (bottom heatmap). Key: LG \rightarrow Legal, MD \rightarrow Med, WT \rightarrow WebText, RN \rightarrow RealNews, RD \rightarrow Reddit, RV \rightarrow Reviews, CD \rightarrow CORD-19, GH \rightarrow Github, GT \rightarrow Gutenberg, BN \rightarrow Breaking News, LC \rightarrow Contracts, AP \rightarrow ACL Papers, TW \rightarrow Tweets, YR \rightarrow Yelp Reviews.

Uniform

Fix the prior to be uniform across the known domains.

Updating

Set the prior at timestep tt to be an exponentially-weighted moving average of the posteriors from previous timesteps:

p(Dt=j)t=1t1λttp(Dt=j𝒙t)\displaystyle p(D_{t}=j)\propto\sum_{t^{\prime}=1}^{t-1}\lambda^{t-t^{\prime}}\cdot p(D_{t^{\prime}}=j\mid\boldsymbol{x}_{t^{\prime}}) (9)

During evaluation, this moving average is calculated over the posterior at the end of each sequence block. The decay factor avoids putting too much weight on calculations made early in the dataset, when posterior calculations are noisier (Appendix §A.6). We performed a small grid search over {0.1, 0.3, 0.5, 1.0} to set the value λ\lambda, and found that 0.3 worked well for most settings.

Cached

If, prior to testing, some data from the test distribution is available, we calculate the posterior over domain labels from that data, and fix the prior to that estimate. Under this setting, we use 100 sequences (i.e., 102,400 tokens) from the development set to estimate the prior, which we found to result in stable posterior probabilities (see Appendix §A.6 for more details).

We display an illustration of the mixture technique in Figure 3.

5.2 Visualizing Domain Membership

In Figure 4, we plot the posteriors, calculated using the updating method above after 100 sequences of development data, each from training and novel domains. This evaluation is carried out using the DEMix LM with 1.3B parameters per GPU from §4, with no modifications.

For known domains (top heatmap of Figure 4), the correct label has the highest posterior, but these datasets do not appear to be as distinct or mutually exclusive as we assume. For example, Reddit data is estimated to be around 80% Reddit, 11% WebText, and 8% RealNews. More variation in the estimates is expected and observed for the new domains (bottom heatmap of Figure 4). While ACL Papers is mostly associated with the CS domain, and Breaking News mostly with the WebText and RealNews domains, CORD-19 is spread across Med, RealNews, and 1B; Yelp Reviews across Reviews, WebText, and Reddit. The alignment of multiple domains like Github and Contracts primarily to WebText suggests the benefit of including a relatively heterogeneous domain in training.

Parameters per GPU
125M 350M 760M 1.3B
Dense 25.9 21.4 18.4 17.8
Dense (B) 25.3 19.6 18.3 17.1
+Domain-Token 24.8 20.4 18.4 18.0
DEMix (naive) 28.8 23.8 21.8 21.1
DEMix (average) 27.2 22.4 21.5 20.1
DEMix (uniform) 24.5 20.5 19.6 18.7
DEMix (updating) 21.9 18.7 17.6 17.1
DEMix (cached) 21.4 18.3 17.4 17.0
Table 5: Average perplexity on domains unseen during training. Mixing domain experts with a prior estimated using a small amount of data in the target domain outperforms all other baselines.

5.3 Experimental Setup

We experiment with the corpus of novel domains (Table 1) to test out-of-distribution performance. We evaluate the three mixture treatments of DEMix layers (i.e., uniform, updating, and cached priors) against five baselines. Note that no new models are trained for this experiment beyond those used in §4.

Dense and Dense (Balanced)

These are the basic baselines trained as in §4; there is no explicit reasoning about domain.

+Domain-Token

Here test data is evaluated using each domain label token, and we choose the lowest among these perplexity values per test set.

DEMix (naive)

Similar to +Domain-Token, we evaluate the data separately with each of the eight experts, and report the lowest among these perplexity values per test set.

DEMix (average)

At every timestep, we take a simple average of the eight experts’ predictions.

5.4 Results

Novel Domain Performance

Results averaged across the eight novel domains are summarized in Table 5. Ensembling DEMix experts outperforms Dense baselines and using experts individually (i.e., the “naive” baseline), and caching a prior prior to evaluation results in the best average performance. While +Domain-Token is competitive with naively using DEMix layers in-domain (Table 3), it consistently underperforms DEMix with a weighted mixture on the novel domains. We observe that ensembling DEMix experts with a cached prior allows smaller models to match or outperform much larger Dense models. We also find that weighted ensembling outperforms simple averaging, confirming the importance of sparsity in the expert mixture.

Examining per-domain performance (Appendix §A.5), we find that DEMix LMs with a cached prior either outperform Dense baselines or closely match them. The largest improvement against Dense baselines comes from the Tweets domain, which are on average 67% better across all model sizes. This domain is heterogeneous according to the DEMix model (Figure 4), confirming the importance of mixing experts for heterogeneous domains. These results demonstrate that conditioning the LM on domains during training need not come at a large cost to generalization to new domains, and in many cases can provide large boosts in performance over Dense baselines.

In-Domain Performance

We can also apply the expert mixture variant of inference (using a cached prior) to the training domains. We find that doing so is beneficial; see the last line of Table 3.

We see improvements in performance across all domains for every scale, though the largest improvements seem to come from hetereogeneous domains (across all model sizes, Reddit improves on average 10.7%, WebText 2.4%, RealNews 1.9%), again confirming that our intuition that domain metadata may not perfectly align with the most effective domain boundaries.

Refer to caption
Figure 5: Illustration of DEMix-DAPT. First, we estimate domain posteriors on a held out sample of the target domain (in this case, CORD-19). We then initialize a new expert with the parameters of the most probable expert under the domain posterior distribution. Finally, we adapt the parameters of the newly initialized expert to the target domain, keeping all other parameters in the LM frozen.

6 Adaptive Pretraining with New Experts

Domain-adaptive, continued pretraining111111This approach typically precedes supervised fine-tuning on task data, hence pretraining. of a language model (DAPT) is a way to use unannotated, in-domain text to improve task performance (Gururangan et al., 2020). However, for a large model, DAPT with Dense training (which we refer to as Dense-DAPT) is expensive and may not be feasible on some computational budgets. Furthermore, Dense-DAPT may result in forgetting what was learned during earlier training phases, limiting reusability.

The modular approach of DEMix LMs allows the model to avoid forgetting training domains and adapt cheaply: we can train a new expert and add it to the DEMix layers of the network without updating the other experts or the shared parameters. Because the original model is not changed, forgetting is impossible. We refer to this method of adaptation as DEMix-DAPT.121212Our proposed technique is reminiscent of Progressive Neural Networks (Rusu et al., 2016).

We display an illustration of DEMix-DAPT in Figure 5. We instantiate a new expert in each DEMix feedforward layer, initialize it with the parameters of the pretrained expert nearest to the new domain. We use the posterior calculations from §5 on a held-out sample to choose the most probable expert. We then train the added expert on target data, updating only the new expert parameters. For inference, we use the weighted mixture of domain experts with a cached prior (§5).

6.1 Experimental Setup

We compare DEMix-DAPT to Dense-DAPT on all novel domains. We report final test-set perplexity after adapting to each domain for 1 hour with 8 NVIDIA V100 32GB GPUs, tracking validation perplexity every 10 minutes for early stopping. We adapt to each novel domain with the same hyperparameters as the original phase of training (§4), except for a 10x smaller learning rate.

Refer to caption
Figure 6: Adapting LMs with 125M parameters per GPU to CORD-19 or Gutenberg. Top row: when performing Dense-DAPT on a new domain (Target), average perplexity on all pretraining domains degrades. Bottom row: DEMix-DAPT avoids that degradation while achieving close (in the case of Gutenberg) or better (in the case of CORD-19) performance. The new CORD-19 expert was initialized with the Med expert, and the new Gutenberg expert was initialized with a WebText expert.

6.2 Results

Adding one expert

We display examples of DEMix-DAPT and Dense-DAPT on a single additional domain in Figure 6. We observe that while Dense-DAPT reduces perplexity on the novel domain, its performance on the training domains progressively worsens, displaying the forgetting effect (we show similar results in larger models in Appendix §A.7). In contrast, DEMix-DAPT reduces perplexity on the novel domain without forgetting.

We generally observe that DEMix-DAPT outperforms Dense-DAPT for some domains (e.g., CORD-19 and ACL Papers), while it closely approaches Dense-DAPT for others (e.g., Gutenberg; Appendix §A.5). Overall, the parameters for the additional expert comprise about 10% of the total parameters in the DEMix model, and Dense-DAPT involves updating all the parameters of the model towards in the target domain, so we would expect that Dense-DAPT outperforms DEMix-DAPT in some cases. The strong performance of DEMix-DAPT on domains like CORD-19 and ACL Papers suggests that DEMix-DAPT is especially helpful when the target domain strongly aligns with one of the experts (Figure 4).

Parameters per GPU
Domains # Experts 125M 350M 760M 1.3B
Training 8 17.8 14.7 13.9 13.4
16 17.7 14.6 13.7 13.4
Novel 8 21.4 18.3 17.4 17.0
16 16.0 14.0 13.5 12.5
Table 6: Average perplexity in training and novel domains before and after adding 8 experts adapted to the novel domains (via DEMix-DAPT). Adding experts reduces perplexity on all domains, even those previously seen.

Adding eight experts

With expert mixing (§5), newly added experts can be combined with existing ones in the model at test time. To more thoroughly understand the effect of adding more experts to the system, we add all experts adapted to novel domains to the DEMix model from §4. We display the performance of a DEMix LM with 16 experts (8 experts trained on training domains, 8 additional experts adapted to novel domains) in Table 6. We generally observe that DEMix-DAPT reduces perplexity on all domains for all model sizes, again without forgetting.

Adding the eight additional experts in fact reduces perplexity on previously seen domains. For example, across all model sizes, on average, we see an 2.4% reduction on Med, 1.8% reduction on RealNews, and 2% reduction on Reddit (Appendix §A.5). These improvements are small, which is expected given that we only performed DEMix-DAPT for at most one hour with eight GPUs. Even so, these results suggest that DEMix layers can enable the LM to incorporate knowledge from novel domains to improve its performance on previously seen domains.

7 Language Models with Removable Parts

Current LM pretraining datasets are rife with undesirable content, from hatespeech to extremism (Gehman et al., 2020; Bender et al., 2021). Another consequence of Dense training is that it is difficult to restrict the model’s access to these problematic domains after training, as might be desirable for many user-facing tasks (Xu et al., 2020; Dinan et al., 2021).

DEMix layers offer new capabilities for lightweight control over the domains in the training data that LMs use to make predictions at inference time. In particular, since DEMix layer experts specialize to their domain (Figure 2), experts that are assigned to domains that are unwanted at test-time can be simply disabled and unused.

A key question is whether disabling an expert can simulate a model that has not been exposed to that domain, which we study in this section. However, since the self-attention and input embedding parameters in the DEMix LM are shared across domains, removing an expert offers no guarantee of having fully forgotten content from the removed domain. Establishing such bounds is an important avenue for future work.

125M Parameters per GPU
Domain +Expert –Expert –Domain
1B 13.7 25.5 30.4
CS 15.7 22.4 25.4
Legal 8.9 20.9 22.7
Med 12.4 18.6 21.9
WebText 20.9 27.3 25.4
RealNews 18.9 26.7 25.0
Reddit 34.4 47.8 51.3
Reviews 20.5 39.0 43.0
Average 18.2 28.5 30.6
Table 7: In a 125M parameter model, removing a domain expert (–Expert) results in perplexity degradation on the corresponding domain, approaching the performance of an LM that has not been exposed to that domain (–Domain). Here we bold the worst performing model for each domain, i.e. the one that gets the highest perplexity.

7.1 Experimental Setup

To evaluate whether we can simulate models that have not been exposed to a particular domain, we compare three settings:

+Expert

A DEMix LM with all experts active.

–Expert

A DEMix LM with a domain expert deactivated.

–Domain

A DEMix LM retrained from scratch without a particular domain. We replace the removed domain with Gutenberg.131313Our cluster requires that jobs are allocated with eight GPUs, necessitating eight experts — hence the substitution.

We evaluate expert removal (+Expert and –Expert) with the DEMix LM with 125M parameters per GPU from §4, with no modifications. For all baselines,we evaluate use expert mixing with a cached prior (§5).

7.2 Results

Removing a domain expert harms model performance on the associated domain, in most cases approaching the performance of a model that has not been exposed to data from that domain (Table 7). In some cases (e.g., WebText and RealNews), –Expert even underperforms –Domain. This leads us to conjecture that most domain-specific learning happens within the DEMix layer, despite the fact that other parts of the model are affected by all training domains.

8 Related Work

Incorporating Metadata

Document metadata has been commonly used to improve the quality of topic models (Mimno and McCallum, 2012; Ramage et al., 2009; Zhu et al., 2012), and previous works have used metadata for adapting RNN-based language models (Jaech and Ostendorf, 2018) or learning better document representations (Card et al., 2018). Zellers et al. (2019) and Keskar et al. (2019) prepend document metadata in the input text (similar to our +Domain-Token setting) while training transformer LMs to provide better inference-time control of text generation.

Inference-time Control

DEMix layers provide a simple mechanism for inference-time control of language model behavior. Previously proposed methods for inference-time control are either expensive to use (Dathathri et al., 2020), or rely on densely trained models (e.g., Keskar et al., 2019). Liu et al. (2021) use multiple experts for inference-time text generation control. This method may be applied to DEMix layers to steer text generation with experts trained on different domains.

Multilinguality

Related to variation across domains is crosslingual variation. Past work has suggested that multilingual models benefit from language-specific parameters (Fan et al., 2020; Pfeiffer et al., 2020; Chau et al., 2020). Here, we investigate the effect of incorporating domain-specific parameters into the LM. Though the boundaries between languages are (often) more clear than those among domains, DEMix layers draw inspiration from multilingual research, and future work might explore a compositional approach with both language experts and domain experts.

Continual Learning

DEMix-DAPT is a type of continual learning, in which the model learns incrementally on new data (Chen et al., 2018). Previously proposed techniques to support continual learning include regularization (Kirkpatrick et al., 2017), meta-learning (Munkhdalai and Yu, 2017), episodic memory modules (Lopez-Paz and Ranzato, 2017; de Masson d’Autume et al., 2019), and data replay (Sun et al., 2019), all of which may be combined with DEMix layers. Model expansion techniques to incorporate new reinforcement learning or visual tasks (Rusu et al., 2016; Draelos et al., 2017) is especially related to DEMix-DAPT. Our results suggest that continual learning in LMs is naturally enabled with modular domain experts; this may be further explored using temporally-relevant domains (Lazaridou et al., 2021).

LM Adapters

Also related to DEMix-DAPT is the line of work into adapter modules for pretrained LMs (Houlsby et al., 2019; Pfeiffer et al., 2020). Similar to the setting in which we add experts for new domains, adapter modules involve freezing the pretrained language model and updating a small number of additional parameters that are appended to certain parts of the network. This study confirms previous findings that only a subset of LM parameters need to be fine-tuned to a target dataset (Zaken et al., 2021). Expert addition may be performed with adapter modules to further improve efficiency.

Multi-Domain Models

Multi-domain models have been studied extensively in the context of machine translation, first with statistical systems (Banerjee et al., 2010; Sennrich et al., 2013), and more recently with neural networks (Pham et al., 2021). Other works have explored multi-domain settings with smaller models and explicit domain labels, using supervision (e.g., Wright and Augenstein, 2020; Guo et al., 2018; Zeng et al., 2018) or dense training (e.g., Maronikolakis and Schütze, 2021). Previous studies have shown the importance considering domains when adapting LMs (Ramponi and Plank, 2020; Gururangan et al., 2020). Our study establishes the importance of considering domains when training LMs from scratch.

9 Conclusion

We introduce DEMix layers for language models, which provide modularity at inference time, addressing limitations of dense training by providing a rapidly adaptable system. DEMix layers experts can be mixed to handle heterogeneous or unseen domains, added to iteratively incorporate new domains, and removed to restrict unwanted domains.

There are many exciting directions for future work, in addition to those described throughout the paper. They include combining domain and token-level routing, to realize the benefits of modularity while scaling models efficiently. The design of DEMix layers assumes access to coarse provenance labels (or other metadata) to identify domains in pretraining data; an alternative option is to use unsupervised learning to discover domains in the corpus, which, in concert with domain metadata, may lead to better DEMix expert assignments. Furthermore, in this work, we study DEMix layers with a dataset that has a few large domains. In practice, textual domains usually contain many diverse subdomains of varying prevalence. Training DEMix layers on dataset with a long tail of domains may require automatic measures to cluster smaller domains, or hierarchical experts that are specialized to progressively narrower data distributions.

Acknowledgments

The authors thank members of UWNLP, FAIR, and AI2, specifically Shruti Bhosale, Tim Dettmers, Emily Dinan, Doug Downey, Margaret Li, Myle Ott, Ofir Press, and Swabha Swayamdipta, for helpful comments. At UW, this work was partially supported by NSF grant 1562364, the Office of Naval Research under MURI grant N00014-18-1-2670, and an Amazon research award.

References

Appendix A Appendix

A.1 Collecting Domains

For most domains, we use the associated sources, listed in Table 1, without modification. For Tweets, we use the Twitter Academic API. For Gutenberg, we use the scraping tool provided in https://github.com/aparrish/gutenberg-dammit. For Breaking News, we identify a list of factually reliable English news sources, using the list curated by Baly et al. (2018). Specifically, we filter on "high" factuality in the data provided in this repository: https://github.com/ramybaly/News-Media-Reliability. We then use Newspaper3K (https://newspaper.readthedocs.io/en/latest/) to scrape the latest 1000 articles from each site. After dropping duplicates, we arrive at about 20K articles from 400 news sources. We provide downloading links and general instructions at https://github.com/kernelmachine/demix-data/blob/main/DOWNLOAD_DATA.md.

A.2 Dataset Anonymization

To anonymize certain datasets, we apply a suite of regexes that aim to identify common patterns of user-identifiable data and substitute them with dummy tokens. We display anonymization regexes and associated dummy tokens in Table 8.

Category Link to Regex Dummy Token
Email https://regex101.com/r/ZqsF9x/1 <EMAIL>
DART https://regex101.com/r/0tQ6EN/1 <DART>
FB User ID https://regex101.com/r/GZl5EZ/1 <FB_USERID>
Phone Number https://regex101.com/r/YrDpPD/1 <PHONE_NUMBER>
Credit Card Number https://regex101.com/r/9NTO6W/1 <CREDIT_CARD_NUMBER>
Social Security Number https://regex101.com/r/V5GPNL/1 <SSN>
User handles https://regex101.com/r/vpey04/1 <USER>
Table 8: Anonymization schema. We anonymize text using the regexes provided in the above links for the categories listed.

A.3 Calculating TFLOPs/GPU

We use the formula presented in Narayanan et al. (2021) to calculate TFLOPs/GPU and TFLOPs/update. The spreadsheet that contains the calculations and formula can be accessed here: https://docs.google.com/spreadsheets/d/1NO-Lz_VqZGF2fpJTFxtXyjhmaoYi6qnz50Xr8W8hgGw/edit?usp=sharing.

A.4 Hyperparameter Assignments

We display hyperparameter assignments for LM pretraining in Tables 9, 10,11, and 12.

Computing Infrastructure 32 Volta 32GB GPUs
Hyperparameter Assignment
architecture GPT-3 small
tokens per sample 1024
batch size 2
number of workers 2
learning rate [5e–4, 3e–4, 1e–4]
clip norm 0.1
gradient acculumation steps 8
number of steps 300,000
save interval updates 6,000
validation interval 3,000
number of warmup steps 24,000
learning rate scheduler polynomial decay
learning rate optimizer Adam
Adam beta weights (0.9, 0.95)
Adam epsilon 10e-8
weight decay 0.1
Table 9: Hyperparameters for pretraining the LM with 125M parameters per GPU. All hyperparameters are the same for DEMix and Dense training.
Computing Infrastructure 64 Volta 32GB GPUs
Hyperparameter Assignment
architecture GPT-3 medium
tokens per sample 1024
batch size 2
number of workers 2
learning rate [5e–4, 3e–4, 1e–4]
clip norm 0.1
gradient acculumation steps 8
number of steps 120,000
save interval updates 3,000
validation interval 2,000
number of warmup steps 9,600
learning rate scheduler polynomial decay
learning rate optimizer Adam
Adam beta weights (0.9, 0.95)
Adam epsilon 10e-8
weight decay 0.1
Table 10: Hyperparameters for pretraining the LM with 350M parameters per GPU. All hyperparameters are the same for DEMix and Dense training.
Computing Infrastructure 128 Volta 32GB GPUs
Hyperparameter Assignment
architecture GPT-3 large
tokens per sample 1024
batch size 2
number of workers 2
learning rate [5e–4, 3e–4, 1e–4]
clip norm 0.1
gradient acculumation steps 8
number of steps 65,000
save interval updates 2,000
validation interval 1,000
number of warmup steps 5,200
learning rate scheduler polynomial decay
learning rate optimizer Adam
Adam beta weights (0.9, 0.95)
Adam epsilon 10e-8
weight decay 0.1
Table 11: Hyperparameters for pretraining the LM with 760M parameters per GPU. All hyperparameters are the same for DEMix and Dense training.
Computing Infrastructure 128 Volta 32GB GPUs
Hyperparameter Assignment
architecture GPT-3 XL
tokens per sample 1024
batch size 2
number of workers 2
learning rate [5e–4, 3e–4, 1e–4]
clip norm 0.1
gradient acculumation steps 8
number of steps 50000
save interval updates 2,000
validation interval 500
number of warmup steps 4000
learning rate scheduler polynomial decay
learning rate optimizer Adam
Adam beta weights (0.9, 0.95)
Adam epsilon 10e-8
weight decay 0.1
Table 12: Hyperparameters for pretraining the LM with 1.3B parameters per GPU. All hyperparameters are the same for DEMix and Dense training.

A.5 Per-Domain Results

We display per-domain test results in the spreadsheets at the following link: https://docs.google.com/spreadsheets/d/1yNMZGSPAvhTi3JttLamiCULaOIGTJ4QGEOajO3b5kt8/edit?usp=sharing

A.6 Domain Posterior Calculations

We track calculated domain posteriors over blocks of development data in Figure 7 (training domains) and Figure 8 (novel domains). The calculate domain posteriors are noisier for earlier blocks, stabilizing usually after around 50 blocks. For all experiments, we conservatively use 100 blocks of data to compute the domain posterior, though one may be able to accurately calcuate the domain posterior for some domains with less data.

Refer to caption
Figure 7: Calculated domain posteriors for 8 training domains.
Refer to caption
Figure 8: Calculated domain posteriors for 8 novel domains.
Parameters
125M 350M 760M 1.3B
Dense-DAPT T +70.1% +21.4% +16.7% +20.6%
N –55.1% –46.6% –38.3% -44.4%
Table 13: Average change in perplexity in training (T) and novel (N) domains after Dense-DAPT. Negative values indicate better performance relative to the original Dense LM. While average perplexity in the novel domains decreases more for Dense-DAPT, this comes at the cost of a significant deterioration in performance in training domains.

A.7 Perplexity changes after Dense-DAPT

In Table 13, we display the average perplexity change after performing Dense-DAPT on a new domain. We observe that across all model sizes, Dense-DAPT improves performance in the novel domain, at the cost of a large performance hit in the training domains.