This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

An NMF-Based Building Block for Interpretable Neural Networks With Continual Learning

Brian K. Vogel
[email protected]
Abstract

Existing learning methods often struggle to balance interpretability and predictive performance. While models like nearest neighbors and non-negative matrix factorization (NMF) offer high interpretability, their predictive performance on supervised learning tasks is often limited. In contrast, neural networks based on the multi-layer perceptron (MLP) support the modular construction of expressive architectures and tend to have better recognition accuracy but are often regarded as black boxes in terms of interpretability. Our approach aims to strike a better balance between these two aspects through the use of a building block based on NMF that incorporates supervised neural network training methods to achieve high predictive performance while retaining the desirable interpretability properties of NMF. We evaluate our Predictive Factorized Coupling (PFC) block on small datasets and show that it achieves competitive predictive performance with MLPs while also offering improved interpretability. We demonstrate the benefits of this approach in various scenarios, such as continual learning, training on non-i.i.d. data, and knowledge removal after training. Additionally, we show examples of using the PFC block to build more expressive architectures, including a fully-connected residual network as well as a factorized recurrent neural network (RNN) that performs competitively with vanilla RNNs while providing improved interpretability. The PFC block uses an iterative inference algorithm that converges to a fixed point, making it possible to trade off accuracy vs computation after training but also currently preventing its use as a general MLP replacement in some scenarios such as training on very large datasets. We provide source code at https://github.com/bkvogel/pfc

1 Introduction

Current neural networks are often the best performing models on a range of challenging problems. Commonly used neural architectures include fully-connected and convolutional networks, various Recurrent Neural Networks (RNNs), and transformers [1]. These architectures and others make use of a relatively small number of basic building block types so that the differences between the various architectures are mainly due to which blocks are used and how they are inter-connected. Perhaps the single most fundamental building block is the Multi-Layer Perceptron (MLP) [2] [3] [4] since it is the smallest block that provides for learning arbitrarily complex non-linear functions and also tends to account for the bulk of the learnable parameters in existing architectures [5] [6]. The remaining building blocks are mainly intended to make the optimization process more efficient and/or serve as regularizers to help prevent over-fitting; examples of these include residual (i.e., skip) connections, LayerNorm [7], and Dropout [8] layers. Additional building blocks with learnable parameters are sometimes used for specialized architectures, such as linear embedding layers for the case of discrete-valued inputs and to provide the sequence-positional information in the transformer, for example. The transformer includes all of the above-mentioned blocks, as well as the attention block, which itself has an MLP interpretation, as described on the last page of [1]. We therefore consider the currently popular neural architectures as being “MLP-based”.

The existing MLP-based models do have some drawbacks, however. The first is that these architectures are often criticized as being “black boxes” lacking interpretability. Additionally and perhaps partially due to the first issue is that they tend to experience training difficulties as the distribution of examples starts to deviate from the i.i.d. assumption. This is often referred to as the “catastrophic forgetting” problem in the literature and it results in poor continual learning performance [9].

As an alternative to the MLP, there are other existing machine learning methods with better interpretability properties. A simple example is the k-nearest-neighbors (k-NN) algorithm for classification and regression [10]. Since there is no explicit learning step other than retaining the training examples as they become available to the model, k-NN trivially supports continual learning as well as knowledge removal. Other prototype-based algorithms such as LVQ [11] [12] can be considered learnable extensions of nearest neighbors and also have good interpretability.

Another method that is often noted for its desirable interpretability properties is non-negative matrix factorization (NMF). In contrast to other prototype-based learning methods such as k-NN and LVQ in which the prototypes represent entire examples, in NMF they instead represent prototypical “parts” exemplars which are additively combined to reconstruct the input.

These more interpretable methods unfortunately also have drawbacks that have prevented them from competing with MLP-based models in terms of predictive accuracy on supervised tasks. Methods such as k-NN and LVQ are not fully differentiable, preventing them from being used as general neural network building blocks that can be composed like the MLP to create expressive architectures supporting arbitrary differentiable loss functions tailored to the task at hand. Although NMF is potentially differentiable, the literature is mainly concerned with its usage for unsupervised learning tasks in which the factor matrices are optimized to minimize an input reconstruction loss only, rather than a supervised classification or regression loss.

1.1 Contributions

The main point of this work is to make some progress on developing learning methods with a better balance of interpretability and predictive performance compared to the existing MLP-based approaches. In particular, we make the following contributions:

  • In Section 2.4 we introduce a new neural network building block called the Predictive Factorized Coupling (PFC) block as a more interpretable alternative to the MLP. Since its declarative model is still matrix factorization, it potentially retains the interpretable parts-based nature of NMF, but extends it to support its use as a general differentiable predictive module.

  • In Section 5.2 we demonstrate that the PFC block has competitive accuracy with the (fully-connected) MLP on MNIST, Fashion MNIST, and CIFAR10 classification. In Section 5.7 we demonstrate that a (fully-connected) residual network consisting of two PFC blocks is also able to be trained without optimization difficulties and that it performs similarly. Section 5.6 show an example of the increased interpretability by visualizing in-domain vs out-of-domain input examples.

  • In Section 3 we develop a factorized RNN by starting from the well-known vanilla RNN and then replacing its MLP building block with our PFC blocks. This results in an RNN modeled as a single matrix factorization, effectively extending the parts-based nature of NMF to the modeling of sequential data. In Section 5.8 we demonstrate its interpretability advantages compared to the standard vanilla RNN on a simple sequential learning task with a known interpretable solution and observe that it is consistently able to learn the minimal transition model of the solution, even when the network is heavily over parameterized. In all of our sequence learning tasks we find that the factorized RNN performs either competitively or better compared with the vanilla RNN when both are trained with the usual BPTT. This includes learning a repeating sequence (Section 5.8), Copy Task (Section 5.9), Sequential MNIST (Section 5.10), and audio source separation (Section 5.11). We also perform some ablations and find two unexpected results of particular interest: The factorized RNN is able to solve simple tasks such as learning a repeating sequence and the Copy Task using only alternating NMF update rules so that backpropagation is not used at all. We also observe that in both the factorized and conventional RNNs, using backpropagation but disabling BPTT often results in a surprisingly minimal degradation in accuracy. The models do tend to become significantly less parameter efficient without BPTT, however.

  • In Section 5.3 we show that our non-replay-based continual learning method is competitive with approaches that rely on replay on the split MNIST Class-IL scenario [13]. Our PFC-based model performs better than the MLP on this task and we introduce a sliding window optimizer in Section 4 for PFC-based models that results in further accuracy improvements.

  • In Section 5.4 we show the superiority of a PFC-based model over the MLP on a non-i.i.d. training task.

  • In Section 5.5 we show how a PFC-based model can support knowledge removal after training by leveraging the sliding window optimizer that we introduce in Section 4.

We also mention the following limitations:

  • The PFC block is slower and consumes more memory during training than the MLP block due to the use of an iterative inference algorithm that effectively turns each replaced MLP into a corresponding RNN. It is therefore not intended to be a general MLP replacement, but rather it is intended to be used in modeling tasks where its better interpretability, parts-based modeling prior, and/or suitability to continual learning is needed.

  • Since the PFC block is modeled as a matrix factorization, it similarly cannot discriminate between two inputs that differ only by a scale factor, since the corresponding outputs would also only differ by the same scale factor. The NMF modeling assumption also constrains the input features to being non-negative, although this can potentially be relaxed to semi-NMF to support negative data values at the expensive of possibly reduced interpretabiity.

  • This work is preliminary. Our experiments only involve relatively small datasets. In terms of the suitability of the PFC block as an MLP replacement, we only consider two multi-block architectures in this paper: a 2-block residual network and a factorized RNN. Evaluating on larger datasets and/or more complex multi-block architectures is left as future research.

2 Matrix-factorization-based layers

We can motivate our approach by first reviewing some related methods that will contribute toward its design. These include non-negative matrix factorization (NMF) in Section 2.1, k-nearest-neighbors (k-NN) prediction in Section 2.2, and the multi-layer perceptron (MLP) in Section 2.3. We then present the The Predictive Factorized Coupling (PFC) block in Section 2.4.

2.1 Review of non-negative matrix factorization (NMF)

Non-negative matrix factorization (NMF) is a matrix decomposition method where a data matrix VMxNV\in\mathbb{R}^{M\text{x}N} is factorized into two matrices WMxRW\in\mathbb{R}^{M\text{x}R} and HRxNH\in\mathbb{R}^{R\text{x}N}, with the property that all three matrices have no negative elements. The objective is to find two non-negative matrices whose product closely approximates the initial matrix, such that:

VWH\displaystyle V\approx WH (2.1)

Here, RR is a tunable hyperparameter that leads to a more compressed approximation as it is reduced. NMF was originally proposed by [14] as positive matrix factorization and later popularized by [15]. NMF’s strength lies in its ability to provide interpretable decompositions. In particular, the non-negativity constraint is found to lead to sparse and parts-based representations because only additive, not subtractive, combinations are allowed [15]. In many real-world applications, such as image representation or document analysis, this aligns well with the nature of the data where the measured quantities (like pixel intensity or word counts) are inherently non-negative.

It is also possible to relax the non-negativity constraint so that only the factor matrix HH is constrained to be non-negative, which is called semi-NMF [16]. This allows for more flexibility in representing data but can also reduce the interpretabiity.

The factorization is not unique and is typically computed using iterative algorithms, such as gradient descent or multiplicative updates, for example [17]. This involves first specifying a differentiable reconstruction loss function such as mean-squared error (MSE), for example, and then jointly optimizing WW and HH to minimize the loss. These algorithms alternately fix one of the factor matrices and update the other, until approximate convergence to a fixed point. In this paper we will use the same terminology from [18] to discriminate between these two updates, so that the NMF left update step refers to applying the update rule once to update the WW factor matrix, which we also refer to as the NMF learning update step. Likewise, the NMF right update step refers to applying the update rule once to update the HH factor matrix, which we also refer to as the NMF inference update step.

The usual modeling convention is that the columns of VV correspond to NN input feature vectors, which are MM-dimensional. For example, the neural network interpretation of NMF shown in Figure 3 of [15] corresponds to:

vnWhn\displaystyle v_{n}\approx Wh_{n} (2.2)

where vnv_{n} and hnh_{n} are now individual column vectors (at columns index nn) in VV and HH, respectively. Here, vnv_{n} represents the observed input features (visible variables), while WW represents the neural network weights and hnh_{n} represent the hidden variables which are inferred from vnv_{n} via the repeated application of the NMF right update rule until convergence. The columns of WW are often referred to as the learned basis vectors in the literature and we will also sometimes refer to them as (parts-based) prototypes in this paper.

We can see from the basic NMF formulation in Eq. 2.1 that it attempts to learn an approximation to a supplied data matrix. While this can be suitable for many tasks, it also tends to limit its use to unsupervised learning and tasks where NMF is able to provide sufficient modeling expressiveness. For these reasons, NMF is not typically used for supervised learning or for tasks requiring more expressiveness, such as modeling sequential data, for example.

2.2 Review of nearest neighbor prediction

The k-nearest-neighbors (k-NN) algorithm [10] is one of the simplest and most interpretable machine learning methods used for classification and regression. It is considered an instance or prototype-based method since the model stores the entire training dataset and then makes predictions using a similarity measure.

Suppose that we have a set of training example pairs {(x1,y1),(x2,y2),,(xN,yN)}\{(x_{1},y_{1}),(x_{2},y_{2}),\ldots,(x_{N},y_{N})\} where xnx_{n} is the input feature vector and yny_{n} is the corresponding target output value or vector. In the classification setting, yny_{n} would contain the ground truth class label and could be represented either as an integer label or a 1-hot vector. In the regression setting, yny_{n} could more generally be an arbitrary real-valued vector. Since the following formulation could apply to either, we will refer to it as nearest neighbor prediction.

In typical machine learning algorithms, the model adjusts its weights (i.e., parameters) to minimize a loss function that measures the difference between the model’s predictions and the actual target values. However, there is no such loss function in nearest neighbors. Rather, the learning procedure is extremely simple as it only consists of storing the training examples in a suitable data structure as they become available. We will still refer to this data structure as the “weights” since it contains the learned knowledge extracted from the training set, which in this case is just the training examples themselves. For easier connection to the matrix-factorization-based models that follow in later sections, we will formulate nearest neighbor prediction as a special case of matrix factorization. For each training example (xn,yn)(x_{n},y_{n}), we create a corresponding column vector wnw_{n} in the weights matrix WW as the vertical concatenation of the target yny_{n} on top of input feature xnx_{n}. Specifically, let ynCy_{n}\in\mathbb{R}^{C} (i.e., CC class labels) and let xnMx_{n}\in\mathbb{R}^{M}. We then construct wnC+Mw_{n}\in\mathbb{R}^{C+M} as:

wn=[ynxn]\displaystyle w_{n}=\begin{bmatrix}y_{n}\\ x_{n}\end{bmatrix} (2.3)

As a result, WW will be a (C+M)(C+M) x NN matrix containing all of the training examples as its columns. We will find it useful to split WW into an upper “prediction” sub-matrix WyW_{y} consisting of the first CC rows (containing the yny_{n} targets), and a lower “recognition” sub-matrix WxW_{x} consisting of the last MM rows (containing the xnx_{n} input features):

W=[WyWx]\displaystyle W=\begin{bmatrix}W_{y}\\ W_{x}\end{bmatrix} (2.4)

The nearest neighbor algorithm also requires us to define a suitable distance or similarity metric. For this, we can specify a similarity function sim(xq,xk)sim(x_{q},x_{k}) which computes a similarity score between two supplied vectors based on their Euclidean distance or cosine similarity, for example.

Given a new input example xx, we can use the model to perform inference and predict the output (either classification or regression) as follows. In the first step, we use sim(x,xk)sim(x,x_{k}) to compute the similarity score of xx against each of the columns xkx_{k} in WxW_{x}. The general form of the algorithm finds the k-nearest neighbors (k-NN) but we consider only the 1-nearest neighbor (1-NN) method here for simplicity. Let mm denote the index of the best-matching column in WxW_{x} (which contains the training example xmx_{m}). We refer to this step as “recognition” since the input xx was recognized as its nearest neighbor xmx_{m}, which represent the model’s reconstruction of xx. The second step then involves selecting the same column of WyW_{y} so that the model will output ymy_{m} as its prediction. We therefore refer to this as the “prediction” step.

We can express the inference solution to the 1-NN in the form of a matrix-vector product that will make the reconstructive-predictive aspect of the model explicit. We introduce a 1-hot vector hNh\in\mathbb{R}^{N} having the same dimensionality as the column count in WW, where only index mm is set to 1. This allows us to express the model’s output prediction in the form of the following matrix-vector product:

ypred=Wyh\displaystyle y_{pred}=W_{y}h (2.5)

We can likewise express the model’s input prediction as:

xpred=Wxh\displaystyle x_{pred}=W_{x}h (2.6)

Using Eq. 2.4, these can be combined so that both the input and output predictions are given by a single product:

[ypredxpred]=[WyWx]h\displaystyle\begin{bmatrix}y_{pred}\\ x_{pred}\end{bmatrix}=\begin{bmatrix}W_{y}\\ W_{x}\end{bmatrix}h (2.7)

