An NMF-Based Building Block for Interpretable Neural Networks With Continual Learning
Abstract
Existing learning methods often struggle to balance interpretability and predictive performance. While models like nearest neighbors and non-negative matrix factorization (NMF) offer high interpretability, their predictive performance on supervised learning tasks is often limited. In contrast, neural networks based on the multi-layer perceptron (MLP) support the modular construction of expressive architectures and tend to have better recognition accuracy but are often regarded as black boxes in terms of interpretability. Our approach aims to strike a better balance between these two aspects through the use of a building block based on NMF that incorporates supervised neural network training methods to achieve high predictive performance while retaining the desirable interpretability properties of NMF. We evaluate our Predictive Factorized Coupling (PFC) block on small datasets and show that it achieves competitive predictive performance with MLPs while also offering improved interpretability. We demonstrate the benefits of this approach in various scenarios, such as continual learning, training on non-i.i.d. data, and knowledge removal after training. Additionally, we show examples of using the PFC block to build more expressive architectures, including a fully-connected residual network as well as a factorized recurrent neural network (RNN) that performs competitively with vanilla RNNs while providing improved interpretability. The PFC block uses an iterative inference algorithm that converges to a fixed point, making it possible to trade off accuracy vs computation after training but also currently preventing its use as a general MLP replacement in some scenarios such as training on very large datasets. We provide source code at https://github.com/bkvogel/pfc
1 Introduction
Current neural networks are often the best performing models on a range of challenging problems. Commonly used neural architectures include fully-connected and convolutional networks, various Recurrent Neural Networks (RNNs), and transformers [1]. These architectures and others make use of a relatively small number of basic building block types so that the differences between the various architectures are mainly due to which blocks are used and how they are inter-connected. Perhaps the single most fundamental building block is the Multi-Layer Perceptron (MLP) [2] [3] [4] since it is the smallest block that provides for learning arbitrarily complex non-linear functions and also tends to account for the bulk of the learnable parameters in existing architectures [5] [6]. The remaining building blocks are mainly intended to make the optimization process more efficient and/or serve as regularizers to help prevent over-fitting; examples of these include residual (i.e., skip) connections, LayerNorm [7], and Dropout [8] layers. Additional building blocks with learnable parameters are sometimes used for specialized architectures, such as linear embedding layers for the case of discrete-valued inputs and to provide the sequence-positional information in the transformer, for example. The transformer includes all of the above-mentioned blocks, as well as the attention block, which itself has an MLP interpretation, as described on the last page of [1]. We therefore consider the currently popular neural architectures as being “MLP-based”.
The existing MLP-based models do have some drawbacks, however. The first is that these architectures are often criticized as being “black boxes” lacking interpretability. Additionally and perhaps partially due to the first issue is that they tend to experience training difficulties as the distribution of examples starts to deviate from the i.i.d. assumption. This is often referred to as the “catastrophic forgetting” problem in the literature and it results in poor continual learning performance [9].
As an alternative to the MLP, there are other existing machine learning methods with better interpretability properties. A simple example is the k-nearest-neighbors (k-NN) algorithm for classification and regression [10]. Since there is no explicit learning step other than retaining the training examples as they become available to the model, k-NN trivially supports continual learning as well as knowledge removal. Other prototype-based algorithms such as LVQ [11] [12] can be considered learnable extensions of nearest neighbors and also have good interpretability.
Another method that is often noted for its desirable interpretability properties is non-negative matrix factorization (NMF). In contrast to other prototype-based learning methods such as k-NN and LVQ in which the prototypes represent entire examples, in NMF they instead represent prototypical “parts” exemplars which are additively combined to reconstruct the input.
These more interpretable methods unfortunately also have drawbacks that have prevented them from competing with MLP-based models in terms of predictive accuracy on supervised tasks. Methods such as k-NN and LVQ are not fully differentiable, preventing them from being used as general neural network building blocks that can be composed like the MLP to create expressive architectures supporting arbitrary differentiable loss functions tailored to the task at hand. Although NMF is potentially differentiable, the literature is mainly concerned with its usage for unsupervised learning tasks in which the factor matrices are optimized to minimize an input reconstruction loss only, rather than a supervised classification or regression loss.
1.1 Contributions
The main point of this work is to make some progress on developing learning methods with a better balance of interpretability and predictive performance compared to the existing MLP-based approaches. In particular, we make the following contributions:
-
•
In Section 2.4 we introduce a new neural network building block called the Predictive Factorized Coupling (PFC) block as a more interpretable alternative to the MLP. Since its declarative model is still matrix factorization, it potentially retains the interpretable parts-based nature of NMF, but extends it to support its use as a general differentiable predictive module.
-
•
In Section 5.2 we demonstrate that the PFC block has competitive accuracy with the (fully-connected) MLP on MNIST, Fashion MNIST, and CIFAR10 classification. In Section 5.7 we demonstrate that a (fully-connected) residual network consisting of two PFC blocks is also able to be trained without optimization difficulties and that it performs similarly. Section 5.6 show an example of the increased interpretability by visualizing in-domain vs out-of-domain input examples.
-
•
In Section 3 we develop a factorized RNN by starting from the well-known vanilla RNN and then replacing its MLP building block with our PFC blocks. This results in an RNN modeled as a single matrix factorization, effectively extending the parts-based nature of NMF to the modeling of sequential data. In Section 5.8 we demonstrate its interpretability advantages compared to the standard vanilla RNN on a simple sequential learning task with a known interpretable solution and observe that it is consistently able to learn the minimal transition model of the solution, even when the network is heavily over parameterized. In all of our sequence learning tasks we find that the factorized RNN performs either competitively or better compared with the vanilla RNN when both are trained with the usual BPTT. This includes learning a repeating sequence (Section 5.8), Copy Task (Section 5.9), Sequential MNIST (Section 5.10), and audio source separation (Section 5.11). We also perform some ablations and find two unexpected results of particular interest: The factorized RNN is able to solve simple tasks such as learning a repeating sequence and the Copy Task using only alternating NMF update rules so that backpropagation is not used at all. We also observe that in both the factorized and conventional RNNs, using backpropagation but disabling BPTT often results in a surprisingly minimal degradation in accuracy. The models do tend to become significantly less parameter efficient without BPTT, however.
-
•
In Section 5.3 we show that our non-replay-based continual learning method is competitive with approaches that rely on replay on the split MNIST Class-IL scenario [13]. Our PFC-based model performs better than the MLP on this task and we introduce a sliding window optimizer in Section 4 for PFC-based models that results in further accuracy improvements.
-
•
In Section 5.4 we show the superiority of a PFC-based model over the MLP on a non-i.i.d. training task.
- •
We also mention the following limitations:
-
•
The PFC block is slower and consumes more memory during training than the MLP block due to the use of an iterative inference algorithm that effectively turns each replaced MLP into a corresponding RNN. It is therefore not intended to be a general MLP replacement, but rather it is intended to be used in modeling tasks where its better interpretability, parts-based modeling prior, and/or suitability to continual learning is needed.
-
•
Since the PFC block is modeled as a matrix factorization, it similarly cannot discriminate between two inputs that differ only by a scale factor, since the corresponding outputs would also only differ by the same scale factor. The NMF modeling assumption also constrains the input features to being non-negative, although this can potentially be relaxed to semi-NMF to support negative data values at the expensive of possibly reduced interpretabiity.
-
•
This work is preliminary. Our experiments only involve relatively small datasets. In terms of the suitability of the PFC block as an MLP replacement, we only consider two multi-block architectures in this paper: a 2-block residual network and a factorized RNN. Evaluating on larger datasets and/or more complex multi-block architectures is left as future research.
2 Matrix-factorization-based layers
We can motivate our approach by first reviewing some related methods that will contribute toward its design. These include non-negative matrix factorization (NMF) in Section 2.1, k-nearest-neighbors (k-NN) prediction in Section 2.2, and the multi-layer perceptron (MLP) in Section 2.3. We then present the The Predictive Factorized Coupling (PFC) block in Section 2.4.
2.1 Review of non-negative matrix factorization (NMF)
Non-negative matrix factorization (NMF) is a matrix decomposition method where a data matrix is factorized into two matrices and , with the property that all three matrices have no negative elements. The objective is to find two non-negative matrices whose product closely approximates the initial matrix, such that:
(2.1) |
Here, is a tunable hyperparameter that leads to a more compressed approximation as it is reduced. NMF was originally proposed by [14] as positive matrix factorization and later popularized by [15]. NMF’s strength lies in its ability to provide interpretable decompositions. In particular, the non-negativity constraint is found to lead to sparse and parts-based representations because only additive, not subtractive, combinations are allowed [15]. In many real-world applications, such as image representation or document analysis, this aligns well with the nature of the data where the measured quantities (like pixel intensity or word counts) are inherently non-negative.
It is also possible to relax the non-negativity constraint so that only the factor matrix is constrained to be non-negative, which is called semi-NMF [16]. This allows for more flexibility in representing data but can also reduce the interpretabiity.
The factorization is not unique and is typically computed using iterative algorithms, such as gradient descent or multiplicative updates, for example [17]. This involves first specifying a differentiable reconstruction loss function such as mean-squared error (MSE), for example, and then jointly optimizing and to minimize the loss. These algorithms alternately fix one of the factor matrices and update the other, until approximate convergence to a fixed point. In this paper we will use the same terminology from [18] to discriminate between these two updates, so that the NMF left update step refers to applying the update rule once to update the factor matrix, which we also refer to as the NMF learning update step. Likewise, the NMF right update step refers to applying the update rule once to update the factor matrix, which we also refer to as the NMF inference update step.
The usual modeling convention is that the columns of correspond to input feature vectors, which are -dimensional. For example, the neural network interpretation of NMF shown in Figure 3 of [15] corresponds to:
(2.2) |
where and are now individual column vectors (at columns index ) in and , respectively. Here, represents the observed input features (visible variables), while represents the neural network weights and represent the hidden variables which are inferred from via the repeated application of the NMF right update rule until convergence. The columns of are often referred to as the learned basis vectors in the literature and we will also sometimes refer to them as (parts-based) prototypes in this paper.
We can see from the basic NMF formulation in Eq. 2.1 that it attempts to learn an approximation to a supplied data matrix. While this can be suitable for many tasks, it also tends to limit its use to unsupervised learning and tasks where NMF is able to provide sufficient modeling expressiveness. For these reasons, NMF is not typically used for supervised learning or for tasks requiring more expressiveness, such as modeling sequential data, for example.
2.2 Review of nearest neighbor prediction
The k-nearest-neighbors (k-NN) algorithm [10] is one of the simplest and most interpretable machine learning methods used for classification and regression. It is considered an instance or prototype-based method since the model stores the entire training dataset and then makes predictions using a similarity measure.
Suppose that we have a set of training example pairs where is the input feature vector and is the corresponding target output value or vector. In the classification setting, would contain the ground truth class label and could be represented either as an integer label or a 1-hot vector. In the regression setting, could more generally be an arbitrary real-valued vector. Since the following formulation could apply to either, we will refer to it as nearest neighbor prediction.
In typical machine learning algorithms, the model adjusts its weights (i.e., parameters) to minimize a loss function that measures the difference between the model’s predictions and the actual target values. However, there is no such loss function in nearest neighbors. Rather, the learning procedure is extremely simple as it only consists of storing the training examples in a suitable data structure as they become available. We will still refer to this data structure as the “weights” since it contains the learned knowledge extracted from the training set, which in this case is just the training examples themselves. For easier connection to the matrix-factorization-based models that follow in later sections, we will formulate nearest neighbor prediction as a special case of matrix factorization. For each training example , we create a corresponding column vector in the weights matrix as the vertical concatenation of the target on top of input feature . Specifically, let (i.e., class labels) and let . We then construct as:
(2.3) |
As a result, will be a x matrix containing all of the training examples as its columns. We will find it useful to split into an upper “prediction” sub-matrix consisting of the first rows (containing the targets), and a lower “recognition” sub-matrix consisting of the last rows (containing the input features):
(2.4) |
The nearest neighbor algorithm also requires us to define a suitable distance or similarity metric. For this, we can specify a similarity function which computes a similarity score between two supplied vectors based on their Euclidean distance or cosine similarity, for example.
Given a new input example , we can use the model to perform inference and predict the output (either classification or regression) as follows. In the first step, we use to compute the similarity score of against each of the columns in . The general form of the algorithm finds the k-nearest neighbors (k-NN) but we consider only the 1-nearest neighbor (1-NN) method here for simplicity. Let denote the index of the best-matching column in (which contains the training example ). We refer to this step as “recognition” since the input was recognized as its nearest neighbor , which represent the model’s reconstruction of . The second step then involves selecting the same column of so that the model will output as its prediction. We therefore refer to this as the “prediction” step.
We can express the inference solution to the 1-NN in the form of a matrix-vector product that will make the reconstructive-predictive aspect of the model explicit. We introduce a 1-hot vector having the same dimensionality as the column count in , where only index is set to 1. This allows us to express the model’s output prediction in the form of the following matrix-vector product:
(2.5) |
We can likewise express the model’s input prediction as:
(2.6) |
Using Eq. 2.4, these can be combined so that both the input and output predictions are given by a single product:
(2.7) |
Note that given the (1-hot) inference solution in , it picks out a single column in the weights to use as the predictions for both the output and input. For the more general k-NN algorithm, we could consider computing by setting all elements to 0 except those for indices corresponding to the k-nearest neighbors, for which we would use the corresponding value of the similarity function and then normalize these values to sum to 1 in . Several interpretable aspects of the 1-NN prediction algorithm are worth mentioning. For example, it provides confidence estimation by making use of the similarity score. When the model makes a wrong prediction (or when an out-of-distribution example is supplied), we can inspect the reconstructed input since it shows what the input was recognized as. Continual learning and knowledge removal are easily supported as well, since the training examples comprise the columns of and can therefore easily be added or removed as necessary.
A drawback of the 1-NN predictor is that the output is not a differentiable function of and the input . This prevents us from connecting it to a loss function such as classification loss in order to learn the weights using backpropagation, and also prevents its use as a building block for more complex architectures. The reliance on a manually specified similarity function can limit the prediction accuracy. In addition, inference can be expensive for large datasets since the number of columns in grows as the number of training examples. Despite these drawbacks, 1-NN and the more general k-NN can still sometimes produce a good balance of predictive performance and interpretability depending on the dataset. It is also worth mentioning that learnable versions of nearest neighbor prediction exist and are referred to as Learnable Vector Quantization (LVQ) [11] [12]. However, we are not aware of a fully differentiable version that would enable its use as a building block for more expressive architectures.
2.3 Review of MLP-based Neural Networks
In contrast to nearest neighbor methods, which are non-parametric, most machine learning algorithms use a set of learnable parameters. In modern neural networks, also known as deep neural networks (DNNs), most of the learnable parameters tend to be contained in the MLP blocks as discussed in the Section 1. The versatility of DNNs lies in their ability to use backpropagation [19] [20] [21] for learning parameters that minimize a chosen loss function—ranging from classification to regression losses. This flexibility allows us to take a foundational component like the MLP and scale it to create architectures of varying complexity such as convolutional networks, deep residual networks, recurrent neural networks (RNNs), transformer architectures, etc.. It is this ability to optimize the parameters so as to minimize a desired loss function that appears to contribute to DNNs’ superior predictive performance over other more interpretable methods.
The basic MLP corresponds to the sequential connection of two affine transformations (linear layers) with a non-linear activation function in between. The hidden layer activation vector is expressed as the following (differentiable) function of the parameters and input vector :
(2.8) |
where the weights matrix and bias vector are the parameters of the first linear layer and can be an arbitrary differentiable activation function. Common choices for include ReLU, GELU [22], and tanh, for example. The MLP’s predicted output, , is then obtained by applying a second affine transformation to :
(2.9) |
where the weights matrix and bias vector are the parameters of the second linear layer. Note that is differentiable with respect to the parameters of both linear layers, hidden activations, and input. This allows is to be used as a building block to create expressive neural architectures. Also note that unlike a single linear layer or Single Layer Perceptron (SLP), an MLP with at least one hidden layer can approximate complex non-linear functions [5] [6]. This ability comes from the non-linear activation function used in the hidden layer. As discussed in Section 1, a drawback of MLP-based models is that they are considered black-box models lacking interpretabiity and have challenges dealing with continual learning and training on non-i.i.d. examples.
Although the above presentation of MLPs might make them seem completely unrelated to nearest neighbor methods, we can make a connection between them. We do this by showing that a particular choice of weights initialization and activation function choice for the MLP makes it equivalent to the k-NN predictor. This is the reason why we overloaded to refer to both the MLP hidden activations in this section, while also using it to refer to the 1-hot vector (or k-nonzero in the case of k-NN) in Eqs 2.5 2.6 2.7. To see the connection, we first ignore the MLP bias terms. Instead of using the usual backpropagation procedure to learn the weights, we instead initialize MLP weights and to the corresponding nearest neighbor weights and that were used in Eqs 2.5 2.6 2.7 and also normalize the columns of to have unit L2 norm. Recall that contained all of the training input vectors as its columns while contained the corresponding target output vectors as its columns, for a total of columns in each weight matrix. For the activation function , we use a k-max activation followed by softmax, which has the effect of only passing the k-largest values and then normalizing them to have unit sum. For example, with , the output of will be a 1-hot vector similar to the vector that we used for nearest neighbors. The final step is that we need to ensure that the input to the MLP is normalized to have unit L2 norm. Note that with this setup, the matrix product in Eq. 2.8 computes the cosine similarity between each of the training inputs in and , corresponding to this choice of similarity measure in nearest neighbors. The activation function then has non-zero outputs only for the indices corresponding to the k-nearest neighbors in , where these values must sum to 1 in . As a result, we see from Eq. 2.9 that the MLP then outputs a linear combination of the corresponding columns in , (which contain the target training vectors ). Without using backpropagation, this “nearest-neighbor” MLP remains interpretable. However, if we attempt to train it, the interpretation of the weights as containing prototypes is potentially lost. For example, with the usual backpropagation training, there is no input reconstruction loss used by default, and so the input prediction (analogous to Eq. 2.6) may not necessarily result in a good reconstruction.
2.4 The Predictive Factorized Coupling (PFC) block
In this section we introduce the Predictive Factorized Coupling (PFC) block and discuss how it relates to the models in the previous background sections. Having covered the related models, we can now easily introduce the PFC block simply by re-interpreting the matrix product shown in Eq. 2.7 with the NMF modeling interpretation of Eq. 2.2 instead of k-NN, so that it represents the predicted input and output vectors after using an NMF algorithm to solve for :
(2.10) |
Interpreted as NMF, it shows the prediction for a single column vector of data matrix in the factorization , where contains the input vectors to be recognized as the columns of its lower sub-matrix and the corresponding target output vectors to be predicted as the columns of its upper sub-matrix :
(2.11) |
When used as a neural building block, we consider to contain the input vectors to the PFC block. The block then infers (solves for) and predicts for its output. Similar to neural network training, the corresponding targets are not available during the prediction (i.e., inference) process. is therefore partially observed during inference, since only its sub-matrix is available to the block as input. is then only used for the purpose of computing the prediction loss when it is available.
Recall that under the NMF modeling constraints, the three matrices of are only required to be non-negative (the non-negativity constraints for and can also potentially be removed if we allow semi-NMF, but we will assume NMF here for the purpose of describing the model). We see that is now less constrained compared to the nearest neighbor model that required to be 1-hot for 1-NN or contain only non-zero elements for the k-NN case. Since is learnable under NMF, we no longer need to use the same number of column vectors (which are often called basis vectors in the NMF literature) as training examples. Similar to Section 2.1, we let the hyperparameter specify the number of learnable basis vectors in , and these can now be initialized to random non-negative values as an alternative to initializing with training examples. also specifies the dimension of the hidden vector . By keeping internal to the block, we are free to later add or remove basis vectors from without changing the external interface of the block. With this interpretation, corresponds to the inferred hidden activations that represent an encoding of the input in terms of the (parts-based) basis vectors in . From Eq. 2.10 we see that is also composed of two sub-matrices, and , which represent the learned basis vectors as coupled “input-output” parts or prototypes. These could also be interpreted as learned key-value factors, where the input vector serves as the query, and the block is seen to perform a kind of factorized attention over parameters in predicting its output.
The factorization expression in Eq. 2.10 corresponding to the PFC block was first proposed in our earlier work [18] where we referred to it as a “coupling module” or “coupling factorization” and only considered its use in sequential data models. This contrasts with its more general usage as a building block for both sequential and non-sequential models in the current work. Refer to Section 6 for a more detailed discussion of related work.
2.4.1 Training and inference
We will show that the PFC block’s inference procedure is differentiable. This allows it to be used as a general neural building block in arbitrary computation graphs and trained with the usual backpropagation, similar to how existing MLP-based models are trained. From inspection of Eqs. 2.10 2.11, it may not initially be clear that the prediction process is differentiable. We see that the block’s predicted output, , is expressed as:
(2.12) |
and so is clearly differentiable with respect to and , but what about with respect to and ? Eq. 2.10 is simply a declarative expression stating that the input is approximately a linear function of the inferred :
(2.13) |
We need to show that the corresponding reverse direction imperative process of inferring from (and from ) is also differentiable. Letting represent this inference process, we need to show that is differentiable with respect to and . Recall from Section 2.1 that is computed by an iterative NMF algorithm consisting of a sequence of right-update steps until approximate convergence of to a fixed point. Let denote the function the computes a single right-update step (the subscript denotes iteration number here, not column index). If it takes iterations to converge, then corresponds to the -fold composition of so that we have:
(2.14) |
It then only remains to show that is differentiable. There are several options for but we will use the simple and well-known SGD update steps which we review in Appendix A. Eq. A.6 shows the right-update step for the more general case of (batched) matrix input rather than vector input , and so we repeat it here for the vector case:
(2.15) |
The inference learning rate controls the step size. With the choice of Eq. 2.15 as , we see that it is indeed differentiable. In summary, the inference procedure given an input is as follows. We first apply the NMF right-update rule in Eq. 2.15 times (assuming convergence is reached by then) to infer . We then apply the final linear prediction step in Eq. 2.12 to compute the predicted output. Note that the targets in Eq. 2.10 are masked while performing inference, since we are predicting them as . For this reason, we will also refer to this extension of NMF as masked predictive NMF.
Since NMF is used to infer , we can interpret it as follows. If the input is well modeled as consisting of a mixture of parts, then the NMF solution can potentially discover those parts (assuming an appropriate learning algorithm for ). Using NMF terminology, the columns of can be interpreted as the learned parts or “basis vectors” so that the inferred then specifies an additive encoding of the input in terms of the basis vectors. Intuitively, we can interpret 2.13 as expressing that the input is approximately generated (reconstructed) as a linear function of the inferred . That is, the reconstruction is given as . Thus, the process of running an optimization algorithm to solve for corresponds to the model trying to recognize its input in terms of the already-learned “parts”. When the recognition is successful, the model is able to find an encoding of the input in terms of these parts that results in a low reconstruction error . When it is unsuccessful, will be large. So although we now need to do more computation compared to the MLP, we get a useful new property: the PFC block can now give us feedback through the reconstruction and its error to tell us how well it was able recognize its input.
We note that it is possible to automate the selection of the learning rate and to accelerate the inference procedure so that fewer iterations are required by leveraging a modified version of the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) [23] algorithm that removes the shrinkage step and adapts it to NMF as we describe in Appendix B. We used this method of unrolling in all of our experiments.
To learn the weights, we can use the PFC block in an arbitrary computation graph and train with the usual backpropagation algorithm. For example, we could compute a classification or regression loss between and , perform backpropagation to compute the error gradients, and use an existing optimizer such as SGD, RMSprop [24], etc. to update the weights. When using the NMF modeling constraint, we also need to clip any negative values in the weights to zero after each optimizer update. We can also consider training on mini-batches instead of individual examples by replacing with a matrix containing a batch of examples.
The general idea of unfolding an iterative and differentiable optimization algorithm, such as NMF in our case, into a computation graph and using backpropagation to learn its parameters is not a new idea. It is sometimes referred to as algorithm unrolling or unrolled neural networks in the literature [25]. For details on related work, refer to Section 6.
3 Factorized RNNs
In this section we develop a simple matrix-factorization-based replacement for the vanilla RNN [26]. Our “factorized” RNN is modeled as a single matrix factorization of the form without any activation functions, and containing all of the model’s input activations, weights, hidden state activations, and output activations. This makes its declarative model one of the conceptually simplest RNNs that we are aware of. We will often be interested in the case where these matrices are constrained to be non-negative to improve interpretability, so that the model will then more specifically correspond to an instance of non-negative matrix factorization (NMF). Since the model is NMF, it retains all of the desirable properties of NMF while also more directly supporting the modeling of sequential data. The models we develop in this section are based on the similar factorized recurrent approach used in Section 3 of [18], although here we use a different recurrent architecture and propose modified and more effective backpropagation-based learning methods using the algorithm unrolling method from Section 2.4.
3.1 Review of the vanilla RNN
To motivate the idea, we first need to review the vanilla RNN. In the following, we initially assume that the weights ( matrices) have already been learned, so that we only need to consider the inference or forward pass through the network to compute its output predictions from its inputs. Given an input sequence of length containing the feature vectors , the RNN maps them in order, one at a time, into a corresponding output sequence of vectors . The subscript denoting the position in the sequence is often called the “time slice”, even when there is no notion of time involved. The reason they must be mapped one at a time is because the RNN maintains an internal hidden state vector which is consumed as an additional (hidden) input during each time slice, modified, and produced as a (hidden) output for use in the next time slice. So, in time slice , the RNN will consume input and previous hidden state input . It then produces a new hidden state and output . The vanilla RNN does this in two stages. In the first stage, we first update the hidden state
(3.1) |
where denotes an arbitrary activation function and is the bias vector. In the second stage, we compute the output from the updated hidden state:
(3.2) |
Note that Eq. 3.1 can be rewritten as a single linear layer followed by nonlinear activation function:
(3.3) |
where we let refer to the combined weights:
(3.4) |
and let refer to the combined inputs:
(3.5) |
(3.6) |
This shows that each time slice of the RNN can be interpreted as an MLP that takes input and produces outputs . From Eq. 3.5, we see that the previous hidden state appear together with in the input , and the updated hidden state corresponds to the hidden layer of the MLP after the activation as shown in Eq. 3.3.
3.2 The factorized RNN
Section 3.1 showed that the each time slice of the vanilla RNN can be interpreted as an MLP. With this interpretation, it is now interesting to consider the model that results from replacing each of these MLP blocks with a corresponding PFC block. This corresponds to keeping the output linear layer in Eq. 3.2 (although with the bias term removed). We replace the input linear layer and activation function from Eq. 3.3 with the following vector factorization for the ’th time slice:
(3.7) |
With this change, we have now reversed the direction of the linear mapping compared to the MLP so that the input is approximately a linear function of the hidden state (contrast this to the MLP in Eq. 3.3 where the pre-activation is a linear function of ). Our factorized representation is also simplified compared to Eq. 3.3 since we have removed the need for an activation function and bias term. The tradeoff is that now when we are given an input , we will require an iterative NMF update algorithm to solve for , which could be more computationally costly compared to the MLP.
With the computed , we then compute the output as in the vanilla RNN using Eq. 3.2. Applying Eq. 3.2 and Eq. 3.7 to all time slices , we finally arrive at the factorized vanilla RNN expressed as a single matrix factorization of the form :
(3.8) |
Using Eq. 3.5, we can also express the left matrix in terms of , , and . This brings us to the key result which lets us express the the factorized RNN with all of the model inputs, outputs, hidden states, and weights together in a single matrix factorization:
(3.9) |
Note that for a single time slice, our model corresponds to the following vector factorization:
(3.10) |
If we use to denote the three stacked weights sub-matrices, the notation simplifies even further to the following:
(3.11) |
Now contrast the simplicity of the factorized RNN for a single time slice in Eq. 3.11 with the corresponding vanilla RNN expression for a single time slice in Eq. 3.6. Note that they differ in that Eq. 3.11 is a declarative representation while Eq. 3.6 is imperative. That is, for the factorized RNN we will still need to find suitable algorithms to actually solve the factorization, whereas the vanilla RNN expression explicitly tells us the steps needed to produce the output.
We can further simplify the notation by replacing each of the sequences in Eq. 3.9 with their respective sub-matrices. Letting , , and results in the following:
(3.12) |
Regarding the hidden states, we see that the initial “previous” only appears in the left matrix while the final state only appears in the right matrix. The other hidden states are duplicated since for appear in both and . Also note that must be an x square sub-matrix of since if the are -dimensional then each of , , and must also have columns.
3.3 Training with alternating NMF update rules
A simple method of training the factorized RNN consists of performing alternating NMF updates to and , while also copying the inferred states from to the corresponding duplicated positions in the data matrix after each inference update step to satisfy the consistency constraints. This is similar to the approach that we used for the sequential models in [18]. The details of this are as follows. We use the simple SGD-based algorithm which is reviewed in Appendix A for our experiments, but a variety algorithms can potentially be used.. Starting from the factorization in Eq. 3.9, we begin by initializing the weights to small random values and initializing the hidden states to either zeros or small random values. This corresponds to setting sub-matrices and to zeros or small random values in Eq. 3.12. Let represent the training inputs and represent the corresponding target output values that we want to predict. Since we want to predict the outputs, using only the provided inputs, we first need to infer the hidden states. For that we use the subpart of Eq. 3.9 corresponding to Eq. 3.7:
(3.13) |
The task is then to solve for the , which we can do by alternating between matrix factorization updates of the right matrix followed by enforcing the constraint that the duplicated must have equal values. We can do this by simply copying the from the (right factor matrix) to (left data matrix) after each NMF update to . Once the updates converge, we can then use the top-most sub-factorization of Eq. 3.9 to compute the predicted outputs as a linear function of the :
(3.14) |
Now that the inference part of the algorithm has complete, we perform the learning updates. We replace the predicted outputs with the target outputs above and perform a NMF update on . Similarly, we perform an NMF update on and in Eq. 3.13. We then repeat the above procedure until convergence.
Since the hidden state updates propagate one slice forward for each NMF update, it will take the same number of updates as the sequence length for information from the first slice to potentially reach the last slice. As a result, if the sequence of long, the initial NMF updates late in the sequence could be considered wasted computation since they will be operating on hidden states that only contain information from nearby slices. Whether or not this actually becomes an issue in practice would seem to depend on how far information can actually propagate in an RNN, which is outside the scope of this paper.
As an alternative to performing the “full batch” updates as above, we can consider performing the inference “one time slice at a time”. Specifically, we start at the first slice and wait for the inference procedure to converge on that slice before containing with the next. Since the (output) inferred hidden state has now converged, we then copy it into the duplicated (input) location in the next () slice of . We can now increment the current slice to and continue in the same way so that the input states in the current slice of can always be considered to have already converged. This is similar to how inference is carried out in the vanilla RNN, since both RNNs share the same temporal dependency ordering between the hidden states. Although the slice-at-a-time inference option seems less able to take advantage of parallel hardware, it also seems potentially more efficient when the sequence length is extremely long since we perform the minimum number of iterations needed for convergence on each state before moving on to the next. Both options seem interesting but a detailed empirical comparison of their relative efficiency is outside the scope of this paper. Perhaps a batch-wise version could also be interesting as a topic for future research. We only consider the latter slice-at-a-time option in the experiments in Section 5.8.
3.4 Training by unrolling NMF inference and backpropagation
The training algorithm in Section 3.3 used standard MF update rules to learn the model weights. These updates optimize the local reconstruction loss of the data matrix. Depending on the task, we empirically observed that such algorithms can sometimes be sufficient and we demonstrate example of this in the experiments Sections 5.8 5.9. However, we generally found this method to fail to perform well on more complex tasks. For this reason we use the algorithm unrolling approach as discussed in Section 2.4 for all other experiments.
It is straightforward to apply unrolling to the factorized RNN since the inference algorithm remains unchanged: we evaluate the computation graph and backpropagate through it. Since we replaced each MLP of the vanilla RNN with a corresponding PFC block, the computation graph dependencies result in the inference progressing one time slice at a time, similar to the vanilla RNN. However, note that in performing the inference in a given time slice, the PFC block’s unrolled NMF update steps form another RNN (internal to each PFC block) with length corresponding to the number of unrolled iterations. We therefore have a computation graph corresponding to an RNN within an RNN.
We compute the loss using the MSE loss between the targets and the predicted values. Since the factorized RNN has three predicted sub-matrices in (Eq. 3.12), this means we will have three loss terms. Once the inference procedure converges so that the inferred hidden states are available in , we can predict as follows:
(3.15) |
The total loss is the the sum of the three loss terms with an arbitrary non-negative scale factor to adjust the relative strength of each term:
(3.16) |
We then backpropagate through the loss to compute the error gradients and use an existing SGD-based optimizer such as RMSprop etc. to update the weights. Since the inference procedure operated one time slice at a time, we can see that the gradients then flow backward through the entire sequence, effectively making it a instance of backpropagation through time (BPTT).
4 An optimizer for continual learning and non-i.i.d. training
To motivate our approach, recall that k-NN simply stores the training examples as is and then performs classification by finding the nearest training examples to the current input. Such an instance-based model can easily support continual learning since we simply append the new examples to the model weights as they become available. Knowledge removal after training (unlearning) is also easily supported by simply removing any desired subset of “bad” training examples from model weights. In contrast, the knowledge representation in the PFC-based model is more distributed since the SGD optimizer update from any particular training example or batch could potentially result in modifications to any of the weights. Additionally, each optimizer update only slightly modifies the weight values, so that many updates are needed before the weights can be considered fully learned. With this understanding, we can modify the optimizer update so that only a narrow unmasked “learnable window” of basis vectors in each weights matrix in the model are able to be modified. We can allow the window to slowly sweep through (e.g., from left to right), advancing slightly with each optimizer update, so that each basis vector ideally remains inside the window long enough to be effectively learned, but not so long as to be overwritten if or when the distribution of training examples changes. If we keep track of the mapping from training batch index to window position during training, then we will be able to identify the (small) subset of weights that can potentially contain the corresponding learned knowledge. For example, this would be the case if the ordering of training examples within each each does not change, such as using a fixed random shuffling. If the training distribution changes (e.g., in continual learning and/or non-i.i.d. training), the learned basis vectors will be protected from being overwritten because they are only learnable for limited number of optimizer updates over which we assume the training distribution to be relatively unchanging.
We now introduce a new “sliding learnable window” (SLW) optimizer that can be used with PFC-based models to support improved continual learning, training with non-i.i.d. examples, and knowledge removal after training. Specifically, suppose each weights matrix has columns (basis vectors) in total. The learnable window will have a width of basis vectors, where is a tunable hyperparameter and . We denote the current position of the window by the index of its left-most column in . When training starts, we initialize the learnable window to consist only of the leftmost basis vectors of each weight matrix by setting . As training progresses, we then increment by some small fractional amount, which is specified by the hyperparameter . This results in the learnable window slowly sweeping to the right within its weight matrix. We can see that this will have the effect of leaving all basis vectors to the left of frozen at their current values. Only the basis vectors at column indices receive optimizer updates from the current training batch. Likewise, all basis vectors to the right of the window (i.e., with index such that ) consist of unused weights which have not yet made their way into the learnable region. If the sweep speed is chosen too small, we will end up with many unused basis vectors at the end of training. However, if it is chosen too large then we will run out of weights storage before training can complete, unless we dynamically allocate additional weights storage and concatenate it to the right of to ensure a continuous supply of unused weights.. If is chosen too small, any given basis vector might not remain under the learnable window long enough to be fully learned. However, if it is chosen too large then the basis columns could remain learnable too long so that they would then be vulnerable to being overwritten as the training distribution shifts. In the case of non-i.i.d. training or when unlearning capability is required (and assuming the ordering of examples does not change from epoch to epoch), resetting at the beginning of each epoch will ensure that the optimizer update for any given training batch is always mapped to same location in .
5 Experiments
In this section we present experimental results demonstrating the use of the PFC block as a replacement for the MLP block. Section 5.1 first establishes baseline accuracy results of an MLP-based classifier. In Section 5.2 we then replace the MLP with a PFC block and evaluate its accuracy on the same datasets. We also conduct ablation experiments to compare the effect of different modeling constraints such as NMF vs semi-NMF, as well as the effect of either including or disabling the input reconstruction loss term. We compare these MLP and PFC-based models on a continual learning task in Section 5.3 and on a non-i.i.d. training task in Section 5.4. We demonstrate how PFC-based models can support knowledge removal after training in Section 5.5. We show another example of interpretabiity by visualizing in-domain vs out-of-domain inputs in Section 5.6.
A usable MLP replacement must also be capable of supporting architectures with multiple blocks. The remaining experiments consider two simple multi-block architectures. In Section 5.7 we show results for a 2-block (fully-connected) residual network using PFC blocks instead of MLP blocks and demonstrate that it produces competitive accuracy with corresponding MLP-based models. Replacing the MLP blocks of the vanilla RNN with PFC blocks results in a factorized RNN as introduced in Section 3 and we present results for these sequential architectures in Sections 5.8 5.9 5.10 5.11. We also perform the following ablations: We compare either using the default backpropagation (i.e., unrolled NMF inference algorithm) or disabling it and using NMF-based weight update rules similar to [18]. When using backpropagation-based training, we also evaluate the effect of disabling BPTT.
We conduct the image classification experiments using fully-connected models instead of architectures such as convolutional networks that arguably have a more suitable inductive prior. Our accuracy results will therefore be significantly below state of the art on the datasets used. Since the purpose of these experiments is to evaluate the suitability of the PFC block as a more interpretable replacement to the MLP block, we are therefore concerned with comparing the relative accuracy of these two blocks rather than attempting to achieve state of the art results. For the same reason, we compare the relative accuracy of the vanilla RNN against the corresponding factorized RNN rather than using more sophisticated RNN or transformer models on the sequence modeling tasks.
All experiments were carried out using single RTX 4090 GPU. Due to the limited compute, we did little hyperparameter tunning and worked with small datasets. Consequently, it seems possible that further improvements to accuracy and/or efficiency could be obtained with additional hyperparameter tuning and it remains unknown how well these results will scale to larger datasets. We did not perform ablations on the number of iterations required for reliable convergence. It is potentially possible to improve efficiency by dynamically unrolling the inference algorithm only for the number of iterations needed for reasonable convergence based on the current inputs. However, we do not attempt this in these experiments and leave it as future research. Results for the MLP-based models were generally run 3 times, but due to the limited compute, the PFC-based model results are shown for a single run and not averaged unless otherwise mentioned.
We trained the models using early stopping with an 85%-15% train-validation split. For simplicity and convenience, we use the MSE loss everywhere, for both the classification/regression loss and the input reconstruction loss (applicable to PFC-based models only). We used equal weighting between the reconstruction and prediction loss terms for simplicity. We use the RMSprop optimizer [24] since it is a simple optimizer that we found to perform well with minimal hyperparameter tuning. For the PFC-based models, parameters are initialized to uniform random values in by default, corresponding to the NMF modeling constraint. In some experiments we allow negative parameters, corresponding to the semi-NMF modeling constraint. When negative parameters are allowed, they are initialized to uniform random values in . The inferred values for the factor matrices are always constrained to be non-negative. We initialize the values to zeros, but note that it is also an option to initialize them to small random values. Unless otherwise mentioned, we set the learning rate to 3e-4 for the PFC-based models and 1e-4 for the MLP models. The weight decay was set to 1e-4 unless otherwise mentioned. For the MLP experiments, weights were initialized using the default PyTorch LinearLayer initializer and negative parameters were always allowed since attempting to use non-negative weights with MLPs resulted in optimization difficulties. The GELU activation [22] was used as the MLP hidden layer activation function.
5.1 MLP baseline for image classification
We first evaluate a simple 1-hidden-layer MLP-based classifier as a baseline model on the MNIST [27], Fashion MNIST [28], and CIFAR10 [29] datasets. For each dataset, we train models for hidden dimensions sizes of 300, 2000, and 5000. The input feature size is equal to the number of image pixels when the image is flattened into a vector. This is for MNIST and Fashion MNIST since they contain 28x28 grayscale images and for CIFAR10 since color images are used. The output layer dimension is 10 since all three datasets have 10 class labels. Table 1 shows the accuracy results.
Dataset | Hidden Dimension | Test Accuracy |
---|---|---|
MNIST | 300 | 98.01% |
MNIST | 2000 | 98.26% |
MNIST | 5000 | 98.32% |
Fashion MNIST | 300 | 88.07% |
Fashion MNIST | 2000 | 88.72% |
Fashion MNIST | 5000 | 88.85% |
CIAFAR10 | 300 | 51.59% |
CIAFAR10 | 2000 | 52.78% |
CIAFAR10 | 5000 | 53.17% |
5.2 PFC network for image classification
We train a 1-block PFC-based network on the same datasets and with the same parameter sizes as the MLP from Section 5.1. Recall that the PFC basis vector count corresponds to the MLP hidden dimension. We also train with and without enforcing non-negative parameters (i.e., NMF vs semi-NMF) and also evaluate the effect of including vs disabling the input reconstruction loss term.
We train with early stopping after 20 epochs with no validation loss improvement. Table 2 shows the accuracy results, averaged over 3 training runs. We see that using semi-NMF generally leads to slightly better accuracy. The impact of input reconstruction loss on accuracy is less clear and seems to vary depending on the specific dataset and parameter configuration. Since both the NMF constraint and enabling reconstruction loss can potentially lead to better interpretability, we will enable the reconstruction loss in all remaining experiments. We will also use the NMF parameter constraint in the remaining experiments unless otherwise mentioned.
Comparing the PFC results in Table 2 with the MLP results in Table 1 shows the PFC-based models to perform competitively. We see that the PFC-based model performs significantly better compared to the MLP on CIFAR10, slightly better on Fashion MNIST, and similarly on MNIST, although the MLP does perform slightly better on MNIST for the case of 300-dimensional hidden dimension when the NMF constraint is used on the PFC-based network.
Dataset | Basis Vector Count | Parameter Constraints | Input Reconstruction Loss | Test Accuracy |
---|---|---|---|---|
MNIST | 300 | NMF | Yes | 96.83% |
MNIST | 300 | NMF | No | 97.54% |
MNIST | 300 | Semi-NMF | Yes | 97.51% |
MNIST | 300 | Semi-NMF | No | 98.68% |
MNIST | 2000 | NMF | Yes | 97.78% |
MNIST | 2000 | NMF | No | 98.14% |
MNIST | 2000 | Semi-NMF | Yes | 98.47% |
MNIST | 2000 | Semi-NMF | No | 98.84% |
MNIST | 5000 | NMF | Yes | 98.08% |
MNIST | 5000 | NMF | No | 98.27% |
MNIST | 5000 | Semi-NMF | Yes | 98.65% |
MNIST | 5000 | Semi-NMF | No | 98.88% |
Fashion MNIST | 300 | NMF | Yes | 88.62% |
Fashion MNIST | 300 | NMF | No | 88.24% |
Fashion MNIST | 300 | Semi-NMF | Yes | 88.60% |
Fashion MNIST | 300 | Semi-NMF | No | 89.77% |
Fashion MNIST | 2000 | NMF | Yes | 90.21% |
Fashion MNIST | 2000 | NMF | No | 90.02% |
Fashion MNIST | 2000 | Semi-NMF | Yes | 90.58% |
Fashion MNIST | 2000 | Semi-NMF | No | 90.56% |
Fashion MNIST | 5000 | NMF | Yes | 90.67% |
Fashion MNIST | 5000 | NMF | No | 90.64% |
Fashion MNIST | 5000 | Semi-NMF | Yes | 90.82% |
Fashion MNIST | 5000 | Semi-NMF | No | 90.78% |
CIFAR10 | 300 | NMF | Yes | 53.25% |
CIFAR10 | 300 | NMF | No | 54.42% |
CIFAR10 | 300 | Semi-NMF | Yes | 51.54% |
CIFAR10 | 300 | Semi-NMF | No | 54.32% |
CIFAR10 | 2000 | NMF | Yes | 58.69% |
CIFAR10 | 2000 | NMF | No | 58.30% |
CIFAR10 | 2000 | Semi-NMF | Yes | 57.46% |
CIFAR10 | 2000 | Semi-NMF | No | 55.78% |
CIFAR10 | 5000 | NMF | Yes | 60.12% |
CIFAR10 | 5000 | NMF | No | 59.68% |
CIFAR10 | 5000 | Semi-NMF | Yes | 59.69% |
CIFAR10 | 5000 | Semi-NMF | No | 57.90% |
5.3 Continual learning on the Split MNIST task
In this experiment we evaluate and compare the performance of the PFC and MLP-based models on the Split MNIST task [30] under the Class-IL scenario as described in [13]. In Split MNIST, the MNIST dataset is split into 5 task partitions, so that each task only contains two digits: Split 0 contains digits 0 and 1, Split 1 contains digits 2 and 3, and so on up to Split 4. The Class-IL scenario is the most difficult of the three considered in [13], as it requires the model solve the tasks that have appeared so far as well as inferring the task ID. Since the model is not told which of the 5 tasks it needs to solve, it only receives the image pixels as input and needs predict the correct digit label. The model will therefore have 28x28 = 784 inputs corresponding to the image pixels flattened into a vector and it will have 10 outputs corresponding to the digit labels. We train the model on one split at a time, ordered by the split number so that Split 0 will be the first task. Within each task the examples are presented i.i.d. to the model. The validation loss is computed over all tasks seen so far and early stopping is used to end the current task and move on to the next once the best validation loss is achieved.
5.3.1 Continual learning performance of baseline MLP-based model
In [13], the authors found that regularization-based continual learning methods such as EWC [31] fail completely on the Class-IL scenario and that memory replay-based methods were needed in order to achieve acceptable accuracy. Specifically, on the Split MNIST task under Class-IL, they found that both a baseline MLP model and regularization-based methods such as EWC resulted in accuracies in the 19-20% range (Table 4 of [13]).
We also evaluate an MLP as a baseline model on this task. We experimented with a range of hidden dimension sizes from 300 - 2000, but only report results for a hidden dimension of 1357 since it corresponds to the same parameter size as the PFC-based model used in Section 5.3.2. We use the RMSprop optimizer with weight decay set to 1e-4. When we tried using a learning rate of 1e-4 that worked well in the MLP for other experiments, it resulted in 19-20% accuracy on the test set, roughly matching the results reported in [13]. However, when we tuned the learning rate to maximize the validation set accuracy, we were surprised to find much better performance when the learning rate was reduced to 2e-6. This resulted in a test set accuracy of 40.11%. We therefore suspect that the difference in accuracies between our baseline MLP and that reported in [13] could be due to hyperparameter tuning. Regardless, even 40.11% is a poor result considering that the upper bound on accuracy when the examples of all tasks are combined together and presented i.i.d to the model is 97.94% [13]. In the following sections, we will attempt to improve from this 40.11% baseline accuracy without resorting to replay-based methods.
5.3.2 Continual learning performance of baseline PFC-based model
We evaluate a one-block PFC-based model using non-negative weights and 1357 basis vectors, resulting in approximately the same parameter sizes as the baseline MLP. We continue to use the RMSprop optimizer with weight decay of 1e-4 and a learning rate of 1e-5 found through hyperparameter search on the validation set. This model achieved 67.07% accuracy on the test set, which is significantly higher than the baseline MLP but still far below the upper bound accuracy.
Since this model uses non-negative weights and can potentially learn parts-based representations, it is interesting to visualize the learned weights after training on each of the five tasks. Perhaps we might then be able to better understand what could be causing the model to gradually forget the earlier learned tasks. Figure 1 shows the first 100 weight basis vectors, reshaped into images after learning to classify the two MNIST digits in each of the 5 consecutive split tasks. Comparing these images, we immediately notice a problem: some of the “digit” images learned in earlier tasks gradually fade away as the model learns newer tasks. For example, after completing the first split task, the model can classify the digits 0 and 1, and so we see that the weights in Figure 1(a) look like zeros, ones, or noise. The model learns to classify digits 2 and 3 during the next split task and so we see these new digits appear in the weights after this task has completed, as expected. However, notice that some of the original 0 and 1 patterns have started to fade or degrade slightly as well. By the time the model has completed training the final split task (i.e., classifying digits 8 and 9), notice that many of the original 0 and 1 patterns are now quite degraded, although a few of them still seem relatively unaffected. We can also see that digits 8 and 9 are difficult to find after learning the final split task in Figure 1(e). It seems this is due to the model only training a single epoch on the final task, which maximized the overall validation loss (which is now computed over all splits). It was apparently a better accuracy tradeoff to have relatively poor performance in classifying digits 8 and 9, rather then learning to classify them well and significantly reduce the performance (through forgetting) on all previous tasks. In the next section, we will introduce a simple method to prevent the earlier learned weights from degrading as new tasks are learned.





