This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Technical intro

1 Implementing the forward pass of a Hypernetwork with Transformers

1.1 Linear attention block

We first give a brief overview of the linear transformer architecture.

Given input tokens EL×dmE\in\mathbb{R}^{L\times d_{m}} for a sequence of length LL, a trasnformer block consists in a self-attention layer followed by a multi-layer-perception. The transformation is done by first computing queries, keys and values Q,K,V=EWq,EWk,EWvQ,K,V=EW_{q},EW_{k},EW_{v} with which we then update EE as

E\displaystyle E E+QKTVWP\displaystyle\leftarrow E+QK^{T}VW_{P} (1)
E\displaystyle E E+σ(EW1)W2\displaystyle\leftarrow E+\sigma(EW_{1})W_{2} (2)

where Wq,Wk,Wvdm×dkW_{q},W_{k},W_{v}\in\mathbb{R}^{d_{m}\times d_{k}} and Wpdk×dmW_{p}\in\mathbb{R}^{d_{k}\times d_{m}} as well as W1dm×dh,W2dh×dmW_{1}\in\mathbb{R}^{d_{m}\times d_{h}},W_{2}\in\mathbb{R}^{d_{h}\times d_{m}} are learnable parameter matrices. The σ\sigma is a non linearity applied row wise. In practice, there are HH heads that performs the first attention operation in parallel, each with its own parameters Wq(h),Wk(h),Wv(h),Wp(h)W_{q}^{(h)},W_{k}^{(h)},W_{v}^{(h)},W_{p}^{(h)} for all hh, resulting in the following forward function

E\displaystyle E E+hQ(h)K(h)TV(h)WP(h)\displaystyle\leftarrow E+\sum_{h}Q^{(h)}K^{(h)T}V^{(h)}W_{P}^{(h)} (3)

1.2 Construction

We will now show a construction of a linear transformer above which would allow it to implement the forward pass of a given hypernetwork given any input xdx\in\mathbb{R}^{d} and latent zMz\in\mathbb{R}^{M}.

Hypernetwork

Let us consider the following linear hypernetwork:

x,zAσ(ω(z)x)x,z\rightarrow A\sigma(\omega(z)x) (5)

where ω(z)=m=1Mz(m)Θ(m)\omega(z)=\sum_{m=1}^{M}z^{(m)}\Theta^{(m)}, Θ(m)h×d\Theta^{(m)}\in\mathbb{R}^{h\times d} for all mm and Ao×dA\in\mathbb{R}^{o\times d}.

Token construction

We assume there are only 22 tokens, e1=(x,0M,1h+o)e_{1}=(x^{\top},0_{M},1_{h+o})^{\top} and e2=(0d,z,0h+o)e_{2}=(0_{d},z^{\top},0_{h+o})^{\top} where 0k,1k0_{k},1_{k} indicate the kk dimensional row vector of 0 resp 11. The output will be computed on the token stream of e2e_{2}.

Linear attention

First, the attention layer will compute the forward pass ω(z)x\omega(z)x. To do this, let us fix H=MH=M heads, dq=dk=1d_{q}=d_{k}=1 and dv=hd_{v}=h. For each head mm, we can construct the value matrix such that the first token has a value vector Θ(m)x\Theta^{(m)}x while the second has 0. By choosing the key and query matrices correctly, the attention score between the first and second token can be made to be exactly z(m)z^{(m)}. By letting the projection matrix be constant across head, the attention operation would then be

e2e2+mMz(m)(Θ(m)x)WPe_{2}\leftarrow e_{2}+\sum_{m}^{M}z^{(m)}(\Theta^{(m)}x)^{\top}W_{P} (6)

by appropriately choosing WPW_{P} the residual stream would then equal (0d,z,ω(z)x,0o)(0_{d},z^{\top},\omega(z)x,0_{o})^{\top} after the attention layer.

MLP

Finally, the MLP layer simply applies the correct non linearity σ\sigma to ω(z)x\omega(z)x and applies the readout weight AA to finally write the result on the remaining 0o0_{o} in the residual stream.