Note that given the (1-hot) inference solution in hh, it picks out a single column in the weights to use as the predictions for both the output and input. For the more general k-NN algorithm, we could consider computing hh by setting all elements to 0 except those for indices corresponding to the k-nearest neighbors, for which we would use the corresponding value of the similarity function and then normalize these kk values to sum to 1 in hh. Several interpretable aspects of the 1-NN prediction algorithm are worth mentioning. For example, it provides confidence estimation by making use of the similarity score. When the model makes a wrong prediction (or when an out-of-distribution example is supplied), we can inspect the reconstructed input xpredx_{pred} since it shows what the input was recognized as. Continual learning and knowledge removal are easily supported as well, since the training examples comprise the columns of WW and can therefore easily be added or removed as necessary.

A drawback of the 1-NN predictor is that the output ypredy_{pred} is not a differentiable function of WW and the input xx. This prevents us from connecting it to a loss function such as classification loss in order to learn the weights using backpropagation, and also prevents its use as a building block for more complex architectures. The reliance on a manually specified similarity function can limit the prediction accuracy. In addition, inference can be expensive for large datasets since the number of columns in WW grows as the number of training examples. Despite these drawbacks, 1-NN and the more general k-NN can still sometimes produce a good balance of predictive performance and interpretability depending on the dataset. It is also worth mentioning that learnable versions of nearest neighbor prediction exist and are referred to as Learnable Vector Quantization (LVQ) [11] [12]. However, we are not aware of a fully differentiable version that would enable its use as a building block for more expressive architectures.

2.3 Review of MLP-based Neural Networks

In contrast to nearest neighbor methods, which are non-parametric, most machine learning algorithms use a set of learnable parameters. In modern neural networks, also known as deep neural networks (DNNs), most of the learnable parameters tend to be contained in the MLP blocks as discussed in the Section 1. The versatility of DNNs lies in their ability to use backpropagation [19] [20] [21] for learning parameters that minimize a chosen loss function—ranging from classification to regression losses. This flexibility allows us to take a foundational component like the MLP and scale it to create architectures of varying complexity such as convolutional networks, deep residual networks, recurrent neural networks (RNNs), transformer architectures, etc.. It is this ability to optimize the parameters so as to minimize a desired loss function that appears to contribute to DNNs’ superior predictive performance over other more interpretable methods.

The basic MLP corresponds to the sequential connection of two affine transformations (linear layers) with a non-linear activation function in between. The hidden layer activation vector hh is expressed as the following (differentiable) function of the parameters and input vector xx:

h=σ(WxhTx+bh)\displaystyle h=\sigma(W^{T}_{xh}x+b_{h}) (2.8)

where the weights matrix WxhW_{xh} and bias vector bhb_{h} are the parameters of the first linear layer and σ()\sigma() can be an arbitrary differentiable activation function. Common choices for σ()\sigma() include ReLU, GELU [22], and tanh, for example. The MLP’s predicted output, ypredy_{\text{pred}}, is then obtained by applying a second affine transformation to hh:

ypred=Whyh+by\displaystyle y_{\text{pred}}=W_{hy}h+b_{y} (2.9)

where the weights matrix WhyW_{hy} and bias vector byb_{y} are the parameters of the second linear layer. Note that ypredy_{\text{pred}} is differentiable with respect to the parameters of both linear layers, hidden activations, and input. This allows is to be used as a building block to create expressive neural architectures. Also note that unlike a single linear layer or Single Layer Perceptron (SLP), an MLP with at least one hidden layer can approximate complex non-linear functions [5] [6]. This ability comes from the non-linear activation function σ()\sigma() used in the hidden layer. As discussed in Section 1, a drawback of MLP-based models is that they are considered black-box models lacking interpretabiity and have challenges dealing with continual learning and training on non-i.i.d. examples.

Although the above presentation of MLPs might make them seem completely unrelated to nearest neighbor methods, we can make a connection between them. We do this by showing that a particular choice of weights initialization and activation function choice for the MLP makes it equivalent to the k-NN predictor. This is the reason why we overloaded hh to refer to both the MLP hidden activations in this section, while also using it to refer to the 1-hot vector (or k-nonzero in the case of k-NN) hh in Eqs 2.5 2.6 2.7. To see the connection, we first ignore the MLP bias terms. Instead of using the usual backpropagation procedure to learn the weights, we instead initialize MLP weights WxhW_{xh} and WhyW_{hy} to the corresponding nearest neighbor weights WxW_{x} and WyW_{y} that were used in Eqs 2.5 2.6 2.7 and also normalize the columns of WxhW_{xh} to have unit L2 norm. Recall that WxW_{x} contained all of the training input vectors xnx_{n} as its columns while WyW_{y} contained the corresponding target output vectors yny_{n} as its columns, for a total of NN columns in each weight matrix. For the activation function σ()\sigma(), we use a k-max activation followed by softmax, which has the effect of only passing the k-largest values and then normalizing them to have unit sum. For example, with k=1k=1, the output hh of σ()\sigma() will be a 1-hot vector similar to the hh vector that we used for nearest neighbors. The final step is that we need to ensure that the input xx to the MLP is normalized to have unit L2 norm. Note that with this setup, the matrix product in Eq. 2.8 computes the cosine similarity between each of the training inputs in WxhW_{xh} and xx, corresponding to this choice of similarity measure in nearest neighbors. The activation function then has non-zero outputs only for the indices corresponding to the k-nearest neighbors in WxhW_{xh}, where these values must sum to 1 in hh. As a result, we see from Eq. 2.9 that the MLP then outputs a linear combination of the corresponding kk columns in WhyW_{hy}, (which contain the target training vectors yny_{n}). Without using backpropagation, this “nearest-neighbor” MLP remains interpretable. However, if we attempt to train it, the interpretation of the weights as containing prototypes is potentially lost. For example, with the usual backpropagation training, there is no input reconstruction loss used by default, and so the input prediction xpred=Wxhhx_{pred}=W_{xh}h (analogous to Eq. 2.6) may not necessarily result in a good reconstruction.

2.4 The Predictive Factorized Coupling (PFC) block

In this section we introduce the Predictive Factorized Coupling (PFC) block and discuss how it relates to the models in the previous background sections. Having covered the related models, we can now easily introduce the PFC block simply by re-interpreting the matrix product shown in Eq. 2.7 with the NMF modeling interpretation of Eq. 2.2 instead of k-NN, so that it represents the predicted input and output vectors after using an NMF algorithm to solve for hh:

v=[ytargetx][ypredxpred]=[WyWx]h\displaystyle v=\begin{bmatrix}y_{target}\\ x\end{bmatrix}\approx\begin{bmatrix}y_{pred}\\ x_{pred}\end{bmatrix}=\begin{bmatrix}W_{y}\\ W_{x}\end{bmatrix}h (2.10)

Interpreted as NMF, it shows the prediction for a single column vector vv of data matrix VV in the factorization VWHV\approx WH, where VV contains the input vectors xnx_{n} to be recognized as the columns of its lower sub-matrix XX and the corresponding target output vectors ytargetn{y_{target}}_{n} to be predicted as the columns of its upper sub-matrix YtargetsY_{targets}:

V=[YtargetsX]\displaystyle V=\begin{bmatrix}Y_{targets}\\ X\end{bmatrix} (2.11)

When used as a neural building block, we consider XX to contain the input vectors to the PFC block. The block then infers (solves for) HH and predicts YpredY_{pred} for its output. Similar to neural network training, the corresponding targets YtargetsY_{targets} are not available during the prediction (i.e., inference) process. VV is therefore partially observed during inference, since only its XX sub-matrix is available to the block as input. YtargetsY_{targets} is then only used for the purpose of computing the prediction loss when it is available.

Recall that under the NMF modeling constraints, the three matrices of VWHV\approx WH are only required to be non-negative (the non-negativity constraints for VV and WW can also potentially be removed if we allow semi-NMF, but we will assume NMF here for the purpose of describing the model). We see that HH is now less constrained compared to the nearest neighbor model that required hh to be 1-hot for 1-NN or contain only kk non-zero elements for the k-NN case. Since WW is learnable under NMF, we no longer need to use the same number of column vectors (which are often called basis vectors in the NMF literature) as training examples. Similar to Section 2.1, we let the hyperparameter RR specify the number of learnable basis vectors in WW, and these can now be initialized to random non-negative values as an alternative to initializing with training examples. RR also specifies the dimension of the hidden vector hh. By keeping hh internal to the block, we are free to later add or remove basis vectors from WW without changing the external interface of the block. With this interpretation, hh corresponds to the inferred hidden activations that represent an encoding of the input in terms of the (parts-based) basis vectors in WW. From Eq. 2.10 we see that WW is also composed of two sub-matrices, WyW_{y} and WxW_{x}, which represent the learned basis vectors as coupled “input-output” parts or prototypes. These could also be interpreted as learned key-value factors, where the input vector xx serves as the query, and the block is seen to perform a kind of factorized attention over parameters in predicting its output.

The factorization expression in Eq. 2.10 corresponding to the PFC block was first proposed in our earlier work [18] where we referred to it as a “coupling module” or “coupling factorization” and only considered its use in sequential data models. This contrasts with its more general usage as a building block for both sequential and non-sequential models in the current work. Refer to Section 6 for a more detailed discussion of related work.

2.4.1 Training and inference

We will show that the PFC block’s inference procedure is differentiable. This allows it to be used as a general neural building block in arbitrary computation graphs and trained with the usual backpropagation, similar to how existing MLP-based models are trained. From inspection of Eqs. 2.10 2.11, it may not initially be clear that the prediction process is differentiable. We see that the block’s predicted output, ypredy_{pred}, is expressed as:

ypred=Wyh\displaystyle y_{pred}=W_{y}h (2.12)

and so ypredy_{pred} is clearly differentiable with respect to WyW_{y} and hh, but what about with respect to WxW_{x} and xx? Eq. 2.10 is simply a declarative expression stating that the input xx is approximately a linear function of the inferred hh:

xWxh\displaystyle x\approx W_{x}h (2.13)

We need to show that the corresponding reverse direction imperative process of inferring hh from xx (and from WxW_{x}) is also differentiable. Letting f()f() represent this inference process, we need to show that h=f(x,Wx)h=f(x,W_{x}) is differentiable with respect to xx and WxW_{x}. Recall from Section 2.1 that hh is computed by an iterative NMF algorithm consisting of a sequence of right-update steps until approximate convergence of hh to a fixed point. Let hk+1=g(hk,x,Wx)h_{k+1}=g(h_{k},x,W_{x}) denote the function the computes a single right-update step (the subscript kk denotes iteration number here, not column index). If it takes KK iterations to converge, then f()f() corresponds to the KK-fold composition of g()g() so that we have:

f=ggg=g(K)\displaystyle f=g\circ g\circ\cdots\circ g=g^{(K)} (2.14)

It then only remains to show that g()g() is differentiable. There are several options for g()g() but we will use the simple and well-known SGD update steps which we review in Appendix A. Eq. A.6 shows the right-update step for the more general case of (batched) matrix input XX rather than vector input xx, and so we repeat it here for the vector case:

hk+1=relu(hkηHWT(Whkx))\displaystyle h_{k+1}=relu(h_{k}-\eta_{H}W^{T}(Wh_{k}-x)) (2.15)

The inference learning rate ηH\eta_{H} controls the step size. With the choice of Eq. 2.15 as g()g(), we see that it is indeed differentiable. In summary, the inference procedure given an input xx is as follows. We first apply the NMF right-update rule in Eq. 2.15 KK times (assuming convergence is reached by then) to infer hh. We then apply the final linear prediction step in Eq. 2.12 to compute the predicted output. Note that the targets ytargety_{target} in Eq. 2.10 are masked while performing inference, since we are predicting them as ypredy_{pred}. For this reason, we will also refer to this extension of NMF as masked predictive NMF.

Since NMF is used to infer hh, we can interpret it as follows. If the input xx is well modeled as consisting of a mixture of parts, then the NMF solution can potentially discover those parts (assuming an appropriate learning algorithm for WxW_{x}). Using NMF terminology, the columns of WxW_{x} can be interpreted as the learned parts or “basis vectors” so that the inferred hh then specifies an additive encoding of the input in terms of the basis vectors. Intuitively, we can interpret 2.13 as expressing that the input xx is approximately generated (reconstructed) as a linear function of the inferred hh. That is, the reconstruction is given as xpred=Wxhx_{pred}=W_{x}h. Thus, the process of running an optimization algorithm to solve for hh corresponds to the model trying to recognize its input in terms of the already-learned “parts”. When the recognition is successful, the model is able to find an encoding hh of the input in terms of these parts that results in a low reconstruction error ereconstruction=xxprede_{reconstruction}=x-x_{pred}. When it is unsuccessful, ereconstructione_{reconstruction} will be large. So although we now need to do more computation compared to the MLP, we get a useful new property: the PFC block can now give us feedback through the reconstruction and its error to tell us how well it was able recognize its input.

We note that it is possible to automate the selection of the learning rate and to accelerate the inference procedure so that fewer iterations are required by leveraging a modified version of the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) [23] algorithm that removes the shrinkage step and adapts it to NMF as we describe in Appendix B. We used this method of unrolling in all of our experiments.

To learn the weights, we can use the PFC block in an arbitrary computation graph and train with the usual backpropagation algorithm. For example, we could compute a classification or regression loss between ypredy_{pred} and ytargety_{target}, perform backpropagation to compute the error gradients, and use an existing optimizer such as SGD, RMSprop [24], etc. to update the weights. When using the NMF modeling constraint, we also need to clip any negative values in the weights to zero after each optimizer update. We can also consider training on mini-batches instead of individual examples by replacing xx with a matrix XnX_{n} containing a batch of examples.

The general idea of unfolding an iterative and differentiable optimization algorithm, such as NMF in our case, into a computation graph and using backpropagation to learn its parameters is not a new idea. It is sometimes referred to as algorithm unrolling or unrolled neural networks in the literature [25]. For details on related work, refer to Section 6.

3 Factorized RNNs

In this section we develop a simple matrix-factorization-based replacement for the vanilla RNN [26]. Our “factorized” RNN is modeled as a single matrix factorization of the form VWHV\approx WH without any activation functions, and containing all of the model’s input activations, weights, hidden state activations, and output activations. This makes its declarative model one of the conceptually simplest RNNs that we are aware of. We will often be interested in the case where these matrices are constrained to be non-negative to improve interpretability, so that the model will then more specifically correspond to an instance of non-negative matrix factorization (NMF). Since the model is NMF, it retains all of the desirable properties of NMF while also more directly supporting the modeling of sequential data. The models we develop in this section are based on the similar factorized recurrent approach used in Section 3 of [18], although here we use a different recurrent architecture and propose modified and more effective backpropagation-based learning methods using the algorithm unrolling method from Section 2.4.

3.1 Review of the vanilla RNN