5.3.3 Continual learning performance of PFC-based model with a learnable sliding window optimizer
We have implemented this sliding learnable window idea from Section 4 in a customized RMSprop optimizer which we will refer to as RMSpropSLW. This optimizer takes two additional hyperparameters compared to the standard one: We must specify the learnable width of the window, as well as , which specifies the (fractional) number of columns that the window advances to the right on each optimizer update.
We then repeat the Split MNIST experiment from Section 5.3.2, replacing the RMSprop optimizer with RMSpropSLW. We set the initial (unused) number of basis vectors to 2000 in each of the two weight matrices. We used a sweep speed of 0.25 and a learnable width of 15. We adjusted the slide speed until training was able to complete without using all 2000 basis vectors. Using these values resulted in the right-most column of the learnable window arriving at the 1357’th basis column at the end of training the last task. Note that since early stopping was used to decide when to switch to the next task, the number of in-use basis vectors at the end of training will vary slightly from run to run. Thus, 643 columns remained unused at their initialized random values. This resulted in an accuracy of 93.73% on the test set, demonstrating the effectiveness of the approach compared to the standard RMSprop optimizer. Figure 2 shows the weights (reshaped as images) immediately to the left of the learnable window as each of the split tasks is completed. Since these weights are now to left of the sliding learnable window, they remain frozen during the remainder of training. Thus, weights learned during the first split task (i.e., classifying digits 0 and 1) as shown in 2(a) were protected from being overwritten during later tasks. Notice that the image features visible after training each split only contain the two digits learned during the split. For example, 2(c) only contains image features for the digits 4 and 5, since the training examples for this split only contain these two digits. Table 3 summarizes the results on this task for the various approaches discussed. We see that the PFC-based model performs significantly better compared to the MLP even when both use the same RMSprop optimizer. Switching to the RMSpropSLW optimizer further increases the accuracy of the PFC-based model much closer to the upper bound accuracy. Note that neither the RMSpropSLW optimizer nor the model is given any information about which task is active, as required by the Class-IL scenario. Also note that the MLP-based model cannot use the RMSpropSLW optimizer since the modeling assumptions appear to be incompatible. For reference, replay-based approaches were reported to achieve between 90.79% and 91.79% and replay + exemplars were reported to achieve 94.57% accuracy in Table 4 of [13]. This shows that our non-replay-based PFC + RMSpropSLW approach is somewhat competitive with even replay-based approaches.





Approach | Optimizer | Test Accuracy |
---|---|---|
MLP offline iid training ([13])(upper bound) | RMSprop | 97.94% |
MLP baseline (ours) | RMSprop | 40.11% |
MLP baseline ([13]) | ADAM | 19.90% |
EWC ([13]) | ADAM | 20.01% |
PFC | RMSprop | 67.07% |
PFC | RMSpropSLW | 93.73% |
5.4 Non-i.i.d. training: MNIST classification with label-ordered examples
Neural networks are most effectively trained when the stream of training examples is presented i.i.d. and the class labels are balanced. Training difficulties being to arise when distribution shifts are introduced. In this experiment, we train MLP and PFC-based models on the MNIST classification task. However, rather than the usual i.i.d. training in which the examples are presented in shuffled order, we instead present the examples in a fixed label-sorted order. Specifically, we sort the training examples ascending by class label so that all of the digit 0 examples are presented first, followed by all of the digit 1 examples, and so on, until finally presenting all of the digit 9 examples. We also train the networks with the usual i.i.d. (shuffled) ordering for comparison to provide an upper bound on the achievable accuracy.
With the PFC-based model, we also have the option of using the sliding learnable window optimizer from Section 4, which is implemented in RMSpropSLW. Since the examples are presented in the same order each epoch, we simply reset the learnable window to the starting position (i.e., the left-most column of each weight matrix) at the beginning of each epoch. This can potentially improve performance when training on non-i.i.d. data since optimizer updates for a particular example (or its batch) are constrained to a small learnable window of the weights. As a result, only nearby training batches (in terms of the number of training iterations between them) are capable of overwriting previous learned knowledge. More widely spaced batches cannot interfere with each other.
5.4.1 Training details
The batch size was set to 50 for all models. The MLP network has a single hidden layer with a size of 2600. We use the RMSprop optimizer. For for i.i.d. training case, the learning rate is 1e-4. The weight decay is 1e-4. For the label-ordered case, we found a lower learning rate of 5e-7 and no weight decay to give the best validation performance.
For the PFC-based network, we trained networks under the following 4 combinations: using either i.i.d. or label-sorted examples, and using either RMSprop or RMSpropSLW optimizers. When using RMSpropSLW, we set the maximum weight basis vector count to 3000, of which 2600 were actually used in training. We used a learning rate of 5e-6 and no weight decay. When using RMSprop, we used 2600 basis vectors. This resulted in the PFC-based networks having the same parameter count as the MLP network. The weights were constrained to be non-negative. We used a learning rate of 2e-3 and weight decay of 1e-4.
5.4.2 Results
The results are summarized in Table 4. We see that the MLP and PFC-based models perform similarly when trained on i.i.d. examples and using the RMSprop optimizer. When the PFC-based model is instead trained on the RMSpropSLW optimizer, we see the accuracy is reduced slightly. This is not surprising since the use of a narrow learnable window constrains only a small subset of the weights from receiving optimizer updates. We see that the MLP accuracy suffers worse degradation when label-sorted training is used, only reaching 84.92%, compared to the PFC-based model’s 88.19% when both use RMSprop. However, the PFC-based model’s accuracy only degrades slightly to 96.06% when using the RMSpropSLW optimizer. Note that when using RMSpropSLW, the PFC-based model has similar accuracy under both the i.i.d. and label-sorted cases. In summary, the PFC-based model was more robust to non-i.i.d. examples than the MLP when both used the same optimizer. When using the sliding window RMSpropSLW (which is only compatible with the PFC-based model), this robustness was further increased.
Model | Training Method | Optimizer | Test Accuracy |
---|---|---|---|
MLP | i.i.d. training (upper bound) | RMSprop | 97.91% |
PFC | i.i.d. training (upper bound) | RMSprop | 97.88% |
PFC | i.i.d. training (upper bound) | RMSpropSLW | 96.02% |
MLP | label-ordered | RMSprop | 84.92% |
PFC | label-ordered | RMSprop | 88.19% |
PFC | label-ordered | RMSpropSLW | 96.06% |
5.5 Unlearning: removing knowledge from a trained model
It is sometimes desireable to remove learned knowledge from a model. For example, if it is discovered that certain training batches contained errors, or if certain knowledge needs to be removed for legal reasons. A large model might take a long time to train, and so the ability to quickly remove specific knowledge could be preferable to retaining from scratch.
In this experiment, we train a network on the MNIST classification task in which some of the training examples have incorrect class labels. We show that this reduces the classification performance, as is expected. Provided that the bad examples/batches can be identified after training has completed, we show that it is possible to perform unlearning by removing the subset of model weights that were influenced by these bad examples during training, restoring much of the lost classification accuracy.
We use a model with 1 PFC block and 3000 available basis columns for each of the two weight matrices. For simplicity, we constrain the corrupted training examples so that they appear a contiguous range of batches. We use the same fixed shuffling of examples each epoch so that the i’th batch of each epoch always contains the same examples. There are then a total of 1020 batches per epoch (with 50 examples per batch). Batches 500 through 800 are corrupted by changing the class label to different (and incorrect) class. This is accomplished by incrementing the label modulo the number of classes.
We use the same ‘RMSpropOptimizerSlidingWindow‘ optimizer from the continual learning experiments. We reset the sliding window to the left-most position at the beginning of each epoch and use a deterministic dataset loader so that the i’th batch always contains the same examples in each epoch. This will cause the learning updates corresponding to the i’th batch to be stored within a knowable range of columns in the weights matrices, corresponding to the position of the sliding window while the said batch was active.
After the model is trained, 2600 of the 3000 basis columns are in use, as Figure 3(a) shows. The right-most 400 columns remain at their randomly initialized values, but this is not a problem and does not affect the results since these columns are not activated during the NMF inference process. Since the training data included a significant fraction of corrupted examples, the classification accuracy is a somewhat low 83.81% on the MNIST test set.
Next, we perform the unlearning operation, which will attempt to remove the knowledge obtained from the corrupted training examples from the network. This works as follows. First, recall that training batches 500 through 800 were identified as containing the corrupted labels. Since our SLW optimizer constrains the optimizer update for each batch to a small learnable window of the weights, we need to find the corresponding union of all positions of the learnable window during learning updates for these batches. We find that batch index 500 corresponds to the left-most column of the learnable window being at index 1250, and batch index 800 corresponds to the right-most column of the learnable window being at index 2050. That is, during this range of (corrupted example) batch updates, the sliding learnable window covered columns in the weight matrices ranging from 1250 through 2050, so that any knowledge learned from these batches must be in this subset of the weights. It is then straightforward to remove the knowledge, such as by deleting these weights, setting them to zero, or reinitializing to random values. For this experiment, we set these weights to zero, as shown in Figure 3(b). With the corrupted knowledge removed, we then re-evaluate the network on the test set and see that the accuracy has improved to 92.77%. For reference, if we train the model from scratch excluding the corrupted examples, we get 95.35% on the test set, which sets an upper bound on the unlearning accuracy. These results are summarized in Table 5. This shows that our method is effective in restoring accuracy without retraining, provided that we are able to identify which training batches were bad. Note that if any training examples are identified as bad, then the entire batch and corresponding region of the weights must be thrown out. This method is therefore most effective when the subset of batches to remove corresponds to a contiguous sequence of batches in the training data.