To motivate the idea, we first need to review the vanilla RNN. In the following, we initially assume that the weights (WW matrices) have already been learned, so that we only need to consider the inference or forward pass through the network to compute its output predictions from its inputs. Given an input sequence of length TT containing the feature vectors x0,x1,,xT1x_{0},x_{1},\dots,x_{T-1}, the RNN maps them in order, one at a time, into a corresponding output sequence of vectors y0,y1,,yT1y_{0},y_{1},\dots,y_{T-1}. The subscript denoting the position in the sequence is often called the “time slice”, even when there is no notion of time involved. The reason they must be mapped one at a time is because the RNN maintains an internal hidden state vector hkh_{k} which is consumed as an additional (hidden) input during each time slice, modified, and produced as a (hidden) output for use in the next time slice. So, in time slice kk, the RNN will consume input xkx_{k} and previous hidden state input hk1h_{k-1}. It then produces a new hidden state hkh_{k} and output yky_{k}. The vanilla RNN does this in two stages. In the first stage, we first update the hidden state

hk=σ(WhThk1+WxTxk+bh)\displaystyle h_{k}=\sigma(W^{T}_{h}h_{k-1}+W^{T}_{x}x_{k}+b_{h}) (3.1)

where σ()\sigma() denotes an arbitrary activation function and bhb_{h} is the bias vector. In the second stage, we compute the output from the updated hidden state:

yk=Wyhk+by\displaystyle y_{k}=W_{y}h_{k}+b_{y} (3.2)

Note that Eq. 3.1 can be rewritten as a single linear layer followed by nonlinear activation function:

hk=\displaystyle h_{k}= σ([WhTWxT][hk1xk]+bh)\displaystyle\sigma(\begin{bmatrix}W^{T}_{h}&W^{T}_{x}\end{bmatrix}\begin{bmatrix}h_{k-1}\\ x_{k}\end{bmatrix}+b_{h})
=\displaystyle= σ(WzTzk+bh)\displaystyle\sigma(W_{z}^{T}z_{k}+b_{h}) (3.3)

where we let WzW_{z} refer to the combined weights:

Wz=[WhWx]\displaystyle W_{z}=\begin{bmatrix}W_{h}\\ W_{x}\end{bmatrix} (3.4)

and let zkz_{k} refer to the combined inputs:

zk=[hk1xk]\displaystyle z_{k}=\begin{bmatrix}h_{k-1}\\ x_{k}\end{bmatrix} (3.5)

Combining Eq. 3.2 and Eq. 3.3, we can express a single time slice of the RNN computation as:

yk=Wyσ(WzTzk+bh)+by\displaystyle y_{k}=W_{y}\sigma(W^{T}_{z}z_{k}+b_{h})+b_{y} (3.6)

This shows that each time slice kk of the RNN can be interpreted as an MLP that takes input zkz_{k} and produces outputs yky_{k}. From Eq. 3.5, we see that the previous hidden state hk1h_{k-1} appear together with xkx_{k} in the input zkz_{k}, and the updated hidden state hkh_{k} corresponds to the hidden layer of the MLP after the σ()\sigma() activation as shown in Eq. 3.3.

3.2 The factorized RNN

Section 3.1 showed that the each time slice of the vanilla RNN can be interpreted as an MLP. With this interpretation, it is now interesting to consider the model that results from replacing each of these MLP blocks with a corresponding PFC block. This corresponds to keeping the output linear layer in Eq. 3.2 (although with the bias term removed). We replace the input linear layer and activation function from Eq. 3.3 with the following vector factorization for the kk’th time slice:

zkWzhk\displaystyle z_{k}\approx W_{z}h_{k} (3.7)

With this change, we have now reversed the direction of the linear mapping compared to the MLP so that the input zkz_{k} is approximately a linear function of the hidden state hkh_{k} (contrast this to the MLP in Eq. 3.3 where the pre-activation hkh_{k} is a linear function of zkz_{k}). Our factorized representation is also simplified compared to Eq. 3.3 since we have removed the need for an activation function and bias term. The tradeoff is that now when we are given an input zkz_{k}, we will require an iterative NMF update algorithm to solve for hkh_{k}, which could be more computationally costly compared to the MLP.

With the computed hkh_{k}, we then compute the output as in the vanilla RNN using Eq. 3.2. Applying Eq. 3.2 and Eq. 3.7 to all time slices k0T1k\in 0\dots T-1, we finally arrive at the factorized vanilla RNN expressed as a single matrix factorization of the form VWHV\approx WH:

[y0y1y2yT1z0z1z2zT1][WyWz][h0h1h2hT1]\displaystyle\begin{bmatrix}y_{0}&y_{1}&y_{2}&\dots&y_{T-1}\\ z_{0}&z_{1}&z_{2}&\dots&z_{T-1}\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{z}\end{bmatrix}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (3.8)

Using Eq. 3.5, we can also express the left matrix VV in terms of yky_{k}, hk1h_{k-1}, and xkx_{k}. This brings us to the key result which lets us express the the factorized RNN with all of the model inputs, outputs, hidden states, and weights together in a single matrix factorization:

[y0y1y2yT1h1h0h1hT2x0x1x2xT1][WyWhWx][h0h1h2hT1]\displaystyle\begin{bmatrix}y_{0}&y_{1}&y_{2}&\dots&y_{T-1}\\ h_{-1}&h_{0}&h_{1}&\dots&h_{T-2}\\ x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (3.9)

Note that for a single time slice, our model corresponds to the following vector factorization:

[ykhk1xk][WyWhWx]hk\displaystyle\begin{bmatrix}y_{k}\\ h_{k-1}\\ x_{k}\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}h_{k} (3.10)

If we use WW to denote the three stacked weights sub-matrices, the notation simplifies even further to the following:

[ykhk1xk]Whk\displaystyle\begin{bmatrix}y_{k}\\ h_{k-1}\\ x_{k}\end{bmatrix}\approx Wh_{k} (3.11)

Now contrast the simplicity of the factorized RNN for a single time slice in Eq. 3.11 with the corresponding vanilla RNN expression for a single time slice in Eq. 3.6. Note that they differ in that Eq. 3.11 is a declarative representation while Eq. 3.6 is imperative. That is, for the factorized RNN we will still need to find suitable algorithms to actually solve the factorization, whereas the vanilla RNN expression explicitly tells us the steps needed to produce the output.

We can further simplify the notation by replacing each of the sequences in Eq. 3.9 with their respective sub-matrices. Letting Y=[y0y1y2yT1]Y=\begin{bmatrix}y_{0}&y_{1}&y_{2}&\dots&y_{T-1}\end{bmatrix}, Hprev=[h1h0h1hT2]H_{prev}=\begin{bmatrix}h_{-1}&h_{0}&h_{1}&\dots&h_{T-2}\end{bmatrix}, and X=[x0x1x2xT1]X=\begin{bmatrix}x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix} results in the following:

[YHprevX][WyWhWx]H\displaystyle\begin{bmatrix}Y\\ H_{prev}\\ X\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}H (3.12)

Regarding the hidden states, we see that the initial “previous” h1h_{-1} only appears in the left VV matrix while the final state hT1h_{T-1} only appears in the right HH matrix. The other hidden states are duplicated since hkh_{k} for k[0T2]k\in[0\dots T-2] appear in both HprevH_{prev} and HH. Also note that WhW_{h} must be an RR x RR square sub-matrix of WW since if the hkh_{k} are RR-dimensional then each of WyW_{y}, WhW_{h}, and WxW_{x} must also have RR columns.

3.3 Training with alternating NMF update rules

A simple method of training the factorized RNN consists of performing alternating NMF updates to WW and HH, while also copying the inferred states from HH to the corresponding duplicated positions in the data matrix after each inference update step to satisfy the consistency constraints. This is similar to the approach that we used for the sequential models in [18]. The details of this are as follows. We use the simple SGD-based algorithm which is reviewed in Appendix A for our experiments, but a variety algorithms can potentially be used.. Starting from the factorization in Eq. 3.9, we begin by initializing the weights W=Wy,Wh,WxW={W_{y},W_{h},W_{x}} to small random values and initializing the hidden states to either zeros or small random values. This corresponds to setting sub-matrices HprevH_{prev} and HH to zeros or small random values in Eq. 3.12. Let X=[x0x1x2xT1]X=\begin{bmatrix}x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix} represent the training inputs and Y=[y0y1y2yT1]Y=\begin{bmatrix}y_{0}&y_{1}&y_{2}&\dots&y_{T-1}\end{bmatrix} represent the corresponding target output values that we want to predict. Since we want to predict the outputs, using only the provided inputs, we first need to infer the hidden states. For that we use the subpart of Eq. 3.9 corresponding to Eq. 3.7:

[h1h0h1hT2x0x1x2xT1][WhWx][h0h1h2hT1]\displaystyle\begin{bmatrix}h_{-1}&h_{0}&h_{1}&\dots&h_{T-2}\\ x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix}\approx\begin{bmatrix}W_{h}\\ W_{x}\end{bmatrix}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (3.13)

The task is then to solve for the hkh_{k}, which we can do by alternating between matrix factorization updates of the right HH matrix followed by enforcing the constraint that the duplicated hkh_{k} must have equal values. We can do this by simply copying the hkh_{k} from the HH (right factor matrix) to HprevH_{prev} (left data matrix) after each NMF update to HH. Once the updates converge, we can then use the top-most sub-factorization of Eq. 3.9 to compute the predicted outputs yk^\hat{y_{k}} as a linear function of the hkh_{k}:

[y0^y1^y2^yT1^]=WY[h0h1h2hT1]\displaystyle\begin{bmatrix}\hat{y_{0}}&\hat{y_{1}}&\hat{y_{2}}&\dots&\hat{y_{T-1}}\end{bmatrix}=W_{Y}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (3.14)

Now that the inference part of the algorithm has complete, we perform the learning updates. We replace the predicted outputs with the target outputs above and perform a NMF update on WyW_{y}. Similarly, we perform an NMF update on WhW_{h} and WxW_{x} in Eq. 3.13. We then repeat the above procedure until convergence.

Since the hidden state updates propagate one slice forward for each NMF update, it will take the same number of updates as the sequence length for information from the first slice to potentially reach the last slice. As a result, if the sequence of long, the initial NMF updates late in the sequence could be considered wasted computation since they will be operating on hidden states that only contain information from nearby slices. Whether or not this actually becomes an issue in practice would seem to depend on how far information can actually propagate in an RNN, which is outside the scope of this paper.

As an alternative to performing the “full batch” updates as above, we can consider performing the inference “one time slice at a time”. Specifically, we start at the first slice k=0k=0 and wait for the inference procedure to converge on that slice before containing with the next. Since the (output) inferred hidden state h0h_{0} has now converged, we then copy it into the duplicated (input) location in the next (k=1k=1) slice of HprevH_{prev}. We can now increment the current slice to k=1k=1 and continue in the same way so that the input hkh_{k} states in the current slice of VV can always be considered to have already converged. This is similar to how inference is carried out in the vanilla RNN, since both RNNs share the same temporal dependency ordering between the hidden states. Although the slice-at-a-time inference option seems less able to take advantage of parallel hardware, it also seems potentially more efficient when the sequence length is extremely long since we perform the minimum number of iterations needed for convergence on each state before moving on to the next. Both options seem interesting but a detailed empirical comparison of their relative efficiency is outside the scope of this paper. Perhaps a batch-wise version could also be interesting as a topic for future research. We only consider the latter slice-at-a-time option in the experiments in Section 5.8.

3.4 Training by unrolling NMF inference and backpropagation

The training algorithm in Section 3.3 used standard MF update rules to learn the model weights. These updates optimize the local reconstruction loss of the data matrix. Depending on the task, we empirically observed that such algorithms can sometimes be sufficient and we demonstrate example of this in the experiments Sections 5.8 5.9. However, we generally found this method to fail to perform well on more complex tasks. For this reason we use the algorithm unrolling approach as discussed in Section 2.4 for all other experiments.

It is straightforward to apply unrolling to the factorized RNN since the inference algorithm remains unchanged: we evaluate the computation graph and backpropagate through it. Since we replaced each MLP of the vanilla RNN with a corresponding PFC block, the computation graph dependencies result in the inference progressing one time slice at a time, similar to the vanilla RNN. However, note that in performing the inference in a given time slice, the PFC block’s unrolled NMF update steps form another RNN (internal to each PFC block) with length corresponding to the number of unrolled iterations. We therefore have a computation graph corresponding to an RNN within an RNN.

We compute the loss using the MSE loss between the targets and the predicted values. Since the factorized RNN has three predicted sub-matrices in VV (Eq. 3.12), this means we will have three loss terms. Once the inference procedure converges so that the inferred hidden states are available in HH, we can predict V^\hat{V} as follows:

[Y^Hprev^X^]=[WyWhWx]H\displaystyle\begin{bmatrix}\hat{Y}\\ \hat{H_{prev}}\\ \hat{X}\end{bmatrix}=\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}H (3.15)

The total loss is the the sum of the three loss terms with an arbitrary non-negative scale factor to adjust the relative strength of each term:

loss=λyMSE(Y,Y^)+λhMSE(Hprev,Hprev^)+λxMSE(X,X^)\displaystyle loss=\lambda_{y}MSE(Y,\hat{Y})+\lambda_{h}MSE(H_{prev},\hat{H_{prev}})+\lambda_{x}MSE(X,\hat{X}) (3.16)

We then backpropagate through the loss to compute the error gradients and use an existing SGD-based optimizer such as RMSprop etc. to update the weights. Since the inference procedure operated one time slice at a time, we can see that the gradients then flow backward through the entire sequence, effectively making it a instance of backpropagation through time (BPTT).

4 An optimizer for continual learning and non-i.i.d. training

To motivate our approach, recall that k-NN simply stores the training examples as is and then performs classification by finding the nearest training examples to the current input. Such an instance-based model can easily support continual learning since we simply append the new examples to the model weights as they become available. Knowledge removal after training (unlearning) is also easily supported by simply removing any desired subset of “bad” training examples from model weights. In contrast, the knowledge representation in the PFC-based model is more distributed since the SGD optimizer update from any particular training example or batch could potentially result in modifications to any of the weights. Additionally, each optimizer update only slightly modifies the weight values, so that many updates are needed before the weights can be considered fully learned. With this understanding, we can modify the optimizer update so that only a narrow unmasked “learnable window” of basis vectors in each weights matrix WiW_{i} in the model are able to be modified. We can allow the window to slowly sweep through WiW_{i} (e.g., from left to right), advancing slightly with each optimizer update, so that each basis vector ideally remains inside the window long enough to be effectively learned, but not so long as to be overwritten if or when the distribution of training examples changes. If we keep track of the mapping from training batch index to window position during training, then we will be able to identify the (small) subset of weights that can potentially contain the corresponding learned knowledge. For example, this would be the case if the ordering of training examples within each each does not change, such as using a fixed random shuffling. If the training distribution changes (e.g., in continual learning and/or non-i.i.d. training), the learned basis vectors will be protected from being overwritten because they are only learnable for limited number of optimizer updates over which we assume the training distribution to be relatively unchanging.