Model Weights at Evaluation | Test Accuracy |
---|---|
Including only good data (upper bound) | 95.35% |
Before unlearning | 83.81% |
After unlearning | 92.77% |
5.6 Visualizing out of domain interpretability
In this section, we train a network containing 1 PFC block on an image classification task and then evaluate it on both in-domain and out-of-domain (OOD) inputs. Since the network produces reconstructed input features during the recognition process, it is interesting to visualize and compare these reconstructions on in-domain vs OOD inputs. We might expect that the learned weights would correspond to the parts of the in-domain images, so that the reconstruction quality should be better on in-domain vs OOD inputs. This is because when the network is given an OOD image, it is forced to attempt to reconstruct it using on the “in domain” parts, potentially reducing the reconstruction quality compared to in-domain inputs.
We now empirically investigate these effects using MNIST as the in-domain images and Fashion MNIST as the OOD images. We train a small network on the MNIST classification task using 100 weight templates to enable the easy visualization of all weights in a single plot figure. We use the combined input reconstruction and classification loss as usual, except that here we use an adjustable trade-off between the two in the form of a hyperparameter :
(5.1) |
Figure 4 shows the weights, reshaped into images, for models trained using different values of . As is increased, the strength of the classification loss increases and the reconstruction loss decreases. Thus, sub-figure 4(a) corresponds to a loss that emphasizes classification accuracy since only a small reconstruction loss is used. This is reflected in the good classification accuracy. We see that some of the weights resemble MNIST digits, but the images appear somewhat “noisy”. The middle sub-figure 4(b) shows the weights using the default blend of classification and reconstruction loss used in most of the other experiments. Here we see that the weights resemble MNIST digits or parts thereof. Finally, sub-Figure 4(c) shows the effect of a strong reconstruction loss. We see that the weights now appears as more localized image parts and that the classification accuracy is significantly lower.



We now compare the input reconstructions produce by the model for in-distribution vs OOD examples. For the following, keep set to the default value of 0.5. For the in-distribution visualization, we train on MNIST and also evaluate on a batch of MNIST test images, as shown in Figure 5. Sub-figure 5(a) shows a batch of input MNIST test images and sub-figure 5(b) shows the corresponding input reconstructions generated by the model. We see that many of reconstructed images resemble the corresponding inputs, although some are less recognizable.


We now move on to the visualization of the model’s reconstructed images when given OOD inputs. For this we will use a batch of images from the Fashion MNIST test set, which contains grayscale images of fashion products of the same size as MNIST images. Recall that the model is constrained to reconstruct an input image by solving for the best additive combination of its weights (i.e., using the 100 images in 4(b)). However, since the available weights correspond to MNIST images and/or their parts, we might expect the reconstructed Fashion MNIST inputs to have less resemblance to the actual images. Indeed, this is what we observe in Figure 6, which shows the OOD inputs and their reconstructions.


5.7 Residual PFC network for image classification
In this experiment we use more than one PFC block to build a more complex architecture. The main purpose of this is to verify that if there are no optimization difficulties then the network should produce similar accuracy as that 1-block network in Section 5.2. We construct a simple fully-connected residual network containing 2 PFC blocks in which the first block uses a skip connection. That is, we let represent the input to the first PFC block, which then produces output . The input to the second PFC block is then given by . The relu is optional when using the semi-NMF assumption but is needed when using the NMF constraint to prevent the input from becoming negative. The second PFC block then outputs the final prediction .
We train on the same datasets as in Section 5.2 using the same basis vector sizes (equivalent to MLP hidden dimension). Since there are two blocks, the total number of parameters is doubled from the 1-block model. Here we only train using the NMF modeling constraint (non-negative parameters) and we use both prediction and input reconstruction MSE loss terms for each PFC block. Table 6 shows the accuracy results on the test set. We see that the accuracy results appear in a similar range compared to those of the 1-block model in Table 2.
Dataset | Basis Vector Count | Test Accuracy |
---|---|---|
MNIST | 300 | 97.81% |
MNIST | 2000 | 98.30% |
MNIST | 5000 | 98.09% |
Fashion MNIST | 300 | 88.92% |
Fashion MNIST | 2000 | 89.89% |
Fashion MNIST | 5000 | 90.08% |
CIFAR10 | 300 | 54.00% |
CIFAR10 | 2000 | 56.26% |
CIFAR10 | 5000 | 58.15% |
5.8 Memorizing a deterministic sequence with a factorized RNN
We developed the factorized RNN in Section 3.2 and modeled it as the matrix factorization of the form shown in Eq. 3.9 which we repeat here:
(5.2) |
We also provided a more compact notation in which the sequences are replaced by their respective sub-matrices in Eq. 3.12, which we repeat here:
(5.3) |
We then provided two basic training procedures. Here we consider the first method which uses matrix factorization update rules for both inference and learning as described in Sections 3.3. It is initially unclear whether such simple update rules could even learn a task spanning several time slices, since there is no backward flow of error gradient-like information through the sequence. We therefore think it seems appropriate to start with a relatively simple temporal learning task.
5.8.1 Training on a repeating sequence
For this initial task, we present a fixed-length pattern that is repeated over and over in the training data and require the network to predict the next item in the sequence. Since we use a deterministic repeating pattern, the network only needs to identify and memorize this underlying pattern in order to predict with perfect accuracy. We also chose this task because we know the underlying generative model corresponds to a simple finite state machine (FSM) containing the same number of (deterministic) transitions as there are time slices in the repeating pattern. We therefore know the minimum number of model parameters that are needed in principle to solve it and it is straightforward to imagine interpretable solutions to the corresponding RNN factorization ourselves. This task therefore also serves as a simple interpretability test for the model since we can train the model and then visualization the learned weights and see whether they align with the interpretable solutions that we know should be possible.
Specifically, we use the following fixed repeating pattern consisting of 25 4-dimensional 1-hot vectors for easy presentation and visualization, shown here represented as integer-valued tokens for easier readability:
(5.4) |
We then repeat this pattern 8 times to create the full training sequence shown in Figure 7, which has a length of .