We now introduce a new “sliding learnable window” (SLW) optimizer that can be used with PFC-based models to support improved continual learning, training with non-i.i.d. examples, and knowledge removal after training. Specifically, suppose each weights matrix WiW_{i} has RR columns (basis vectors) in total. The learnable window will have a width of LL basis vectors, where LL is a tunable hyperparameter and LRL\ll R. We denote the current position of the window by the index rr of its left-most column in WW. When training starts, we initialize the learnable window to consist only of the leftmost LL basis vectors of each weight matrix by setting r=0r=0. As training progresses, we then increment rr by some small fractional amount, which is specified by the hyperparameter sweep_speedsweep\_speed. This results in the learnable window slowly sweeping to the right within its weight matrix. We can see that this will have the effect of leaving all basis vectors to the left of rr frozen at their current values. Only the LL basis vectors at column indices [r,r+L)[r,r+L) receive optimizer updates from the current training batch. Likewise, all basis vectors to the right of the window (i.e., with index kk such that kr+Lk\geq r+L) consist of unused weights which have not yet made their way into the learnable region. If the sweep speed is chosen too small, we will end up with many unused basis vectors at the end of training. However, if it is chosen too large then we will run out of weights storage before training can complete, unless we dynamically allocate additional weights storage and concatenate it to the right of WW to ensure a continuous supply of unused weights.. If LL is chosen too small, any given basis vector might not remain under the learnable window long enough to be fully learned. However, if it is chosen too large then the basis columns could remain learnable too long so that they would then be vulnerable to being overwritten as the training distribution shifts. In the case of non-i.i.d. training or when unlearning capability is required (and assuming the ordering of examples does not change from epoch to epoch), resetting r=0r=0 at the beginning of each epoch will ensure that the optimizer update for any given training batch is always mapped to same location rr in WiW_{i}.

5 Experiments

In this section we present experimental results demonstrating the use of the PFC block as a replacement for the MLP block. Section 5.1 first establishes baseline accuracy results of an MLP-based classifier. In Section 5.2 we then replace the MLP with a PFC block and evaluate its accuracy on the same datasets. We also conduct ablation experiments to compare the effect of different modeling constraints such as NMF vs semi-NMF, as well as the effect of either including or disabling the input reconstruction loss term. We compare these MLP and PFC-based models on a continual learning task in Section 5.3 and on a non-i.i.d. training task in Section 5.4. We demonstrate how PFC-based models can support knowledge removal after training in Section 5.5. We show another example of interpretabiity by visualizing in-domain vs out-of-domain inputs in Section 5.6.

A usable MLP replacement must also be capable of supporting architectures with multiple blocks. The remaining experiments consider two simple multi-block architectures. In Section 5.7 we show results for a 2-block (fully-connected) residual network using PFC blocks instead of MLP blocks and demonstrate that it produces competitive accuracy with corresponding MLP-based models. Replacing the MLP blocks of the vanilla RNN with PFC blocks results in a factorized RNN as introduced in Section 3 and we present results for these sequential architectures in Sections 5.8 5.9 5.10 5.11. We also perform the following ablations: We compare either using the default backpropagation (i.e., unrolled NMF inference algorithm) or disabling it and using NMF-based weight update rules similar to [18]. When using backpropagation-based training, we also evaluate the effect of disabling BPTT.

We conduct the image classification experiments using fully-connected models instead of architectures such as convolutional networks that arguably have a more suitable inductive prior. Our accuracy results will therefore be significantly below state of the art on the datasets used. Since the purpose of these experiments is to evaluate the suitability of the PFC block as a more interpretable replacement to the MLP block, we are therefore concerned with comparing the relative accuracy of these two blocks rather than attempting to achieve state of the art results. For the same reason, we compare the relative accuracy of the vanilla RNN against the corresponding factorized RNN rather than using more sophisticated RNN or transformer models on the sequence modeling tasks.

All experiments were carried out using single RTX 4090 GPU. Due to the limited compute, we did little hyperparameter tunning and worked with small datasets. Consequently, it seems possible that further improvements to accuracy and/or efficiency could be obtained with additional hyperparameter tuning and it remains unknown how well these results will scale to larger datasets. We did not perform ablations on the number of iterations required for reliable convergence. It is potentially possible to improve efficiency by dynamically unrolling the inference algorithm only for the number of iterations needed for reasonable convergence based on the current inputs. However, we do not attempt this in these experiments and leave it as future research. Results for the MLP-based models were generally run 3 times, but due to the limited compute, the PFC-based model results are shown for a single run and not averaged unless otherwise mentioned.

We trained the models using early stopping with an 85%-15% train-validation split. For simplicity and convenience, we use the MSE loss everywhere, for both the classification/regression loss and the input reconstruction loss (applicable to PFC-based models only). We used equal weighting between the reconstruction and prediction loss terms for simplicity. We use the RMSprop optimizer [24] since it is a simple optimizer that we found to perform well with minimal hyperparameter tuning. For the PFC-based models, parameters are initialized to uniform random values in [0,1e2][0,1e-2] by default, corresponding to the NMF modeling constraint. In some experiments we allow negative parameters, corresponding to the semi-NMF modeling constraint. When negative parameters are allowed, they are initialized to uniform random values in [1e2,1e2][-1e-2,1e-2]. The inferred values for the HH factor matrices are always constrained to be non-negative. We initialize the HH values to zeros, but note that it is also an option to initialize them to small random values. Unless otherwise mentioned, we set the learning rate to 3e-4 for the PFC-based models and 1e-4 for the MLP models. The weight decay was set to 1e-4 unless otherwise mentioned. For the MLP experiments, weights were initialized using the default PyTorch LinearLayer initializer and negative parameters were always allowed since attempting to use non-negative weights with MLPs resulted in optimization difficulties. The GELU activation [22] was used as the MLP hidden layer activation function.

5.1 MLP baseline for image classification

We first evaluate a simple 1-hidden-layer MLP-based classifier as a baseline model on the MNIST [27], Fashion MNIST [28], and CIFAR10 [29] datasets. For each dataset, we train models for hidden dimensions sizes of 300, 2000, and 5000. The input feature size is equal to the number of image pixels when the image is flattened into a vector. This is 28x28=78428x28=784 for MNIST and Fashion MNIST since they contain 28x28 grayscale images and 32x32x3=307232x32x3=3072 for CIFAR10 since color images are used. The output layer dimension is 10 since all three datasets have 10 class labels. Table 1 shows the accuracy results.

Table 1: Results of the baseline MLP image classifier model (averaged over 5 training runs).
Dataset Hidden Dimension Test Accuracy
MNIST 300 98.01%
MNIST 2000 98.26%
MNIST 5000 98.32%
Fashion MNIST 300 88.07%
Fashion MNIST 2000 88.72%
Fashion MNIST 5000 88.85%
CIAFAR10 300 51.59%
CIAFAR10 2000 52.78%
CIAFAR10 5000 53.17%

5.2 PFC network for image classification

We train a 1-block PFC-based network on the same datasets and with the same parameter sizes as the MLP from Section 5.1. Recall that the PFC basis vector count corresponds to the MLP hidden dimension. We also train with and without enforcing non-negative parameters (i.e., NMF vs semi-NMF) and also evaluate the effect of including vs disabling the input reconstruction loss term.

We train with early stopping after 20 epochs with no validation loss improvement. Table 2 shows the accuracy results, averaged over 3 training runs. We see that using semi-NMF generally leads to slightly better accuracy. The impact of input reconstruction loss on accuracy is less clear and seems to vary depending on the specific dataset and parameter configuration. Since both the NMF constraint and enabling reconstruction loss can potentially lead to better interpretability, we will enable the reconstruction loss in all remaining experiments. We will also use the NMF parameter constraint in the remaining experiments unless otherwise mentioned.

Comparing the PFC results in Table 2 with the MLP results in Table 1 shows the PFC-based models to perform competitively. We see that the PFC-based model performs significantly better compared to the MLP on CIFAR10, slightly better on Fashion MNIST, and similarly on MNIST, although the MLP does perform slightly better on MNIST for the case of 300-dimensional hidden dimension when the NMF constraint is used on the PFC-based network.

Table 2: Results of the 1-block PFC-based image classifier model (averaged over 3 training runs).
Dataset Basis Vector Count Parameter Constraints Input Reconstruction Loss Test Accuracy
MNIST 300 NMF Yes 96.83%
MNIST 300 NMF No 97.54%
MNIST 300 Semi-NMF Yes 97.51%
MNIST 300 Semi-NMF No 98.68%
MNIST 2000 NMF Yes 97.78%
MNIST 2000 NMF No 98.14%
MNIST 2000 Semi-NMF Yes 98.47%
MNIST 2000 Semi-NMF No 98.84%
MNIST 5000 NMF Yes 98.08%
MNIST 5000 NMF No 98.27%
MNIST 5000 Semi-NMF Yes 98.65%
MNIST 5000 Semi-NMF No 98.88%
Fashion MNIST 300 NMF Yes 88.62%
Fashion MNIST 300 NMF No 88.24%
Fashion MNIST 300 Semi-NMF Yes 88.60%
Fashion MNIST 300 Semi-NMF No 89.77%
Fashion MNIST 2000 NMF Yes 90.21%
Fashion MNIST 2000 NMF No 90.02%
Fashion MNIST 2000 Semi-NMF Yes 90.58%
Fashion MNIST 2000 Semi-NMF No 90.56%
Fashion MNIST 5000 NMF Yes 90.67%
Fashion MNIST 5000 NMF No 90.64%
Fashion MNIST 5000 Semi-NMF Yes 90.82%
Fashion MNIST 5000 Semi-NMF No 90.78%
CIFAR10 300 NMF Yes 53.25%
CIFAR10 300 NMF No 54.42%
CIFAR10 300 Semi-NMF Yes 51.54%
CIFAR10 300 Semi-NMF No 54.32%
CIFAR10 2000 NMF Yes 58.69%
CIFAR10 2000 NMF No 58.30%
CIFAR10 2000 Semi-NMF Yes 57.46%
CIFAR10 2000 Semi-NMF No 55.78%
CIFAR10 5000 NMF Yes 60.12%
CIFAR10 5000 NMF No 59.68%
CIFAR10 5000 Semi-NMF Yes 59.69%
CIFAR10 5000 Semi-NMF No 57.90%

5.3 Continual learning on the Split MNIST task

In this experiment we evaluate and compare the performance of the PFC and MLP-based models on the Split MNIST task [30] under the Class-IL scenario as described in [13]. In Split MNIST, the MNIST dataset is split into 5 task partitions, so that each task only contains two digits: Split 0 contains digits 0 and 1, Split 1 contains digits 2 and 3, and so on up to Split 4. The Class-IL scenario is the most difficult of the three considered in [13], as it requires the model solve the tasks that have appeared so far as well as inferring the task ID. Since the model is not told which of the 5 tasks it needs to solve, it only receives the image pixels as input and needs predict the correct digit label. The model will therefore have 28x28 = 784 inputs corresponding to the image pixels flattened into a vector and it will have 10 outputs corresponding to the digit labels. We train the model on one split at a time, ordered by the split number so that Split 0 will be the first task. Within each task the examples are presented i.i.d. to the model. The validation loss is computed over all tasks seen so far and early stopping is used to end the current task and move on to the next once the best validation loss is achieved.

5.3.1 Continual learning performance of baseline MLP-based model

In [13], the authors found that regularization-based continual learning methods such as EWC [31] fail completely on the Class-IL scenario and that memory replay-based methods were needed in order to achieve acceptable accuracy. Specifically, on the Split MNIST task under Class-IL, they found that both a baseline MLP model and regularization-based methods such as EWC resulted in accuracies in the 19-20% range (Table 4 of [13]).

We also evaluate an MLP as a baseline model on this task. We experimented with a range of hidden dimension sizes from 300 - 2000, but only report results for a hidden dimension of 1357 since it corresponds to the same parameter size as the PFC-based model used in Section 5.3.2. We use the RMSprop optimizer with weight decay set to 1e-4. When we tried using a learning rate of 1e-4 that worked well in the MLP for other experiments, it resulted in 19-20% accuracy on the test set, roughly matching the results reported in [13]. However, when we tuned the learning rate to maximize the validation set accuracy, we were surprised to find much better performance when the learning rate was reduced to 2e-6. This resulted in a test set accuracy of 40.11%. We therefore suspect that the difference in accuracies between our baseline MLP and that reported in [13] could be due to hyperparameter tuning. Regardless, even 40.11% is a poor result considering that the upper bound on accuracy when the examples of all tasks are combined together and presented i.i.d to the model is 97.94% [13]. In the following sections, we will attempt to improve from this 40.11% baseline accuracy without resorting to replay-based methods.

5.3.2 Continual learning performance of baseline PFC-based model

We evaluate a one-block PFC-based model using non-negative weights and 1357 basis vectors, resulting in approximately the same parameter sizes as the baseline MLP. We continue to use the RMSprop optimizer with weight decay of 1e-4 and a learning rate of 1e-5 found through hyperparameter search on the validation set. This model achieved 67.07% accuracy on the test set, which is significantly higher than the baseline MLP but still far below the upper bound accuracy.

Since this model uses non-negative weights and can potentially learn parts-based representations, it is interesting to visualize the learned weights after training on each of the five tasks. Perhaps we might then be able to better understand what could be causing the model to gradually forget the earlier learned tasks. Figure 1 shows the first 100 weight basis vectors, reshaped into images after learning to classify the two MNIST digits in each of the 5 consecutive split tasks. Comparing these images, we immediately notice a problem: some of the “digit” images learned in earlier tasks gradually fade away as the model learns newer tasks. For example, after completing the first split task, the model can classify the digits 0 and 1, and so we see that the weights in Figure 1(a) look like zeros, ones, or noise. The model learns to classify digits 2 and 3 during the next split task and so we see these new digits appear in the weights after this task has completed, as expected. However, notice that some of the original 0 and 1 patterns have started to fade or degrade slightly as well. By the time the model has completed training the final split task (i.e., classifying digits 8 and 9), notice that many of the original 0 and 1 patterns are now quite degraded, although a few of them still seem relatively unaffected. We can also see that digits 8 and 9 are difficult to find after learning the final split task in Figure 1(e). It seems this is due to the model only training a single epoch on the final task, which maximized the overall validation loss (which is now computed over all splits). It was apparently a better accuracy tradeoff to have relatively poor performance in classifying digits 8 and 9, rather then learning to classify them well and significantly reduce the performance (through forgetting) on all previous tasks. In the next section, we will introduce a simple method to prevent the earlier learned weights from degrading as new tasks are learned.

Refer to caption
(a) Split 1/5: digits 0 and 1
Refer to caption
(b) Split 2/5: digits 2 and 3
Refer to caption
(c) Split 3/5: digits 4 and 5
Refer to caption
(d) Split 4/5: digits 6 and 7
Refer to caption
(e) Split 5/5: digits 8 and 9
Figure 1: Reconstruction weights of the PFC-based model after learning to classify two MNIST digits in each of the 5 consecutive split tasks. The first 100 weight basis vectors are shown reshaped as images. Notice that two new digit features appear after training on each split. We see that image features (such as the 0 and 1 digits in the first split) of earlier tasks start to degrade as the model learns additional tasks.

5.3.3 Continual learning performance of PFC-based model with a learnable sliding window optimizer

We have implemented this sliding learnable window idea from Section 4 in a customized RMSprop optimizer which we will refer to as RMSpropSLW. This optimizer takes two additional hyperparameters compared to the standard one: We must specify the learnable width LL of the window, as well as sweep_speedsweep\_speed, which specifies the (fractional) number of columns that the window advances to the right on each optimizer update.