We want the model to memorize the repeating pattern and so it needs to predict the next vector in the sequence. For this we can use an autoregressive model so that in Eq 3.12. With this, Eq 3.9 looks like the following:
(5.5) |
We use NMF for the inference and learning updates. Specifically, we use SGD updates with negative value clipping and scaling of the rows and columns of the factor matrices to keep them from exploding as described in Section A. For the results in this section, we do not use any weight decay or sparsity regularizers so that only the non-negativity constraint is used. We initialize the weights to non-negative random values uniform in and initialize the hidden states ( and ) to zero. Recall from Section 3.2 that the hyperparameter specifies both the dimensionality of the hidden state vectors as well as the number of basis column vectors in so that is an x sub-matrix of . Since the repeating sub-sequence has length 25, that seems to be the minimum value of that could have any chance of learning the state transition model. We then run the training procedure using alternating the NMF inference and learning update rules as described in Section 3.3. We ran several training runs for values of ranging between 25 and 500.
5.8.2 Interpretation of the learned weights
We observed some interesting and surprising results when training with different values of . The first is that training seems to become faster and more reliable as is increased. The model was consistently able to learn an exact or nearly exact factorization (training MSE below 1e-7 or so) for around 100 or larger. However, we saw training gradually became less reliable as was decreased toward the lower limit of 25, often requiring multiple training attempts to successfully learn the factorization. Figure 8 shows one such unsuccessful training run with , where the training MSE only converged to 0.112. Still, training was still sometimes successful even at the limit value . Since we only observed unreliable training with small values of , we did not attempt any hyperparameter optimization in order to fix it.

Perhaps our most surprising observation was that the learned models tended to be highly sparse and interpretable even though we did not use any sparsity regularization. Regardless of the value of , a successful training run resulted in the model discovering that only 25 basis vectors were actually needed in , with the other columns tending toward 0. Additionally, the learned columns of the state transition matrix appeared close to 1-hot vectors as Figure 9 shows. Such a learned representation lets us understand the underlying state transition model that was used to generate the repeating sub-sequence by quick visual inspection of the model weights. Recall that for any particular activated basis column in , the corresponding column of bottom sub-matrix additively reconstructs the input . Similarly, the same column of additively reconstructs the previous state , and the same column of additively reconstructs , which is the prediction for the next input . As a concrete example, consider the right-most basis vector in Figure 9(a). Since has a 1 in the second row and has a 1 in the third row, this corresponds to explaining a 1-hot input vector similarly having a 1 in its second dimension and predicting that it will transition to a 1-hot vector having a 1 in its third dimension in the next time step. In the integer 1-hot sequence representation, this would correspond to . The same column of is additively reconstructing the previous hidden state (as a 1-hot vector with a 1 in its first dimension). When the right-most basis vector (i.e., column index = 24) of is activated by a column in such as , it causes the inferred next state to have a 1 in its last dimension. This inferred then becomes the previous input state in the next time slice, and we see that the third basis column of from the left has a 1 in the final dimension (i.e., = 1) which would cause this basis column to be activated in the next time slice . Continuing in this way, we can easily read out the underlying transition model from inspection of . Note that the ordering of the basis vectors in is significant because activating column of results in a corresponding positive activated value in dimension of the inferred next state vector. With this understanding, it makes sense that the learned would tend to be sparse even without any explicit sparsity regularization. Each activated basis column of becomes a positive entry in the corresponding dimension of the input previous state in the next time slice (recalling the inferred states in are copied into the next time slice of ), which in turn needs to be explained as an additive combination of the basis vectors in . That is, any activated columns in translate to corresponding non-zero (positive-valued) rows in the next time slice’s input state vector. If contains a non-zero column that has multiple positive values in different rows and this column is activated by , it implies multiple columns of must have been activated in the previous time slice. Consider also the case where there are duplicated columns in . The NMF inference algorithm might then choose to activate both of them with some positive strength, again resulting in a non 1-hot that would in turn need to be explained in the next time slice. Intuitively, it then seems to make sense that the model would tend toward the sparsest possible learned representation.



5.8.3 Evaluation and visualization
With the model weights trained, we can now verify that it has successfully memorized the sequence. We will provide a short “seed” sequence that contains part of the repeating sub-sequence and then ask the model to generate a continuation of it. For this evaluation, we use a model that was trained with dimensional hidden state vectors. Figure 10 shows the seed sequence, for which we use the first 15 time slices. We will generate 50 additional time slices after the seed sequence. We then initialize the and sub-matrices of so that they only contain the seed as the initial part at the left as shown in Figure 11.


We then run the inference procedure starting from the first time slice, since the hidden states need to be inferred for all time slices. The generation procedure works as follows. For each time slice , we iterate the NMF updates until the current state vector converges. We copy the inferred into the next time slice of the sub-matrix of . If is less than the seed length, we leave as is (since it is part of the seed). Otherwise, we also update the current . We also copy this predicted into the position in the next time slice of , which serves to propagate the predicted sequence vectors forward. However, we should note that we do not sample from the predicted and instead simply copy the predicted vector directly into the following time slice. We then increment to the next time slice and so on until finally reaching the end of the generated sequence length. Figure 12 shows the resulting generated sequence including the seed. From this we see that the model can successfully generate the memorized sequence, with a small amount of noise which is seen as slight yellow or red in the “hot” colormap that we used. Figure 13 shows all three matrices corresponding to the factorization in Eq. 3.12 after generating the sequence from the seed.