We then repeat the Split MNIST experiment from Section 5.3.2, replacing the RMSprop optimizer with RMSpropSLW. We set the initial (unused) number of basis vectors to 2000 in each of the two weight matrices. We used a sweep speed of 0.25 and a learnable width of 15. We adjusted the slide speed until training was able to complete without using all 2000 basis vectors. Using these values resulted in the right-most column of the learnable window arriving at the 1357’th basis column at the end of training the last task. Note that since early stopping was used to decide when to switch to the next task, the number of in-use basis vectors at the end of training will vary slightly from run to run. Thus, 643 columns remained unused at their initialized random values. This resulted in an accuracy of 93.73% on the test set, demonstrating the effectiveness of the approach compared to the standard RMSprop optimizer. Figure 2 shows the weights (reshaped as images) immediately to the left of the learnable window as each of the split tasks is completed. Since these weights are now to left of the sliding learnable window, they remain frozen during the remainder of training. Thus, weights learned during the first split task (i.e., classifying digits 0 and 1) as shown in 2(a) were protected from being overwritten during later tasks. Notice that the image features visible after training each split only contain the two digits learned during the split. For example, 2(c) only contains image features for the digits 4 and 5, since the training examples for this split only contain these two digits. Table 3 summarizes the results on this task for the various approaches discussed. We see that the PFC-based model performs significantly better compared to the MLP even when both use the same RMSprop optimizer. Switching to the RMSpropSLW optimizer further increases the accuracy of the PFC-based model much closer to the upper bound accuracy. Note that neither the RMSpropSLW optimizer nor the model is given any information about which task is active, as required by the Class-IL scenario. Also note that the MLP-based model cannot use the RMSpropSLW optimizer since the modeling assumptions appear to be incompatible. For reference, replay-based approaches were reported to achieve between 90.79% and 91.79% and replay + exemplars were reported to achieve 94.57% accuracy in Table 4 of [13]. This shows that our non-replay-based PFC + RMSpropSLW approach is somewhat competitive with even replay-based approaches.

Refer to caption
(a) Split 1/5: digits 0 and 1
Refer to caption
(b) Split 2/5: digits 2 and 3
Refer to caption
(c) Split 3/5: digits 4 and 5
Refer to caption
(d) Split 4/5: digits 6 and 7
Refer to caption
(e) Split 5/5: digits 8 and 9
Figure 2: Reconstruction weights of the PFC-based model after learning to classify two MNIST digits in each of the 5 consecutive split tasks, using the sliding learnable window optimizer. The first 25 weight basis vectors are shown, reshaped as 28x28 images. Each sub-figure shows the weights learned during the corresponding split task by extracting the 25 basis vectors to the left of the learnable window just after training of the corresponding task has completed. Since these weights are outside (i.e., to the left of) the sliding learnable window, they remain frozen at their current values and thus protected from degradation during the remainder of training.
Table 3: Comparison of non-replay-based approaches on Split MNIST task under the Class-IL scenario.
Approach Optimizer Test Accuracy
MLP offline iid training ([13])(upper bound) RMSprop 97.94%
MLP baseline (ours) RMSprop 40.11%
MLP baseline ([13]) ADAM 19.90%
EWC ([13]) ADAM 20.01%
PFC RMSprop 67.07%
PFC RMSpropSLW 93.73%

5.4 Non-i.i.d. training: MNIST classification with label-ordered examples

Neural networks are most effectively trained when the stream of training examples is presented i.i.d. and the class labels are balanced. Training difficulties being to arise when distribution shifts are introduced. In this experiment, we train MLP and PFC-based models on the MNIST classification task. However, rather than the usual i.i.d. training in which the examples are presented in shuffled order, we instead present the examples in a fixed label-sorted order. Specifically, we sort the training examples ascending by class label so that all of the digit 0 examples are presented first, followed by all of the digit 1 examples, and so on, until finally presenting all of the digit 9 examples. We also train the networks with the usual i.i.d. (shuffled) ordering for comparison to provide an upper bound on the achievable accuracy.

With the PFC-based model, we also have the option of using the sliding learnable window optimizer from Section 4, which is implemented in RMSpropSLW. Since the examples are presented in the same order each epoch, we simply reset the learnable window to the starting position (i.e., the left-most column of each weight matrix) at the beginning of each epoch. This can potentially improve performance when training on non-i.i.d. data since optimizer updates for a particular example (or its batch) are constrained to a small learnable window of the weights. As a result, only nearby training batches (in terms of the number of training iterations between them) are capable of overwriting previous learned knowledge. More widely spaced batches cannot interfere with each other.

5.4.1 Training details

The batch size was set to 50 for all models. The MLP network has a single hidden layer with a size of 2600. We use the RMSprop optimizer. For for i.i.d. training case, the learning rate is 1e-4. The weight decay is 1e-4. For the label-ordered case, we found a lower learning rate of 5e-7 and no weight decay to give the best validation performance.

For the PFC-based network, we trained networks under the following 4 combinations: using either i.i.d. or label-sorted examples, and using either RMSprop or RMSpropSLW optimizers. When using RMSpropSLW, we set the maximum weight basis vector count to 3000, of which 2600 were actually used in training. We used a learning rate of 5e-6 and no weight decay. When using RMSprop, we used 2600 basis vectors. This resulted in the PFC-based networks having the same parameter count as the MLP network. The weights were constrained to be non-negative. We used a learning rate of 2e-3 and weight decay of 1e-4.

5.4.2 Results

The results are summarized in Table 4. We see that the MLP and PFC-based models perform similarly when trained on i.i.d. examples and using the RMSprop optimizer. When the PFC-based model is instead trained on the RMSpropSLW optimizer, we see the accuracy is reduced slightly. This is not surprising since the use of a narrow learnable window constrains only a small subset of the weights from receiving optimizer updates. We see that the MLP accuracy suffers worse degradation when label-sorted training is used, only reaching 84.92%, compared to the PFC-based model’s 88.19% when both use RMSprop. However, the PFC-based model’s accuracy only degrades slightly to 96.06% when using the RMSpropSLW optimizer. Note that when using RMSpropSLW, the PFC-based model has similar accuracy under both the i.i.d. and label-sorted cases. In summary, the PFC-based model was more robust to non-i.i.d. examples than the MLP when both used the same optimizer. When using the sliding window RMSpropSLW (which is only compatible with the PFC-based model), this robustness was further increased.

Table 4: Non-i.i.d. training results on label-ordered MNIST classification task. Corresponding i.i.d. results are also shown as an accuracy upper bound.
Model Training Method Optimizer Test Accuracy
MLP i.i.d. training (upper bound) RMSprop 97.91%
PFC i.i.d. training (upper bound) RMSprop 97.88%
PFC i.i.d. training (upper bound) RMSpropSLW 96.02%
MLP label-ordered RMSprop 84.92%
PFC label-ordered RMSprop 88.19%
PFC label-ordered RMSpropSLW 96.06%

5.5 Unlearning: removing knowledge from a trained model

It is sometimes desireable to remove learned knowledge from a model. For example, if it is discovered that certain training batches contained errors, or if certain knowledge needs to be removed for legal reasons. A large model might take a long time to train, and so the ability to quickly remove specific knowledge could be preferable to retaining from scratch.

In this experiment, we train a network on the MNIST classification task in which some of the training examples have incorrect class labels. We show that this reduces the classification performance, as is expected. Provided that the bad examples/batches can be identified after training has completed, we show that it is possible to perform unlearning by removing the subset of model weights that were influenced by these bad examples during training, restoring much of the lost classification accuracy.

We use a model with 1 PFC block and 3000 available basis columns for each of the two weight matrices. For simplicity, we constrain the corrupted training examples so that they appear a contiguous range of batches. We use the same fixed shuffling of examples each epoch so that the i’th batch of each epoch always contains the same examples. There are then a total of 1020 batches per epoch (with 50 examples per batch). Batches 500 through 800 are corrupted by changing the class label to different (and incorrect) class. This is accomplished by incrementing the label modulo the number of classes.

We use the same ‘RMSpropOptimizerSlidingWindow‘ optimizer from the continual learning experiments. We reset the sliding window to the left-most position at the beginning of each epoch and use a deterministic dataset loader so that the i’th batch always contains the same examples in each epoch. This will cause the learning updates corresponding to the i’th batch to be stored within a knowable range of columns in the weights matrices, corresponding to the position of the sliding window while the said batch was active.

After the model is trained, 2600 of the 3000 basis columns are in use, as Figure 3(a) shows. The right-most 400 columns remain at their randomly initialized values, but this is not a problem and does not affect the results since these columns are not activated during the NMF inference process. Since the training data included a significant fraction of corrupted examples, the classification accuracy is a somewhat low 83.81% on the MNIST test set.

Next, we perform the unlearning operation, which will attempt to remove the knowledge obtained from the corrupted training examples from the network. This works as follows. First, recall that training batches 500 through 800 were identified as containing the corrupted labels. Since our SLW optimizer constrains the optimizer update for each batch to a small learnable window of the weights, we need to find the corresponding union of all positions of the learnable window during learning updates for these batches. We find that batch index 500 corresponds to the left-most column of the learnable window being at index 1250, and batch index 800 corresponds to the right-most column of the learnable window being at index 2050. That is, during this range of (corrupted example) batch updates, the sliding learnable window covered columns in the weight matrices ranging from 1250 through 2050, so that any knowledge learned from these batches must be in this subset of the weights. It is then straightforward to remove the knowledge, such as by deleting these weights, setting them to zero, or reinitializing to random values. For this experiment, we set these weights to zero, as shown in Figure 3(b). With the corrupted knowledge removed, we then re-evaluate the network on the test set and see that the accuracy has improved to 92.77%. For reference, if we train the model from scratch excluding the corrupted examples, we get 95.35% on the test set, which sets an upper bound on the unlearning accuracy. These results are summarized in Table 5. This shows that our method is effective in restoring accuracy without retraining, provided that we are able to identify which training batches were bad. Note that if any training examples are identified as bad, then the entire batch and corresponding region of the weights must be thrown out. This method is therefore most effective when the subset of batches to remove corresponds to a contiguous sequence of batches in the training data.

Refer to caption
Refer to caption
(a) Weights before unlearning
Refer to caption
Refer to caption
(b) Weights after unlearning
Figure 3: Model weights before and after unlearning. 2600 out of the 3000 columns of in use, resulting in unused randomly initialized values in the right-most 400 columns. After unlearning, notice that column indices 1250 through 2050 have been removed.
Table 5: Unlearning performance on MNIST Classification
Model Weights at Evaluation Test Accuracy
Including only good data (upper bound) 95.35%
Before unlearning 83.81%
After unlearning 92.77%

5.6 Visualizing out of domain interpretability

In this section, we train a network containing 1 PFC block on an image classification task and then evaluate it on both in-domain and out-of-domain (OOD) inputs. Since the network produces reconstructed input features during the recognition process, it is interesting to visualize and compare these reconstructions on in-domain vs OOD inputs. We might expect that the learned weights would correspond to the parts of the in-domain images, so that the reconstruction quality should be better on in-domain vs OOD inputs. This is because when the network is given an OOD image, it is forced to attempt to reconstruct it using on the “in domain” parts, potentially reducing the reconstruction quality compared to in-domain inputs.

We now empirically investigate these effects using MNIST as the in-domain images and Fashion MNIST as the OOD images. We train a small network on the MNIST classification task using 100 weight templates to enable the easy visualization of all weights in a single plot figure. We use the combined input reconstruction and classification loss as usual, except that here we use an adjustable trade-off between the two in the form of a hyperparameter λ[0,1]\lambda\in[0,1]:

L=λLclassification+(1λ)Lreconstruction\displaystyle L=\lambda L_{classification}+(1-\lambda)L_{reconstruction} (5.1)

Figure 4 shows the weights, reshaped into images, for models trained using different values of λ\lambda. As λ\lambda is increased, the strength of the classification loss increases and the reconstruction loss decreases. Thus, sub-figure 4(a) corresponds to a loss that emphasizes classification accuracy since only a small reconstruction loss is used. This is reflected in the good classification accuracy. We see that some of the weights resemble MNIST digits, but the images appear somewhat “noisy”. The middle sub-figure 4(b) shows the weights using the default blend of classification and reconstruction loss used in most of the other experiments. Here we see that the weights resemble MNIST digits or parts thereof. Finally, sub-Figure 4(c) shows the effect of a strong reconstruction loss. We see that the weights now appears as more localized image parts and that the classification accuracy is significantly lower.

Refer to caption
(a) λ=0.9\lambda=0.9, accuracy = 96.21%
Refer to caption
(b) λ=0.5\lambda=0.5, accuracy = 95.21%
Refer to caption
(c) λ=0.1\lambda=0.1, accuracy = 85.94%
Figure 4: Visualization of model weights for different trade-offs between classification accuracy and input reconstruction quality, for different values of λ[0,1]\lambda\in[0,1]. Corresponding classification accuracy is also shown. As λ\lambda is increased, the strength of the classification loss increases and the reconstruction loss decreases.

We now compare the input reconstructions produce by the model for in-distribution vs OOD examples. For the following, keep λ\lambda set to the default value of 0.5. For the in-distribution visualization, we train on MNIST and also evaluate on a batch of MNIST test images, as shown in Figure 5. Sub-figure 5(a) shows a batch of input MNIST test images and sub-figure 5(b) shows the corresponding input reconstructions generated by the model. We see that many of reconstructed images resemble the corresponding inputs, although some are less recognizable.

Refer to caption
(a) Input images (in-distribution, MNIST test set)
Refer to caption
(b) Reconstructed images
Figure 5: Model response to in-distribution inputs. The model was trained on an MNIST classification task and also evaluated on MNIST images here.

We now move on to the visualization of the model’s reconstructed images when given OOD inputs. For this we will use a batch of images from the Fashion MNIST test set, which contains grayscale images of fashion products of the same size as MNIST images. Recall that the model is constrained to reconstruct an input image by solving for the best additive combination of its weights (i.e., using the 100 images in 4(b)). However, since the available weights correspond to MNIST images and/or their parts, we might expect the reconstructed Fashion MNIST inputs to have less resemblance to the actual images. Indeed, this is what we observe in Figure 6, which shows the OOD inputs and their reconstructions.

Refer to caption
(a) Input images (OOD, Fashion MNIST)
Refer to caption
(b) Reconstructed images
Figure 6: Model response to OOD inputs. The model was trained on an MNIST classification task and but evaluated on Fashion MNIST images here.

5.7 Residual PFC network for image classification

In this experiment we use more than one PFC block to build a more complex architecture. The main purpose of this is to verify that if there are no optimization difficulties then the network should produce similar accuracy as that 1-block network in Section 5.2. We construct a simple fully-connected residual network containing 2 PFC blocks in which the first block uses a skip connection. That is, we let xx represent the input to the first PFC block, which then produces output y1y_{1}. The input to the second PFC block is then given by x2=relu(xy1)x_{2}=relu(x-y_{1}). The relu is optional when using the semi-NMF assumption but is needed when using the NMF constraint to prevent the input x2x_{2} from becoming negative. The second PFC block then outputs the final prediction ypredy_{pred}.

We train on the same datasets as in Section 5.2 using the same basis vector sizes (equivalent to MLP hidden dimension). Since there are two blocks, the total number of parameters is doubled from the 1-block model. Here we only train using the NMF modeling constraint (non-negative parameters) and we use both prediction and input reconstruction MSE loss terms for each PFC block. Table 6 shows the accuracy results on the test set. We see that the accuracy results appear in a similar range compared to those of the 1-block model in Table 2.

Table 6: Results of the residual 2-block PFC-based image classifier model. Parameters are constrained to be non-negative.
Dataset Basis Vector Count Test Accuracy
MNIST 300 97.81%
MNIST 2000 98.30%
MNIST 5000 98.09%
Fashion MNIST 300 88.92%
Fashion MNIST 2000 89.89%
Fashion MNIST 5000 90.08%
CIFAR10 300 54.00%
CIFAR10 2000 56.26%
CIFAR10 5000 58.15%

5.8 Memorizing a deterministic sequence with a factorized RNN

We developed the factorized RNN in Section 3.2 and modeled it as the matrix factorization of the form VWHV\approx WH shown in Eq. 3.9 which we repeat here:

[y0y1y2yT1h1h0h1hT2x0x1x2xT1][WyWhWx][h0h1h2hT1]\displaystyle\begin{bmatrix}y_{0}&y_{1}&y_{2}&\dots&y_{T-1}\\ h_{-1}&h_{0}&h_{1}&\dots&h_{T-2}\\ x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (5.2)

We also provided a more compact notation in which the sequences are replaced by their respective sub-matrices in Eq. 3.12, which we repeat here:

[YHprevX][WyWhWx]H\displaystyle\begin{bmatrix}Y\\ H_{prev}\\ X\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}H (5.3)

We then provided two basic training procedures. Here we consider the first method which uses matrix factorization update rules for both inference and learning as described in Sections 3.3. It is initially unclear whether such simple update rules could even learn a task spanning several time slices, since there is no backward flow of error gradient-like information through the sequence. We therefore think it seems appropriate to start with a relatively simple temporal learning task.

5.8.1 Training on a repeating sequence

For this initial task, we present a fixed-length pattern that is repeated over and over in the training data and require the network to predict the next item in the sequence. Since we use a deterministic repeating pattern, the network only needs to identify and memorize this underlying pattern in order to predict with perfect accuracy. We also chose this task because we know the underlying generative model corresponds to a simple finite state machine (FSM) containing the same number of (deterministic) transitions as there are time slices in the repeating pattern. We therefore know the minimum number of model parameters that are needed in principle to solve it and it is straightforward to imagine interpretable solutions to the corresponding RNN factorization ourselves. This task therefore also serves as a simple interpretability test for the model since we can train the model and then visualization the learned weights and see whether they align with the interpretable solutions that we know should be possible.

Specifically, we use the following fixed repeating pattern consisting of 25 4-dimensional 1-hot vectors for easy presentation and visualization, shown here represented as integer-valued tokens for easier readability:

[0,1,1,2,2,2,3,3,3,3,3,3,3,3,3,1,2,2,2,2,2,1,3,2,1]\displaystyle[0,1,1,2,2,2,3,3,3,3,3,3,3,3,3,1,2,2,2,2,2,1,3,2,1] (5.4)

We then repeat this pattern 8 times to create the full training sequence X=[x0x1x2xT1]X=\begin{bmatrix}x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix} shown in Figure 7, which has a length of T=200T=200.

Refer to caption
Figure 7: A training sequence XX of length 200 consisting of a fixed repeating sub-sequence of 1-hot vectors of length 25.

We want the model to memorize the repeating pattern and so it needs to predict the next vector in the sequence. For this we can use an autoregressive model so that Y=[x1x2x3xT]Y=\begin{bmatrix}x_{1}&x_{2}&x_{3}&\dots&x_{T}\end{bmatrix} in Eq 3.12. With this, Eq 3.9 looks like the following:

[x1x2x3xTh1h0h1hT2x0x1x2xT1][WyWhWx][h0h1h2hT1]\displaystyle\begin{bmatrix}x_{1}&x_{2}&x_{3}&\dots&x_{T}\\ h_{-1}&h_{0}&h_{1}&\dots&h_{T-2}\\ x_{0}&x_{1}&x_{2}&\dots&x_{T-1}\end{bmatrix}\approx\begin{bmatrix}W_{y}\\ W_{h}\\ W_{x}\end{bmatrix}\begin{bmatrix}h_{0}&h_{1}&h_{2}&\dots&h_{T-1}\end{bmatrix} (5.5)

We use NMF for the inference and learning updates. Specifically, we use SGD updates with negative value clipping and scaling of the rows and columns of the factor matrices to keep them from exploding as described in Section A. For the results in this section, we do not use any weight decay or sparsity regularizers so that only the non-negativity constraint is used. We initialize the weights W=Wy,Wh,WxW={W_{y},W_{h},W_{x}} to non-negative random values uniform in [0,1e2][0,1e-2] and initialize the hidden states (HprevH_{prev} and HH) to zero. Recall from Section 3.2 that the hyperparameter RR specifies both the dimensionality of the hidden state vectors hkh_{k} as well as the number of basis column vectors in WW so that WhW_{h} is an RR x RR sub-matrix of WW. Since the repeating sub-sequence has length 25, that seems to be the minimum value of RR that could have any chance of learning the state transition model. We then run the training procedure using alternating the NMF inference and learning update rules as described in Section 3.3. We ran several training runs for values of RR ranging between 25 and 500.

5.8.2 Interpretation of the learned weights

We observed some interesting and surprising results when training with different values of RR. The first is that training seems to become faster and more reliable as RR is increased. The model was consistently able to learn an exact or nearly exact factorization (training MSE below 1e-7 or so) for RR around 100 or larger. However, we saw training gradually became less reliable as RR was decreased toward the lower limit of 25, often requiring multiple training attempts to successfully learn the factorization. Figure 8 shows one such unsuccessful training run with R=50R=50, where the training MSE only converged to 0.112. Still, training was still sometimes successful even at the limit value R=25R=25. Since we only observed unreliable training with small values of RR, we did not attempt any hyperparameter optimization in order to fix it.

Refer to caption
Figure 8: A failure case: learned weights after training on the repeating deterministic sequence in Figure 7 for R=50R=50. Multiple training attempts were sometimes needed for values of RR between 25 and 100 and this shows one such example of unsuccessful training.

Perhaps our most surprising observation was that the learned models tended to be highly sparse and interpretable even though we did not use any sparsity regularization. Regardless of the value of RR, a successful training run resulted in the model discovering that only 25 basis vectors were actually needed in WW, with the other columns tending toward 0. Additionally, the learned columns of the state transition matrix WhW_{h} appeared close to 1-hot vectors as Figure 9 shows. Such a learned representation lets us understand the underlying state transition model that was used to generate the repeating sub-sequence by quick visual inspection of the model weights. Recall that for any particular activated basis column in WW, the corresponding column of bottom sub-matrix WxW_{x} additively reconstructs the input xkx_{k}. Similarly, the same column of WhW_{h} additively reconstructs the previous state hk1h_{k-1}, and the same column of WyW_{y} additively reconstructs yky_{k}, which is the prediction for the next input xk+1x_{k+1}. As a concrete example, consider the right-most basis vector W[:,24]W[:,24] in Figure 9(a). Since Wx[:,24]W_{x}[:,24] has a 1 in the second row and Wy[:,24]W_{y}[:,24] has a 1 in the third row, this corresponds to explaining a 1-hot input vector similarly having a 1 in its second dimension and predicting that it will transition to a 1-hot vector having a 1 in its third dimension in the next time step. In the integer 1-hot sequence representation, this would correspond to [1][1,2][1]\rightarrow[1,2]. The same column of Wh[:,24]W_{h}[:,24] is additively reconstructing the previous hidden state (as a 1-hot vector with a 1 in its first dimension). When the right-most basis vector (i.e., column index = 24) of WW is activated by a column in HH such as H[24,k]=1H[24,k]=1, it causes the inferred next state to have a 1 in its last dimension. This inferred hkh_{k} then becomes the previous hk1h_{k-1} input state in the next time slice, and we see that the third basis column of WW from the left has a 1 in the final dimension (i.e., Wh[24,2]W_{h}[24,2] = 1) which would cause this basis column to be activated in the next time slice k+1k+1. Continuing in this way, we can easily read out the underlying transition model from inspection of WW. Note that the ordering of the basis vectors in WW is significant because activating column rr of WW results in a corresponding positive activated value in dimension rr of the inferred next state vector. With this understanding, it makes sense that the learned WhW_{h} would tend to be sparse even without any explicit sparsity regularization. Each activated basis column of WhW_{h} becomes a positive entry in the corresponding dimension of the input previous state hk1h_{k-1} in the next time slice (recalling the inferred states in HH are copied into the next time slice of HprevH_{prev}), which in turn needs to be explained as an additive combination of the basis vectors in WhW_{h}. That is, any activated columns in WhW_{h} translate to corresponding non-zero (positive-valued) rows in the next time slice’s input state vector. If WhW_{h} contains a non-zero column that has multiple positive values in different rows and this column is activated by hkh_{k}, it implies multiple columns of WhW_{h} must have been activated in the previous time slice. Consider also the case where there are duplicated columns in WhW_{h}. The NMF inference algorithm might then choose to activate both of them with some positive strength, again resulting in a non 1-hot hkh_{k} that would in turn need to be explained in the next time slice. Intuitively, it then seems to make sense that the model would tend toward the sparsest possible learned representation.

Refer to caption
(a) R=25R=25 basis columns
Refer to caption
(b) R=100R=100 basis columns
Refer to caption
(c) R=150R=150 basis columns
Figure 9: Learned weights after training on the repeating deterministic sequence in Figure 7 for three choices of the hidden state vector dimension RR, which is equal to the number of basis columns in WW. We see that in each case, 25 column vectors are learned.

5.8.3 Evaluation and visualization

With the model weights trained, we can now verify that it has successfully memorized the sequence. We will provide a short “seed” sequence that contains part of the repeating sub-sequence and then ask the model to generate a continuation of it. For this evaluation, we use a model that was trained with R=100R=100 dimensional hidden state vectors. Figure 10 shows the seed sequence, for which we use the first 15 time slices. We will generate 50 additional time slices after the seed sequence. We then initialize the XX and YY sub-matrices of VV so that they only contain the seed as the initial part at the left as shown in Figure 11.

Refer to caption
Figure 10: The input seed sequence consisting of the first 15 time slices of the repeating sub-sequence the model was trained on.
Refer to caption
Figure 11: The sub-matrices of VV just before the generation task. The first 15 time slices of XX are initialized with the seed sequence from 10. Likewise, the first 14 time slices of YY are initialized with the seed shifted 1 slice to the left. All hidden states are initialized to zeros.

We then run the inference procedure starting from the first time slice, since the hidden states need to be inferred for all time slices. The generation procedure works as follows. For each time slice kk, we iterate the NMF updates until the current state vector hkh_{k} converges. We copy the inferred hkh_{k} into the next time slice of the HprevH_{prev} sub-matrix of VV. If kk is less than the seed length, we leave yky_{k} as is (since it is part of the seed). Otherwise, we also update the current yk=Whhky_{k}=W_{h}h_{k}. We also copy this predicted yky_{k} into the xk+1x_{k+1} position in the next time slice of XX, which serves to propagate the predicted sequence vectors forward. However, we should note that we do not sample from the predicted yky_{k} and instead simply copy the predicted vector directly into the following time slice. We then increment kk to the next time slice and so on until finally reaching the end of the generated sequence length. Figure 12 shows the resulting generated sequence including the seed. From this we see that the model can successfully generate the memorized sequence, with a small amount of noise which is seen as slight yellow or red in the “hot” colormap that we used. Figure 13 shows all three matrices Vpredicted=WHV_{predicted}=WH corresponding to the factorization in Eq. 3.12 after generating the sequence from the seed.

Refer to caption
Figure 12: The generated sequence including the seed. The first 15 time slices are the seed and the remaining were generated.
Refer to caption
(a) Predicted sub-matrices of VV
Refer to caption
(b) Learned weights WW
Refer to caption
(c) Inferred hidden states HH
Figure 13: The factorization Vpredicted=WHV_{predicted}=WH corresponding to Eq. 3.12 after generating the sequence from the seed.

5.9 Copy task

We now test the factorized RNN and compare to the vanilla RNN on the more difficult Copy Task [32] using the setup from [33]. This involves generating a random sequence of tokens which are supplied one at a time in each time slice to the network. After this we supply a padding token for some fixed number of time slices. We then supply a “remember” token for a number of time slices equal to the length of input sequence that was supplied earlier, during which the network must attempt to output the remembered tokens of the input sequence. This task thus tests the network’s ability to recall information that was presented earlier and it gets more difficult as the task spans longer time intervals. We use the same parameters as [33] in which the input sequence is of length 10 and there are 10 distinct token values, which we supply to the network as 1-hot vectors. This combined with the pad and remember tokens lead to 12 distinct token values in total, so that the input 1-hot vectors will be 12-dimensional. We can then still adjust the difficulty by controlling the padding length TpadT_{pad}.

5.9.1 Training details for the factorized RNN

We used the same model and training hyperparameters as in Section 5.8. They only differ in how the loss is applied. In the model of Section 5.8, the RNN outputs a prediction at every time slice. However, in the copy task we only care about the predicted tokens during the time slices when the “remember” token is being supplied and so we only compute the prediction loss of these time slices here. Note that since we are using the NMF learning (i.e., NMF WW update rule) to learn the weights, this corresponds to having implicit MSE reconstruction loss terms on all of the hidden states and inputs as well. We set the hidden state dimension to R=1024R=1024. We used 100 NMF iterations per time slice for inference.

With the factorized RNN and non-negative parameters (i.e., NMF), most training runs resulted in perfect validation accuracy for Tpad=5T_{pad}=5. With Tpad=10T_{pad}=10, multiple training runs were needed to reach perfect accuracy. For Tpad=15T_{pad}=15, accuracy was no better than chance level. When we tried allowing negative parameters, the models failed to do better than chance accuracy. These results seem interesting in that the network is able to successfully learn the task with up to 10 padding tokens, even though there is no BPTT-like backward flow of error gradients.

5.9.2 Training details for the vanilla RNN

For comparison, we also train a vanilla RNN on the same task. For this we use the network described in Section 3.1 except that we used LayerNorm on the hidden states as suggested for RNNs in [7]. We use the GELU activation. We used the same hidden state dimension R=1024R=1024 as the factorized network. Negative-valued parameters were allowed as an all other experiments with MLP-based models. We observed that weight decay of 1e-4 was needed for the best results, although we did not do much hyperparameter tuning.

With the vanilla RNN, we were also able to get perfect accuracy for Tpad=5T_{pad}=5 when BPTT was used. We found that perfect accuracy was possible up to approximately Tpad=15T_{pad}=15. However, when we disabled BPTT, we were not able to get perfect accuracy for any padding length. We were not able to do much hyperparameter tuning, though, and so it remains possible that performance could improve with better tuning.

5.10 Sequential MNIST classification