5.9 Copy task
We now test the factorized RNN and compare to the vanilla RNN on the more difficult Copy Task [32] using the setup from [33]. This involves generating a random sequence of tokens which are supplied one at a time in each time slice to the network. After this we supply a padding token for some fixed number of time slices. We then supply a “remember” token for a number of time slices equal to the length of input sequence that was supplied earlier, during which the network must attempt to output the remembered tokens of the input sequence. This task thus tests the network’s ability to recall information that was presented earlier and it gets more difficult as the task spans longer time intervals. We use the same parameters as [33] in which the input sequence is of length 10 and there are 10 distinct token values, which we supply to the network as 1-hot vectors. This combined with the pad and remember tokens lead to 12 distinct token values in total, so that the input 1-hot vectors will be 12-dimensional. We can then still adjust the difficulty by controlling the padding length .
5.9.1 Training details for the factorized RNN
We used the same model and training hyperparameters as in Section 5.8. They only differ in how the loss is applied. In the model of Section 5.8, the RNN outputs a prediction at every time slice. However, in the copy task we only care about the predicted tokens during the time slices when the “remember” token is being supplied and so we only compute the prediction loss of these time slices here. Note that since we are using the NMF learning (i.e., NMF update rule) to learn the weights, this corresponds to having implicit MSE reconstruction loss terms on all of the hidden states and inputs as well. We set the hidden state dimension to . We used 100 NMF iterations per time slice for inference.
With the factorized RNN and non-negative parameters (i.e., NMF), most training runs resulted in perfect validation accuracy for . With , multiple training runs were needed to reach perfect accuracy. For , accuracy was no better than chance level. When we tried allowing negative parameters, the models failed to do better than chance accuracy. These results seem interesting in that the network is able to successfully learn the task with up to 10 padding tokens, even though there is no BPTT-like backward flow of error gradients.
5.9.2 Training details for the vanilla RNN
For comparison, we also train a vanilla RNN on the same task. For this we use the network described in Section 3.1 except that we used LayerNorm on the hidden states as suggested for RNNs in [7]. We use the GELU activation. We used the same hidden state dimension as the factorized network. Negative-valued parameters were allowed as an all other experiments with MLP-based models. We observed that weight decay of 1e-4 was needed for the best results, although we did not do much hyperparameter tuning.
With the vanilla RNN, we were also able to get perfect accuracy for when BPTT was used. We found that perfect accuracy was possible up to approximately . However, when we disabled BPTT, we were not able to get perfect accuracy for any padding length. We were not able to do much hyperparameter tuning, though, and so it remains possible that performance could improve with better tuning.
5.10 Sequential MNIST classification
The Sequential MNIST classification task is a common task used for evaluating the performance of RNNs. It adapts the MNIST dataset [27], which consists of 28x28 pixel grayscale images of handwritten digits (0 to 9), for a sequential data processing context. In the standard MNIST task, the entire image is presented to the model at once, as in Section 5.2. However, in the Sequential MNIST task, the image is presented as a sequence of pixels, typically in a row or column-wise manner. We evaluate the column-wise version in this experiment. The image is unrolled column by column, resulting in a 28-time-slice sequence, where each slice is a 28-dimensional vector representing one column of pixels. Each time slice of the RNN outputs a 10-dimensional class prediction vector, although we only compute the loss on the final time slice of the sequence, ignoring the model predictions from earlier time slices.
5.10.1 Training details
For the factorized RNN, we used 10 unrolled inference iterations since it allowed us to run more experiments. We noticed only slightly reduced accuracy in one experiment compared to using 50-100 iterations and so we made the trade-off based on our limited computational resources. We used a learning rate of 5e-5 and weight decay of 1e-5 for all of the training runs and for both the factorized and vanilla RNNs. Similar to the other experiments, we used the input reconstruction loss on the factorized model. We trained under both the NMF and semi-NMF parameter constraints. We trained both the factorized and vanilla RNNs with and without BPTT, and for hidden state dimensions of 512 and 2048.
5.10.2 Results
Table 7 shows the accuracy results on this task. For the results with BPTT disabled, the factorized RNN has significantly higher accuracy compared to the MLP (96.23% vs 83.33% at hidden dimension = 512, and 97.16% vs 94.06% at hidden dimension = 2048), although this required the semi-NMF parameter constraint. Using the NMF constraint without BPTT produced significantly worse accuracy. We were somewhat surprised to see both the factorized and vanilla RNNs performing so well without BPTT. Both models performed better with BPTT enabled, but only slightly, and we see that the factorized RNN slightly outperformed the vanilla RNN. Interestingly, when using BPTT, the factorized RNN performed similarly under both the NMF and semi-NMF constraints. Notice also that when BPTT is disabled, both the factorized and vanilla RNNs suffer a more severe accuracy degradation in going from 2048 to 512 hidden dimension size, compared to when BPTT is enabled; It seems that disabling BPTT could be making the models less parameter efficient.
Model | BPTT | Hidden State Dimension | Parameter Constraints | Test Accuracy |
---|---|---|---|---|
Factorized RNN | No | 512 | NMF | 77.43% |
Factorized RNN | No | 512 | semi-NMF | 96.23% |
Factorized RNN | No | 2048 | NMF | 75.81% |
Factorized RNN | No | 2048 | semi-NMF | 97.16% |
Vanilla RNN | No | 512 | N/A | 83.33% |
Vanilla RNN | No | 2048 | N/A | 94.06% |
Factorized RNN | Yes | 512 | NMF | 98.56% |
Factorized RNN | Yes | 512 | semi-NMF | 98.16% |
Factorized RNN | Yes | 2048 | NMF | 98.66% |
Factorized RNN | Yes | 2048 | semi-NMF | 99.00% |
Vanilla RNN | Yes | 512 | N/A | 97.94% |
Vanilla RNN | Yes | 2048 | N/A | 98.66% |
5.11 Audio source separation on MUSDB18
Audio tends to be well modeled as a mixture of source components. Much of the music we listen to is explicitly mixed together from isolated recordings (often called stems). In the source separation problem, we are interested in separating multiple sources that have been mixed together. For example, suppose we have a recording by a small band in which the vocals, guitar, and drums have all been captured together in a single audio channel. A source separation system would then be tasked with taking this recording as input and producing an output containing only a single desired source such as the vocals.
A useful inductive bias in this case would be to have the model satisfy the linear superposition property: . Any scaled mixture of the two sources should produce the corresponding scaled output. This is a desirable property since it implies that if the model is capable of recognizing the sources in isolation, it automatically generalizes to handling the mixture case. Such a property could potentially allow the model to generalize better from limited training data.
NMF can potentially satisfy this property, and there are several existing works in which it has been applied to related audio problems [34]. Since our factorized RNN is essentially a sequential extension of matrix factorization, it seems interesting to apply it to the source separation task and compare its generalization performance against a standard vanilla RNN (which does not have the superposition property due to its use of non-linear activation functions). We train both factorized and vanilla RNNs on the source separation task using the MUSDB18 dataset. MUSDB18 contains 150 music tracks of various genres corresponding to approximately 10 hours of music, split into a training dataset containing 100 tracks and a test dataset containing 50 tracks. It includes the isolated tracks (stems) for drums, bass, other, and vocals. This makes it straightforward to use for training and evaluating source separation models.
5.11.1 Training details
For training we mix together the isolated stems to create the inputs and then supply just one of the stems as the target. In this experiment we chose to use vocals as the target. The audio is aligned during validation and testing but not aligned during training to provide more variation in the examples. During training only, each source is scaled by random per-example values between 0.5 and 2.0.
We used the same MSE output prediction loss (i.e., regression loss) for both models to ensure that the resulting MSE test loss would be comparable between the factorized and vanilla RNN. That is, the factorized RNN only used the MSE output prediction loss term, without the input reconstruction loss term in this experiment. For the audio features, we use the following hyperparameters: The sample rate was 44100 Hz, the short-time Fourier transform (STFT) used a 2048 window size and 1024 hop size. We limited tha audio feature vectors (time slices of the STFT) to contain only the lowest 350 frequency bins. Negative parameter values were allowed in both models. BPTT was used in both models. The hidden state dimension was 1024 in both models. The batch size was 100. The learning rate was 5e-4 in the factorized RNN as 1e-4 for the vanilla RNN. Weight decay was 5e-5 for both models.
5.11.2 Results
The vanilla RNN had a MSE test loss of 3.657e-3. The factorized RNN had a test loss of 2.497e-3. These were averaged over 3 runs. Both models seem to be underfitting since the training loss converges to a similar range as the validation loss. Although the factorized RNN produced a better (lower) test loss compared to the vanilla RNN, we should note that both models perform significantly below state of the art on this dataset due to the use of a the simple RNN architecture.
6 Related work
6.1 Positive Factor Networks
Perhaps the single most related existing work is our previous work on Positive Factor Networks [18], which similarly proposed NMF-based building blocks that can be composed to build expressive and interpretable architectures. Although the factorized equivalent of the vanilla RNN that we introduce in this paper was not considered, various other factorized-RNN-like architectures (which were referred to as dynamic positive factor networks) were considered, including some more sophisticated architectures such as a sequential data model for target tracking, and a model employing dilated deconvolutional layers in a Wavenet-like [35] architecture. A block with the same factorized model as our PFC block appears in Eq 3.1 of [18], which was referred to a “coupling module” or “coupling factorization” in that paper. Similar iterative NMF algorithms were used to perform inference in these models. However, a key distinction between that work and our present work is that it did not use backpropagation for learning the parameters, and instead relied on applying the NMF left-update rules to learn them. We also experimented with NMF left-update rules in the current work for learning the parameters as an ablation in Sections 5.8, 5.9, 5.10, but found them to underperform backpropagation on the more challenging learning tasks. As a result of not using backpropagation for parameter learning, the models presented in [18] failed to produce results competitive with other approaches on supervised learning tasks. The PFC block and corresponding neural networks constructed from them that we consider in the present work can therefore be considered as unrolled positive factor networks, or positive factor networks employing backpropagation-based training.
6.2 Non-negative matrix factorization
NMF was originally proposed by [14] as positive matrix factorization and later popularized by [15] [17]. NMF is related to other methods such as sparse coding [36] that use a similar dictionary learning model but with differences in the modeling constraints, regularizers, and/or algorithms used. NMF is often observed to provide parts-based decompositions of data, which can be useful when interpretable decompositions are desired. It also potentially satisfies the (non-negative) linear superposition property and as a result has been applied to audio processing tasks such as source separation and music transcription in which the audio features in the input data matrix are assumed to be well modeled as an additive combination of audio sources (i.e., individual instruments and/or notes) [34].
Our PFC block is related in that its declarative model corresponds to a masked predictive NMF in which the data matrix is partitioned along the row dimension into “input” and “output” vectors of the block, with the right factor matrix corresponding to inferred hidden activations. The output vectors as masked during recognition (during inference of ) and the resulting inferred is then used to predict the outputs. This allows our block to retain the additive superpositional NMF model, while also supporting the construction of more expressive differentiable architectures such as factorized RNNs. The resulting neural networks (or positive factor networks) can then be interpreted as an extension of NMF that increases its modeling expressiveness and suitability for supervised learning tasks.
6.3 Unrolled neural networks
The key enabler of the increased predictive performance of our present work compared to the Positive Factor Networks of [18] is the modification of the learning algorithm to use backpropagation instead of relying on NMF left-update steps. As discussed in Section 2.4, backpropagation-based training is enabled by unrolling the iterative NMF inference steps into an RNN-like structure in the computation graph, making the PFC block differentiable and therefore compatible with backpropagation training.
This process of unrolling an iterative optimization algorithm into a neural network and training with backpropagation is referred to as algorithm unrolling or unrolled neural networks in the literature [25]. An existing example of unrolled NMF appears in [37]. The idea that improved results could be obtained on supervised tasks by adding an additional task-specific loss function to an iterative and differentiable optimization algorithm and using it to optimize the parameters was proposed in [38] and [39], with earlier related ideas being proposed in [40] and [41]. More recent works have explored the use of unrolled convolutional sparse coding for improved robustness to corrupt and/or adversarial input images [42], [43]. Of these, [44], [43] also use FISTA to accelerate the convergence of the unrolled inference. As far as we are aware, ours is the first work to explore the use of a modular unrolled block (PFC block) supporting the construction of arbitrary neural architectures, unrolled factorized RNNs, and the first to explore the interpretability advantages of unrolled NMF-based networks for continual learning.
6.4 Nearest neighbor classification and regression
As discussed in Section 2.2, our PFC block shares some similarities with classification and regression based on the k-nearest neighbors (k-NN) algorithm [10]. We show that the PFC block can be derived by first formulating k-NN prediction as a matrix-vector product, and then generalizing and interpreting it as a matrix factorization.
6.5 Learning Vector Quantization (LVQ)
Learning Vector Quantization (LVQ) [11] [12] is a prototype-based method that extends the k-nearest neighbors algorithm by introducing the concept of learnable prototypes. An advantage of this approach over matrix-factorization-based methods is its faster recognition since an iterative algorithm is not used. We are not aware of a fully differentiable version of LVQ that could match MLP predictive performance and make it possible to use as a building block for more complex architectures such as RNNs, however.
6.6 Future research
This work is preliminary and we leave several unanswered questions and possibilities for future research directions. We list some of them here:
-
•
As discussed in [45], methods such as ours that perform recognition by iterating to a fixed point can potentially make use of adaptive computation. This could potentially be used to improve recognition efficiency, as well as providing an additional form of confidence estimation, since “difficult” inputs could require more iterations to converge.
-
•
When unrolling the NMF inference updates, we observed that memory usage can potentially be reduced by computing the initial several iterations “without gradients” and only unrolling the last several iterations “with gradients”. In some cases, this resulted in improved efficiency, but in others it resulted in reduced accuracy, and so a better understanding is still needed.
-
•
Although we found FISTA to be one effective method in accelerating the NMF inference step of the PFC block, it could be interesting to also consider other approaches to further accelerate the inference.
-
•
Our sliding learnable window optimizer introduced in Section 4 enabled improved continual learning in PFC-based models. However, more sophisticated approaches based on basis-vector-specific learning rates could potentially lead to improved continual learning performance and/or parameter efficiency. For example, it could be interesting to consider methods that identify the important or in-use basis vectors and reduce their learning rates accordingly so as to make them more resistant to being overwritten or modified later in training.
-
•
We did not consider any form of sparsity regularization (e.g., L1 penalty) in our experiments, leaving regularization methods other than the non-negativity constraint and weight decay unexplored in the current work.
-
•
As Table 7 shows, disabling BPTT in both factorized and vanilla RNNs often resulted in only a small drop in accuracy, although more parameters were needed to recover accuracy. It is also unclear why the semi-NMF parameter constraint resulted in better accuracy compared to NMF when BPTT was disabled. It could be interesting to conduct a more detailed investigation as future research.
-
•
The declarative model of the PFC block in Eq. 2.10 is symmetric with respect to the inputs and outputs. This allows us to potentially reverse the prediction direction of the block so that we can consider swapping their roles and make the block compute the “inputs” given the “outputs”. In addition to reversible blocks/layers, it could be interesting to explore the case where only some subset of the input and output are observed and the block then jointly predicts all unobserved elements.
-
•
As discussed in Section 2.4, the PFC block can be interpreted as performing factorized attention over parameters. It could be interesting to also consider its use as an attention block that performs factorized attention over key and value activations (output from other upstream blocks) instead of over parameters, making it more like a factorized version of the attention block used in the transformer. Specifically, in Eq. 2.10 the matrix would be replaced by a keys matrix and would be replaced with a corresponding values matrix which would then represent output activations from other layers/blocks instead of learnable parameters.
-
•
Note that the PFC block has the inductive bias of NMF, which is different than the MLP. In our experiments, we observed it to perform similar to and sometimes better compared to the MLP. However, it is possible that the NMF inductive bias could make PFC-based models particularly well suited or poorly suited depending on the modeling assumptions of the datasets used. Also note that relaxing the non-negativity constraint to semi-NMF could support the use of datasets containing negative values, at the expense of possibly reduced interpretabiity, however. We leave such an exploration as future research.
7 Conclusion
In this paper, we presented the Predictive Factorized Coupling (PFC) block, a neural network building block that combines the interpretability of non-negative matrix factorization (NMF) with the predictive performance of multi-layer perceptron (MLP) networks. We demonstrated the versatility of the PFC block by using it to build various architectures, including a single-block network, a fully-connected residual network containing two PFC blocks, and a factorized RNN.
Our experiments showed that the PFC block achieves competitive accuracy with MLPs on small datasets while providing better interpretability. We also demonstrated the benefits of the PFC block in continual learning, training on non-i.i.d. data, and knowledge removal after training. Additionally, we showed that the factorized RNN outperforms vanilla RNNs in certain tasks while providing improved interpretability.
While the PFC block has limitations, such as slower training and inference and increased memory consumption during training, it offers a promising direction for developing more interpretable neural networks without sacrificing predictive performance. Future work includes evaluating the PFC block on larger datasets and exploring its suitability for use in more complex multi-block architectures.
Appendix A Alternating SGD updates for NMF
One of the simplest methods that can be used to solve for the factors and in Equation (2.1) is gradient descent (GD) or stochastic gradient descent (SGD). Whether GD or SGD applies depends on whether we are operating on the full training data or on batches, and so we will simply refer to both cases as SGD in the following. In SGD, we first choose a suitable loss function to quantify the approximation error and then compute its gradients with respect to each factor matrix. We alternately apply small additive update steps in the opposite direction of the gradient to each factor matrix, while clipping any negative values to zero, until convergence. A commonly used choice of loss is the squared Euclidean error and so we use it here. The squared error loss between and the approximation is given by:
(A.1) |
where is the approximation error of the ’th element of . Let denote the approximation error matrix with elements . Since is differentiable with respect to and , the loss gradients are:
(A.2) | |||
(A.3) |
The resulting SGD updates are then given by:
(A.4) | |||||
(A.5) |
where and are learning rate hyperparameters which are set to a small nonnegative value. The function is used to prevent negative values in the updated matrices. Substituting the gradients in the above updates gives:
(A.6) | |||||
(A.7) |
We then initialize and with non-negative noise and iteratively perform the updates until convergence.
A.1 Normalization and preventing numerical issues
We observed that numerical issues can sometimes be reduced by additionally replacing any zero-valued elements in the factor matrices with some small minimum allowable non-negative value immediately after the in Eqs. A.6 A.7, such as , for example.
We introduce the following two normalization methods, which we have observed to often perform well in our experiments, depending on whether we are performing alternating updates to jointly learn both and or performing unrolled inference to infer only , followed by backpropagation-based learning of .
A.1.1 Normalization for the joint factorization case
Note that we can scale by an arbitrary value if we also scale by its inverse so that their product is unchanged. Depending on the update algorithm used, it is possible that one of the factor matrices might be scaled slightly larger on each update while the other is scaled slightly smaller so that one of and/or often tends toward infinity while the other tends toward zero. To avoid this problem, one of the factor matrices (typically ) is typically normalized to have e.g. unit column norm after each update [46]. However, this prevents the basis vectors from becoming arbitrarily small, potentially resulting in a less sparse and/or less interpretable solutions.
We have empirically observed that optimization can be easier when the range of values in and are similar. We use the Numpy/PyTorch notation in the follow where “” denotes selecting all rows or columns in the dimension where it is applied. Intuitively, if a particular column does not contribute significantly to the approximation of , then we would like its values as well as the corresponding activations in row to be small and vice versa. We use the following normalization algorithm to achieve this.
For each column in , we compute its maximum value along with the maximum value of the corresponding row in . We then compute the mean of these two values. We scale the column and row so that their updated maximum value will be equal to this mean. We have found this simple algorithm to work well in our experiments. After training, it tends to result in both and having nearly identical maximum values, sparser learned factorizations and/or improved approximation error compared to the other normalization methods that we tried. We must be careful if mini-batch training is used, though, as the above maximum values are intended to be computed over the full and matrices.
A.1.2 Normalization for unrolled inference with backpropagation
For the unrolled algorithm in which backpropagation is used together with another optimizer (e.g., RMSprop) to update , we have found the following update method to perform well. The basis idea is that for each column in the data matrix , we limit the maximum value of the corresponding inferred column in such that it is not allowed to have a maximum value larger than the maximum value in . This normalization step involves computing the column-wise maximum values in at the start of inference. After each unrolled NMF right-update to matrix , we then apply a column scaling step to scale the columns such that their corresponding maximum values do not exceed those of (if the maximum value is already less, then no scaling if performed). This is intended to prevent the inferred values in from exploding during the unrolled inference. Recall that backpropagation and another optimizer are used to update . We observed that it was not necessary to explicitly apply any normalization to since we did not encounter exploding values. Although weight decay was used in the experiments, it was not needed in order to prevent numerical issues, and was only used due to its observed slightly beneficial effect on predictive performance in some cases.
A.2 Automatic learning rate selection
We empirically observe that setting the SGD learning rates and in Eq. A.6 automatically using the same method as in FISTA often works well. Specifically is set to where is the largest eigenvalue in . Likewise, is set to where is the largest eigenvalue in . In our experiments, we used the power method to compute the approximate largest eigenvalue.
Appendix B Implementation details for FISTA-accelerated NMF inference in the PFC block
We adapt the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) to accelerate the NMF inference procedure in the PFC block. FISTA is an optimization method that combines gradient-based approaches with an acceleration technique [23]. Originally designed for sparse recovery, FISTA can be adapted for various applications, including matrix factorization. We use rectification (relu) as the proximal operator instead of the usual shrinkage threshold operator. The resulting algorithm uses the SGD NMF right-update steps with the FISTA steps that compute the momentum term. Both NMF and semi-NMF constraints are supported.
Given an input matrix , we aim to infer the hidden matrix in the factorization . We assume that the weights remain fixed during the inference procedure. In the standard SGD-based NMF procedure reviewed in Section A, the inference procedure amounts to the repeated application of the update step in Eq A.6 followed by the normalization steps and FISTA momentum update steps. The procedure is as follows:
-
1.
Initialize and set .
-
2.
For each iteration , update by:
-
(a)
Updating as
-
(b)
Applying normalization scaling to .
-
(c)
Updating as .
-
(d)
Updating as .
-
(a)
-
3.
Repeat the process until convergence is achieved.
We estimate the Lipschitz constant using the power method. The normalization scaling step is described in Section A.1.2. This step may not always be needed, but we leave it enabled in all experiments to prevent numerical issues in .
References
- [1] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” 2017. Version 3.
- [2] F. Rosenblatt, “The perceptron: a probabilistic model for information storage and organization in the brain.,” Psychological review 65(6), p. 386, 1958.
- [3] A. G. Ivakhnenko, V. G. Lapa, et al., “Cybernetic predicting devices,” (No Title) , 1965.
- [4] S. Amari, “A theory of adaptive pattern classifiers,” IEEE Transactions on Electronic Computers (3), pp. 299–307, 1967.
- [5] G. Cybenko, “Approximation by superpositions of a sigmoidal function,” Mathematics of Control, Signals, and Systems 2(4), pp. 303–314, 1989.
- [6] K. Hornik, M. Stinchcombe, and H. White, “Multilayer feedforward networks are universal approximators,” Neural Networks 2(5), pp. 359–366, 1989.
- [7] J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer normalization,” 2016.
- [8] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov, “Dropout: a simple way to prevent neural networks from overfitting,” The journal of machine learning research 15(1), pp. 1929–1958, 2014.
- [9] M. McCloskey and N. J. Cohen, “Catastrophic interference in connectionist networks: The sequential learning problem,” in Psychology of learning and motivation, 24, pp. 109–165, Elsevier, 1989.
- [10] T. M. Cover and P. E. Hart, “Nearest neighbor pattern classification,” IEEE Trans. Inf. Theory 13, pp. 21–27, 1967.
- [11] T. Kohonen, “The ”neural” phonetic typewriter,” IEEE Computer 21(3), pp. 11–22, 1988.
- [12] T. Kohonen, “Improved versions of learning vector quantization,” in 1990 ijcnn international joint conference on Neural networks, pp. 545–550, IEEE, 1990.
- [13] G. M. van de Ven and A. S. Tolias, “Three scenarios for continual learning,” 2019.
- [14] P. Paatero and U. Tapper, “Positive matrix factorization: A non-negative factor model with optimal utilization of error estimates of data values,” Environmetrics 5(2), pp. 111–126, 1994.
- [15] D. Lee and H. Seung, “Learning the parts of object by non-negative matrix factorization,” Nature 401, pp. 788–791, 1999.
- [16] C. H. Ding, T. Li, and M. I. Jordan, “Convex and semi-nonnegative matrix factorizations,” IEEE transactions on pattern analysis and machine intelligence 32(1), pp. 45–55, 2008.
- [17] D. Lee and H. S. Seung, “Algorithms for non-negative matrix factorization,” Advances in neural information processing systems 13, 2000.
- [18] B. K. Vogel, “Positive factor networks: A graphical framework for modeling non-negative sequential data,” arXiv preprint arXiv:0807.4198 , 2008.
- [19] D. E. Rumelhart, G. E. Hinton, and R. J. Williams, “Learning representations by back-propagating errors,” nature 323(6088), pp. 533–536, 1986.
- [20] Y. LeCun, Learning processes in an asymmetric threshold network. PhD thesis, University of Paris 6, 1985.
- [21] P. J. Werbos, Beyond regression: New tools for prediction and analysis in the behavioral sciences. PhD thesis, Harvard University, 1974.
- [22] D. Hendrycks and K. Gimpel, “Gaussian error linear units (gelus),” arXiv preprint arXiv:1606.08415 , 2016.
- [23] A. Beck and M. Teboulle, “A fast iterative shrinkage-thresholding algorithm for linear inverse problems,” SIAM journal on imaging sciences 2(1), pp. 183–202, 2009.
- [24] G. Hinton, “Lecture 6a overview of mini-batch gradient descent.” https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf, 2018.
- [25] V. Monga, Y. Li, and Y. C. Eldar, “Algorithm unrolling: Interpretable, efficient deep learning for signal and image processing,” IEEE Signal Processing Magazine 38(2), pp. 18–44, 2021.
- [26] J. L. Elman, “Finding structure in time,” Cognitive science 14(2), pp. 179–211, 1990.
- [27] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner, “Gradient-based learning applied to document recognition,” Proceedings of the IEEE 86(11), pp. 2278–2324, 1998.
- [28] H. Xiao, K. Rasul, and R. Vollgraf, “Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms,” 2017.
- [29] A. Krizhevsky, V. Nair, and G. Hinton, “Learning multiple layers of features from tiny images,” tech. rep., University of Toronto, 2009.
- [30] F. Zenke, B. Poole, and S. Ganguli, “Continual learning through synaptic intelligence,” 2017.
- [31] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al., “Overcoming catastrophic forgetting in neural networks,” Proceedings of the national academy of sciences 114(13), pp. 3521–3526, 2017.
- [32] S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural computation 9(8), pp. 1735–1780, 1997.
- [33] M. Arjovsky, A. Shah, and Y. Bengio, “Unitary evolution recurrent neural networks,” 2016.
- [34] G. Grindlay and D. P. W. Ellis, “Multi-voice polyphonic music transcription using eigeninstruments,” in IEEE Workshop on Applications of Signal Processing to Audio and Acoustics, pp. 53–56, 2009.
- [35] A. v. d. Oord, S. Dieleman, H. Zen, K. Simonyan, O. Vinyals, A. Graves, N. Kalchbrenner, A. Senior, and K. Kavukcuoglu, “Wavenet: A generative model for raw audio,” arXiv preprint arXiv:1609.03499 , 2016.
- [36] B. A. Olshausen and D. J. Field, “Sparse coding with an overcomplete basis set: A strategy employed by v1?,” Vision research 37(23), pp. 3311–3325, 1997.
- [37] R. Nasser, Y. C. Eldar, and R. Sharan, “deep unfolding for non-negative matrix factorization with application to mutational signature analysis,” 2021.
- [38] J. Mairal, F. Bach, and J. Ponce, “Task-driven dictionary learning,” IEEE transactions on pattern analysis and machine intelligence 34(4), pp. 791–804, 2011.
- [39] J. T. Rolfe and Y. LeCun, “Discriminative recurrent sparse auto-encoders,” arXiv preprint arXiv:1301.3775 , 2013.
- [40] Y. Bengio and F. Gingras, “Recurrent neural networks for missing or asynchronous data,” Advances in neural information processing systems 8, 1995.
- [41] H. S. Seung, “Learning continuous attractors in recurrent networks,” Advances in neural information processing systems 10, 1997.
- [42] J. Sulam, R. Muthukumar, and R. Arora, “Adversarial robustness of supervised sparse coding,” Advances in neural information processing systems 33, pp. 2110–2121, 2020.
- [43] M. Li, P. Zhai, S. Tong, X. Gao, S.-L. Huang, Z. Zhu, C. You, Y. Ma, et al., “Revisiting sparse convolutional model for visual recognition,” Advances in Neural Information Processing Systems 35, pp. 10492–10504, 2022.
- [44] X. Sun, N. M. Nasrabadi, and T. D. Tran, “Supervised deep sparse coding networks,” in 2018 25th IEEE International Conference on Image Processing (ICIP), pp. 346–350, IEEE, 2018.
- [45] T. Achler, “A flexible online classifier using supervised generative reconstruction during recognition,” 2012.
- [46] W. Liu, N. Zheng, and Q. You, “Nonnegative matrix factorization and its applications in pattern recognition,” Chinese Science Bulletin 51, pp. 7–18, 2006.