The Sequential MNIST classification task is a common task used for evaluating the performance of RNNs. It adapts the MNIST dataset [27], which consists of 28x28 pixel grayscale images of handwritten digits (0 to 9), for a sequential data processing context. In the standard MNIST task, the entire image is presented to the model at once, as in Section 5.2. However, in the Sequential MNIST task, the image is presented as a sequence of pixels, typically in a row or column-wise manner. We evaluate the column-wise version in this experiment. The image is unrolled column by column, resulting in a 28-time-slice sequence, where each slice is a 28-dimensional vector representing one column of pixels. Each time slice of the RNN outputs a 10-dimensional class prediction vector, although we only compute the loss on the final time slice of the sequence, ignoring the model predictions from earlier time slices.

5.10.1 Training details

For the factorized RNN, we used 10 unrolled inference iterations since it allowed us to run more experiments. We noticed only slightly reduced accuracy in one experiment compared to using 50-100 iterations and so we made the trade-off based on our limited computational resources. We used a learning rate of 5e-5 and weight decay of 1e-5 for all of the training runs and for both the factorized and vanilla RNNs. Similar to the other experiments, we used the input reconstruction loss on the factorized model. We trained under both the NMF and semi-NMF parameter constraints. We trained both the factorized and vanilla RNNs with and without BPTT, and for hidden state dimensions of 512 and 2048.

5.10.2 Results

Table 7 shows the accuracy results on this task. For the results with BPTT disabled, the factorized RNN has significantly higher accuracy compared to the MLP (96.23% vs 83.33% at hidden dimension = 512, and 97.16% vs 94.06% at hidden dimension = 2048), although this required the semi-NMF parameter constraint. Using the NMF constraint without BPTT produced significantly worse accuracy. We were somewhat surprised to see both the factorized and vanilla RNNs performing so well without BPTT. Both models performed better with BPTT enabled, but only slightly, and we see that the factorized RNN slightly outperformed the vanilla RNN. Interestingly, when using BPTT, the factorized RNN performed similarly under both the NMF and semi-NMF constraints. Notice also that when BPTT is disabled, both the factorized and vanilla RNNs suffer a more severe accuracy degradation in going from 2048 to 512 hidden dimension size, compared to when BPTT is enabled; It seems that disabling BPTT could be making the models less parameter efficient.

Table 7: Comparison of factorized vs vanilla RNNs on the Sequential MNIST task for various hyperparameter settings.
Model BPTT Hidden State Dimension Parameter Constraints Test Accuracy
Factorized RNN No 512 NMF 77.43%
Factorized RNN No 512 semi-NMF 96.23%
Factorized RNN No 2048 NMF 75.81%
Factorized RNN No 2048 semi-NMF 97.16%
Vanilla RNN No 512 N/A 83.33%
Vanilla RNN No 2048 N/A 94.06%
Factorized RNN Yes 512 NMF 98.56%
Factorized RNN Yes 512 semi-NMF 98.16%
Factorized RNN Yes 2048 NMF 98.66%
Factorized RNN Yes 2048 semi-NMF 99.00%
Vanilla RNN Yes 512 N/A 97.94%
Vanilla RNN Yes 2048 N/A 98.66%

5.11 Audio source separation on MUSDB18

Audio tends to be well modeled as a mixture of source components. Much of the music we listen to is explicitly mixed together from isolated recordings (often called stems). In the source separation problem, we are interested in separating multiple sources that have been mixed together. For example, suppose we have a recording by a small band in which the vocals, guitar, and drums have all been captured together in a single audio channel. A source separation system would then be tasked with taking this recording as input and producing an output containing only a single desired source such as the vocals.

A useful inductive bias in this case would be to have the model f()f() satisfy the linear superposition property: alpha1f(x1)+alpha2f(x2)=f(alpha1x1+alpha2x2)alpha1*f(x_{1})+alpha2*f(x_{2})=f(alpha1*x_{1}+alpha2*x_{2}). Any scaled mixture of the two sources should produce the corresponding scaled output. This is a desirable property since it implies that if the model is capable of recognizing the sources in isolation, it automatically generalizes to handling the mixture case. Such a property could potentially allow the model to generalize better from limited training data.

NMF can potentially satisfy this property, and there are several existing works in which it has been applied to related audio problems [34]. Since our factorized RNN is essentially a sequential extension of matrix factorization, it seems interesting to apply it to the source separation task and compare its generalization performance against a standard vanilla RNN (which does not have the superposition property due to its use of non-linear activation functions). We train both factorized and vanilla RNNs on the source separation task using the MUSDB18 dataset. MUSDB18 contains 150 music tracks of various genres corresponding to approximately 10 hours of music, split into a training dataset containing 100 tracks and a test dataset containing 50 tracks. It includes the isolated tracks (stems) for drums, bass, other, and vocals. This makes it straightforward to use for training and evaluating source separation models.

5.11.1 Training details

For training we mix together the isolated stems to create the inputs and then supply just one of the stems as the target. In this experiment we chose to use vocals as the target. The audio is aligned during validation and testing but not aligned during training to provide more variation in the examples. During training only, each source is scaled by random per-example values between 0.5 and 2.0.

We used the same MSE output prediction loss (i.e., regression loss) for both models to ensure that the resulting MSE test loss would be comparable between the factorized and vanilla RNN. That is, the factorized RNN only used the MSE output prediction loss term, without the input reconstruction loss term in this experiment. For the audio features, we use the following hyperparameters: The sample rate was 44100 Hz, the short-time Fourier transform (STFT) used a 2048 window size and 1024 hop size. We limited tha audio feature vectors (time slices of the STFT) to contain only the lowest 350 frequency bins. Negative parameter values were allowed in both models. BPTT was used in both models. The hidden state dimension was 1024 in both models. The batch size was 100. The learning rate was 5e-4 in the factorized RNN as 1e-4 for the vanilla RNN. Weight decay was 5e-5 for both models.

5.11.2 Results

The vanilla RNN had a MSE test loss of 3.657e-3. The factorized RNN had a test loss of 2.497e-3. These were averaged over 3 runs. Both models seem to be underfitting since the training loss converges to a similar range as the validation loss. Although the factorized RNN produced a better (lower) test loss compared to the vanilla RNN, we should note that both models perform significantly below state of the art on this dataset due to the use of a the simple RNN architecture.

6 Related work

6.1 Positive Factor Networks

Perhaps the single most related existing work is our previous work on Positive Factor Networks [18], which similarly proposed NMF-based building blocks that can be composed to build expressive and interpretable architectures. Although the factorized equivalent of the vanilla RNN that we introduce in this paper was not considered, various other factorized-RNN-like architectures (which were referred to as dynamic positive factor networks) were considered, including some more sophisticated architectures such as a sequential data model for target tracking, and a model employing dilated deconvolutional layers in a Wavenet-like [35] architecture. A block with the same factorized model as our PFC block appears in Eq 3.1 of [18], which was referred to a “coupling module” or “coupling factorization” in that paper. Similar iterative NMF algorithms were used to perform inference in these models. However, a key distinction between that work and our present work is that it did not use backpropagation for learning the parameters, and instead relied on applying the NMF left-update rules to learn them. We also experimented with NMF left-update rules in the current work for learning the parameters as an ablation in Sections 5.8, 5.9, 5.10, but found them to underperform backpropagation on the more challenging learning tasks. As a result of not using backpropagation for parameter learning, the models presented in [18] failed to produce results competitive with other approaches on supervised learning tasks. The PFC block and corresponding neural networks constructed from them that we consider in the present work can therefore be considered as unrolled positive factor networks, or positive factor networks employing backpropagation-based training.

6.2 Non-negative matrix factorization

NMF was originally proposed by [14] as positive matrix factorization and later popularized by [15] [17]. NMF is related to other methods such as sparse coding [36] that use a similar dictionary learning model but with differences in the modeling constraints, regularizers, and/or algorithms used. NMF is often observed to provide parts-based decompositions of data, which can be useful when interpretable decompositions are desired. It also potentially satisfies the (non-negative) linear superposition property and as a result has been applied to audio processing tasks such as source separation and music transcription in which the audio features in the input data matrix are assumed to be well modeled as an additive combination of audio sources (i.e., individual instruments and/or notes) [34].

Our PFC block is related in that its declarative model corresponds to a masked predictive NMF in which the data matrix is partitioned along the row dimension into “input” and “output” vectors of the block, with the right factor matrix HH corresponding to inferred hidden activations. The output vectors as masked during recognition (during inference of HH) and the resulting inferred HH is then used to predict the outputs. This allows our block to retain the additive superpositional NMF model, while also supporting the construction of more expressive differentiable architectures such as factorized RNNs. The resulting neural networks (or positive factor networks) can then be interpreted as an extension of NMF that increases its modeling expressiveness and suitability for supervised learning tasks.

6.3 Unrolled neural networks

The key enabler of the increased predictive performance of our present work compared to the Positive Factor Networks of [18] is the modification of the learning algorithm to use backpropagation instead of relying on NMF left-update steps. As discussed in Section 2.4, backpropagation-based training is enabled by unrolling the iterative NMF inference steps into an RNN-like structure in the computation graph, making the PFC block differentiable and therefore compatible with backpropagation training.

This process of unrolling an iterative optimization algorithm into a neural network and training with backpropagation is referred to as algorithm unrolling or unrolled neural networks in the literature [25]. An existing example of unrolled NMF appears in [37]. The idea that improved results could be obtained on supervised tasks by adding an additional task-specific loss function to an iterative and differentiable optimization algorithm and using it to optimize the parameters was proposed in [38] and [39], with earlier related ideas being proposed in [40] and [41]. More recent works have explored the use of unrolled convolutional sparse coding for improved robustness to corrupt and/or adversarial input images [42], [43]. Of these, [44], [43] also use FISTA to accelerate the convergence of the unrolled inference. As far as we are aware, ours is the first work to explore the use of a modular unrolled block (PFC block) supporting the construction of arbitrary neural architectures, unrolled factorized RNNs, and the first to explore the interpretability advantages of unrolled NMF-based networks for continual learning.

6.4 Nearest neighbor classification and regression

As discussed in Section 2.2, our PFC block shares some similarities with classification and regression based on the k-nearest neighbors (k-NN) algorithm [10]. We show that the PFC block can be derived by first formulating k-NN prediction as a matrix-vector product, and then generalizing and interpreting it as a matrix factorization.

6.5 Learning Vector Quantization (LVQ)

Learning Vector Quantization (LVQ) [11] [12] is a prototype-based method that extends the k-nearest neighbors algorithm by introducing the concept of learnable prototypes. An advantage of this approach over matrix-factorization-based methods is its faster recognition since an iterative algorithm is not used. We are not aware of a fully differentiable version of LVQ that could match MLP predictive performance and make it possible to use as a building block for more complex architectures such as RNNs, however.

6.6 Future research

This work is preliminary and we leave several unanswered questions and possibilities for future research directions. We list some of them here:

  • As discussed in [45], methods such as ours that perform recognition by iterating to a fixed point can potentially make use of adaptive computation. This could potentially be used to improve recognition efficiency, as well as providing an additional form of confidence estimation, since “difficult” inputs could require more iterations to converge.

  • When unrolling the NMF inference updates, we observed that memory usage can potentially be reduced by computing the initial several iterations “without gradients” and only unrolling the last several iterations “with gradients”. In some cases, this resulted in improved efficiency, but in others it resulted in reduced accuracy, and so a better understanding is still needed.

  • Although we found FISTA to be one effective method in accelerating the NMF inference step of the PFC block, it could be interesting to also consider other approaches to further accelerate the inference.

  • Our sliding learnable window optimizer introduced in Section 4 enabled improved continual learning in PFC-based models. However, more sophisticated approaches based on basis-vector-specific learning rates could potentially lead to improved continual learning performance and/or parameter efficiency. For example, it could be interesting to consider methods that identify the important or in-use basis vectors and reduce their learning rates accordingly so as to make them more resistant to being overwritten or modified later in training.

  • We did not consider any form of sparsity regularization (e.g., L1 penalty) in our experiments, leaving regularization methods other than the non-negativity constraint and weight decay unexplored in the current work.

  • As Table 7 shows, disabling BPTT in both factorized and vanilla RNNs often resulted in only a small drop in accuracy, although more parameters were needed to recover accuracy. It is also unclear why the semi-NMF parameter constraint resulted in better accuracy compared to NMF when BPTT was disabled. It could be interesting to conduct a more detailed investigation as future research.

  • The declarative model of the PFC block in Eq. 2.10 is symmetric with respect to the inputs and outputs. This allows us to potentially reverse the prediction direction of the block so that we can consider swapping their roles and make the block compute the “inputs” given the “outputs”. In addition to reversible blocks/layers, it could be interesting to explore the case where only some subset of the input and output are observed and the block then jointly predicts all unobserved elements.

  • As discussed in Section 2.4, the PFC block can be interpreted as performing factorized attention over parameters. It could be interesting to also consider its use as an attention block that performs factorized attention over key and value activations (output from other upstream blocks) instead of over parameters, making it more like a factorized version of the attention block used in the transformer. Specifically, in Eq. 2.10 the WxW_{x} matrix would be replaced by a keys matrix KxK_{x} and WyW_{y} would be replaced with a corresponding values matrix VyV_{y} which would then represent output activations from other layers/blocks instead of learnable parameters.

  • Note that the PFC block has the inductive bias of NMF, which is different than the MLP. In our experiments, we observed it to perform similar to and sometimes better compared to the MLP. However, it is possible that the NMF inductive bias could make PFC-based models particularly well suited or poorly suited depending on the modeling assumptions of the datasets used. Also note that relaxing the non-negativity constraint to semi-NMF could support the use of datasets containing negative values, at the expense of possibly reduced interpretabiity, however. We leave such an exploration as future research.

7 Conclusion

In this paper, we presented the Predictive Factorized Coupling (PFC) block, a neural network building block that combines the interpretability of non-negative matrix factorization (NMF) with the predictive performance of multi-layer perceptron (MLP) networks. We demonstrated the versatility of the PFC block by using it to build various architectures, including a single-block network, a fully-connected residual network containing two PFC blocks, and a factorized RNN.

Our experiments showed that the PFC block achieves competitive accuracy with MLPs on small datasets while providing better interpretability. We also demonstrated the benefits of the PFC block in continual learning, training on non-i.i.d. data, and knowledge removal after training. Additionally, we showed that the factorized RNN outperforms vanilla RNNs in certain tasks while providing improved interpretability.

While the PFC block has limitations, such as slower training and inference and increased memory consumption during training, it offers a promising direction for developing more interpretable neural networks without sacrificing predictive performance. Future work includes evaluating the PFC block on larger datasets and exploring its suitability for use in more complex multi-block architectures.

Appendix A Alternating SGD updates for NMF

One of the simplest methods that can be used to solve for the factors WW and HH in Equation (2.1) is gradient descent (GD) or stochastic gradient descent (SGD). Whether GD or SGD applies depends on whether we are operating on the full training data or on batches, and so we will simply refer to both cases as SGD in the following. In SGD, we first choose a suitable loss function to quantify the approximation error and then compute its gradients with respect to each factor matrix. We alternately apply small additive update steps in the opposite direction of the gradient to each factor matrix, while clipping any negative values to zero, until convergence. A commonly used choice of loss is the squared Euclidean error and so we use it here. The squared error loss \mathcal{L} between VV and the approximation Vpred=WHV_{pred}=WH is given by:

(V||Vpred)=12i,j(vijvpredij)2=12i,jeij2\displaystyle\mathcal{L}(V||V_{pred})=\frac{1}{2}\sum_{i,j}(v_{ij}-{v_{pred}}_{ij})^{2}=\frac{1}{2}\sum_{i,j}e_{ij}^{2} (A.1)

where eij=vijvpredije_{ij}=v_{ij}-{v_{pred}}_{ij} is the approximation error of the ijij’th element of VV. Let EE denote the approximation error matrix with elements eije_{ij}. Since \mathcal{L} is differentiable with respect to WW and HH, the loss gradients are:

H=WT(WHV)=WTE\displaystyle\frac{\partial\mathcal{L}}{\partial H}=W^{T}(WH-V)=W^{T}E (A.2)
W=(WHV)HT=EHT\displaystyle\frac{\partial\mathcal{L}}{\partial W}=(WH-V)H^{T}=EH^{T} (A.3)

The resulting SGD updates are then given by:

H\displaystyle H \displaystyle\leftarrow relu(HηHH)\displaystyle relu(H-\eta_{H}\frac{\partial\mathcal{L}}{\partial H}) (A.4)
W\displaystyle W \displaystyle\leftarrow relu(WηWW)\displaystyle relu(W-\eta_{W}\frac{\partial\mathcal{L}}{\partial W}) (A.5)

where ηH\eta_{H} and ηW\eta_{W} are learning rate hyperparameters which are set to a small nonnegative value. The relu()relu() function is used to prevent negative values in the updated matrices. Substituting the gradients in the above updates gives:

H\displaystyle H \displaystyle\leftarrow relu(HηHWT(WHX))\displaystyle relu(H-\eta_{H}W^{T}(WH-X)) (A.6)
W\displaystyle W \displaystyle\leftarrow relu(WηW(WHX)HT)\displaystyle relu(W-\eta_{W}(WH-X)H^{T}) (A.7)

We then initialize WW and HH with non-negative noise and iteratively perform the updates until convergence.

A.1 Normalization and preventing numerical issues

We observed that numerical issues can sometimes be reduced by additionally replacing any zero-valued elements in the factor matrices with some small minimum allowable non-negative value immediately after the relu()relu() in Eqs. A.6 A.7, such as ϵ=1e5\epsilon=1e-5, for example.

We introduce the following two normalization methods, which we have observed to often perform well in our experiments, depending on whether we are performing alternating updates to jointly learn both WW and HH or performing unrolled inference to infer only HH, followed by backpropagation-based learning of WW.

A.1.1 Normalization for the joint factorization case

Note that we can scale WW by an arbitrary value α\alpha if we also scale HH by its inverse so that their product is unchanged. Depending on the update algorithm used, it is possible that one of the factor matrices might be scaled slightly larger on each update while the other is scaled slightly smaller so that one of WW and/or HH often tends toward infinity while the other tends toward zero. To avoid this problem, one of the factor matrices (typically WW) is typically normalized to have e.g. unit column norm after each update [46]. However, this prevents the basis vectors from becoming arbitrarily small, potentially resulting in a less sparse and/or less interpretable solutions.

We have empirically observed that optimization can be easier when the range of values in WW and HH are similar. We use the Numpy/PyTorch notation in the follow where “::” denotes selecting all rows or columns in the dimension where it is applied. Intuitively, if a particular column W[:,i]W[:,i] does not contribute significantly to the approximation of XX, then we would like its values as well as the corresponding activations in row H[i,:]H[i,:] to be small and vice versa. We use the following normalization algorithm to achieve this.

For each column W[:,i]W[:,i] in WW, we compute its maximum value along with the maximum value of the corresponding row H[i,:]H[i,:] in HH. We then compute the mean of these two values. We scale the column W[:,i]W[:,i] and row H[i,:]H[i,:] so that their updated maximum value will be equal to this mean. We have found this simple algorithm to work well in our experiments. After training, it tends to result in both WW and HH having nearly identical maximum values, sparser learned factorizations and/or improved approximation error compared to the other normalization methods that we tried. We must be careful if mini-batch training is used, though, as the above maximum values are intended to be computed over the full WW and HH matrices.

A.1.2 Normalization for unrolled inference with backpropagation

For the unrolled algorithm in which backpropagation is used together with another optimizer (e.g., RMSprop) to update WW, we have found the following update method to perform well. The basis idea is that for each column viv_{i} in the data matrix VV, we limit the maximum value of the corresponding inferred column hih_{i} in HH such that it is not allowed to have a maximum value larger than the maximum value in viv_{i}. This normalization step involves computing the column-wise maximum values in VV at the start of inference. After each unrolled NMF right-update to matrix HH, we then apply a column scaling step to scale the columns such that their corresponding maximum values do not exceed those of VV (if the maximum value is already less, then no scaling if performed). This is intended to prevent the inferred values in HH from exploding during the unrolled inference. Recall that backpropagation and another optimizer are used to update WW. We observed that it was not necessary to explicitly apply any normalization to WW since we did not encounter exploding values. Although weight decay was used in the experiments, it was not needed in order to prevent numerical issues, and was only used due to its observed slightly beneficial effect on predictive performance in some cases.

A.2 Automatic learning rate selection

We empirically observe that setting the SGD learning rates ηH\eta_{H} and ηW\eta_{W} in Eq. A.6 automatically using the same method as in FISTA often works well. Specifically ηH\eta_{H} is set to 1/LH1/L_{H} where LHL_{H} is the largest eigenvalue in WTWW^{T}W. Likewise, ηW\eta_{W} is set to 1/LW1/L_{W} where LWL_{W} is the largest eigenvalue in HHTHH^{T}. In our experiments, we used the power method to compute the approximate largest eigenvalue.

Appendix B Implementation details for FISTA-accelerated NMF inference in the PFC block

We adapt the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) to accelerate the NMF inference procedure in the PFC block. FISTA is an optimization method that combines gradient-based approaches with an acceleration technique [23]. Originally designed for sparse recovery, FISTA can be adapted for various applications, including matrix factorization. We use rectification (relu) as the proximal operator instead of the usual shrinkage threshold operator. The resulting algorithm uses the SGD NMF right-update steps with the FISTA steps that compute the momentum term. Both NMF and semi-NMF constraints are supported.

Given an input matrix XX, we aim to infer the hidden matrix HH in the factorization XWxHX\approx W_{x}H. We assume that the weights WxW_{x} remain fixed during the inference procedure. In the standard SGD-based NMF procedure reviewed in Section A, the inference procedure amounts to the repeated application of the HH update step in Eq A.6 followed by the normalization steps and FISTA momentum update steps. The procedure is as follows:

  1. 1.

    Initialize H0H^{0} and set Y1=H0,t1=1Y^{1}=H^{0},t_{1}=1.

  2. 2.

    For each iteration kk, update HH by:

    1. (a)

      Updating HH as Hk+1=relu(Yk1LWT(WYkX))H^{k+1}=relu(Y^{k}-\frac{1}{L}W^{T}(WY^{k}-X))

    2. (b)

      Applying normalization scaling to HH.

    3. (c)

      Updating tt as tk+1=1+1+4tk22t_{k+1}=\frac{1+\sqrt{1+4t_{k}^{2}}}{2}.

    4. (d)

      Updating YY as Yk+1=Hk+1+(tk1tk+1)(Hk+1Hk)Y^{k+1}=H^{k+1}+\left(\frac{t_{k}-1}{t_{k+1}}\right)(H^{k+1}-H^{k}).

  3. 3.

    Repeat the process until convergence is achieved.

We estimate the Lipschitz constant LL using the power method. The normalization scaling step is described in Section A.1.2. This step may not always be needed, but we leave it enabled in all experiments to prevent numerical issues in HH.

References

  • [1] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” 2017. Version 3.
  • [2] F. Rosenblatt, “The perceptron: a probabilistic model for information storage and organization in the brain.,” Psychological review 65(6), p. 386, 1958.
  • [3] A. G. Ivakhnenko, V. G. Lapa, et al., “Cybernetic predicting devices,” (No Title) , 1965.
  • [4] S. Amari, “A theory of adaptive pattern classifiers,” IEEE Transactions on Electronic Computers (3), pp. 299–307, 1967.
  • [5] G. Cybenko, “Approximation by superpositions of a sigmoidal function,” Mathematics of Control, Signals, and Systems 2(4), pp. 303–314, 1989.
  • [6] K. Hornik, M. Stinchcombe, and H. White, “Multilayer feedforward networks are universal approximators,” Neural Networks 2(5), pp. 359–366, 1989.
  • [7] J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer normalization,” 2016.
  • [8] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov, “Dropout: a simple way to prevent neural networks from overfitting,” The journal of machine learning research 15(1), pp. 1929–1958, 2014.
  • [9] M. McCloskey and N. J. Cohen, “Catastrophic interference in connectionist networks: The sequential learning problem,” in Psychology of learning and motivation, 24, pp. 109–165, Elsevier, 1989.
  • [10] T. M. Cover and P. E. Hart, “Nearest neighbor pattern classification,” IEEE Trans. Inf. Theory 13, pp. 21–27, 1967.
  • [11] T. Kohonen, “The ”neural” phonetic typewriter,” IEEE Computer 21(3), pp. 11–22, 1988.
  • [12] T. Kohonen, “Improved versions of learning vector quantization,” in 1990 ijcnn international joint conference on Neural networks, pp. 545–550, IEEE, 1990.
  • [13] G. M. van de Ven and A. S. Tolias, “Three scenarios for continual learning,” 2019.
  • [14] P. Paatero and U. Tapper, “Positive matrix factorization: A non-negative factor model with optimal utilization of error estimates of data values,” Environmetrics 5(2), pp. 111–126, 1994.
  • [15] D. Lee and H. Seung, “Learning the parts of object by non-negative matrix factorization,” Nature 401, pp. 788–791, 1999.
  • [16] C. H. Ding, T. Li, and M. I. Jordan, “Convex and semi-nonnegative matrix factorizations,” IEEE transactions on pattern analysis and machine intelligence 32(1), pp. 45–55, 2008.
  • [17] D. Lee and H. S. Seung, “Algorithms for non-negative matrix factorization,” Advances in neural information processing systems 13, 2000.
  • [18] B. K. Vogel, “Positive factor networks: A graphical framework for modeling non-negative sequential data,” arXiv preprint arXiv:0807.4198 , 2008.
  • [19] D. E. Rumelhart, G. E. Hinton, and R. J. Williams, “Learning representations by back-propagating errors,” nature 323(6088), pp. 533–536, 1986.
  • [20] Y. LeCun, Learning processes in an asymmetric threshold network. PhD thesis, University of Paris 6, 1985.
  • [21] P. J. Werbos, Beyond regression: New tools for prediction and analysis in the behavioral sciences. PhD thesis, Harvard University, 1974.
  • [22] D. Hendrycks and K. Gimpel, “Gaussian error linear units (gelus),” arXiv preprint arXiv:1606.08415 , 2016.
  • [23] A. Beck and M. Teboulle, “A fast iterative shrinkage-thresholding algorithm for linear inverse problems,” SIAM journal on imaging sciences 2(1), pp. 183–202, 2009.
  • [24] G. Hinton, “Lecture 6a overview of mini-batch gradient descent.” https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf, 2018.
  • [25] V. Monga, Y. Li, and Y. C. Eldar, “Algorithm unrolling: Interpretable, efficient deep learning for signal and image processing,” IEEE Signal Processing Magazine 38(2), pp. 18–44, 2021.
  • [26] J. L. Elman, “Finding structure in time,” Cognitive science 14(2), pp. 179–211, 1990.
  • [27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner, “Gradient-based learning applied to document recognition,” Proceedings of the IEEE 86(11), pp. 2278–2324, 1998.
  • [28] H. Xiao, K. Rasul, and R. Vollgraf, “Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms,” 2017.
  • [29] A. Krizhevsky, V. Nair, and G. Hinton, “Learning multiple layers of features from tiny images,” tech. rep., University of Toronto, 2009.
  • [30] F. Zenke, B. Poole, and S. Ganguli, “Continual learning through synaptic intelligence,” 2017.
  • [31] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al., “Overcoming catastrophic forgetting in neural networks,” Proceedings of the national academy of sciences 114(13), pp. 3521–3526, 2017.
  • [32] S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural computation 9(8), pp. 1735–1780, 1997.
  • [33] M. Arjovsky, A. Shah, and Y. Bengio, “Unitary evolution recurrent neural networks,” 2016.
  • [34] G. Grindlay and D. P. W. Ellis, “Multi-voice polyphonic music transcription using eigeninstruments,” in IEEE Workshop on Applications of Signal Processing to Audio and Acoustics, pp. 53–56, 2009.
  • [35] A. v. d. Oord, S. Dieleman, H. Zen, K. Simonyan, O. Vinyals, A. Graves, N. Kalchbrenner, A. Senior, and K. Kavukcuoglu, “Wavenet: A generative model for raw audio,” arXiv preprint arXiv:1609.03499 , 2016.
  • [36] B. A. Olshausen and D. J. Field, “Sparse coding with an overcomplete basis set: A strategy employed by v1?,” Vision research 37(23), pp. 3311–3325, 1997.
  • [37] R. Nasser, Y. C. Eldar, and R. Sharan, “deep unfolding for non-negative matrix factorization with application to mutational signature analysis,” 2021.
  • [38] J. Mairal, F. Bach, and J. Ponce, “Task-driven dictionary learning,” IEEE transactions on pattern analysis and machine intelligence 34(4), pp. 791–804, 2011.
  • [39] J. T. Rolfe and Y. LeCun, “Discriminative recurrent sparse auto-encoders,” arXiv preprint arXiv:1301.3775 , 2013.
  • [40] Y. Bengio and F. Gingras, “Recurrent neural networks for missing or asynchronous data,” Advances in neural information processing systems 8, 1995.
  • [41] H. S. Seung, “Learning continuous attractors in recurrent networks,” Advances in neural information processing systems 10, 1997.
  • [42] J. Sulam, R. Muthukumar, and R. Arora, “Adversarial robustness of supervised sparse coding,” Advances in neural information processing systems 33, pp. 2110–2121, 2020.
  • [43] M. Li, P. Zhai, S. Tong, X. Gao, S.-L. Huang, Z. Zhu, C. You, Y. Ma, et al., “Revisiting sparse convolutional model for visual recognition,” Advances in Neural Information Processing Systems 35, pp. 10492–10504, 2022.
  • [44] X. Sun, N. M. Nasrabadi, and T. D. Tran, “Supervised deep sparse coding networks,” in 2018 25th IEEE International Conference on Image Processing (ICIP), pp. 346–350, IEEE, 2018.
  • [45] T. Achler, “A flexible online classifier using supervised generative reconstruction during recognition,” 2012.
  • [46] W. Liu, N. Zheng, and Q. You, “Nonnegative matrix factorization and its applications in pattern recognition,” Chinese Science Bulletin 51, pp. 7–18, 2006.