Debugging using Orthogonal Gradient Descent
Abstract
In this report we consider the following problem: Given a trained model that is partially faulty, can we correct its behaviour without having to train the model from scratch? In other words, can we “debug" neural networks similar to how we address bugs in our mathematical models and standard computer code. We base our approach on the hypothesis that debugging can be treated as a two-task continual learning problem. In particular, we employ a modified version of a continual learning algorithm called Orthogonal Gradient Descent (OGD) to demonstrate, via two simple experiments on the MNIST dataset, that we can in-fact unlearn the undesirable behaviour while retaining the general performance of the model, and we can additionally relearn the appropriate behaviour, both without having to train the model from scratch.
1 Introduction
While the field of machine learning has seen incredible progress in the past decade, in our view, there is an important problem that has not received enough traction: the problem of “debugging" machine learning models. That is, when it is discovered that a trained model is partially faulty – that it responds to a portion of the inputs in an incorrect manner, perhaps due to a dataset bias or due to insufficient training data – a very common method of debugging is to simply train the model from scratch. While this naive approach works well, retraining models can be prohibitively expensive, with many of the state-of-the art models requiring several weeks and multiple GPUs to train. Finances and the time overhead aside, as a matter of principle, starting from scratch is unsatisfactory. For instance, we rarely address bugs in our mathematical models or standard computer code by throwing everything away and starting with a clean slate – and neither do we abandon a child every time they misbehave or say unpleasant things!
With the above concerns in mind, we explore the problem of debugging machine learning models without having to train the models from scratch. We do so by basing our approach on the claim that the debugging problem can be treated as a two-task continual learning problem. More specifically, we employ a modified version of a continual learning method known as Orthogonal Gradient Descent [4], and using a simple convolutional neural network trained on a (purposefully) corrupted version of the MNIST dataset we show that (1) it is possible to unlearn the way the model responds to the misclassified data points without affecting the general performance; and (2) on top of unlearning, it is also possible to relearn the right responses to misclassified inputs and attain the maximum possible test accuracy (with the simple network), without training from scratch.
2 Background
In this section we talk about Orthogonal Gradient Descent (OGD), the method that underlies our approach. While continual learning, in general, focuses on the problem of sequential learning of many tasks , here we discuss the method with just two-tasks in mind, . We use to denote the neural network, where is the input to the network and are the weights that parametrize the model. The network takes in an input and outputs its prediction . In case of classification problems, denotes the -th logit associated to the -th class. The loss function (for a single task) is denoted as
where the sum is over the training data and denotes the ground truth.
When discussing OGD, it is crucial to distinguish between two types of gradients: first, we have the gradient of the neural network, , and then we have the gradient of the loss function, , which is the one encountered more frequently in machine learning applications as this is used to minimize the loss.
With the preliminaries established, we can now describe the OGD algorithm in some detail. Please refer to the Appendix (1) or the original paper for more information [4].
1. Computing Model Gradients After training the model on the first task in the usual manner, we first compute the model gradients of the class that corresponds to the ground truth label, that is, if an belongs to the -th class, then we compute and store only . On the MNIST dataset, the authors recommend computing the gradients at around 300 data points randomly sampled from the training set and thus the memory requirements are reasonable – note that the dataset has about 60k training points.
2. Orthogonal Basis Once we have the model gradients = {}, we then compute the orthogonal basis for using the Gram-Schmidt method on all gradients. From here on, we take to denote the orthonormal set.
3. Computing Loss Gradients Third, we compute the gradients of the loss function on the data points that belong to the second task :
4. Projection We then project the gradient of the loss function computed in the previous step onto the subspace that is orthogonal to the span of the model gradients
5. Parameter Update Finally, we update the parameters of the model using the new projected vector:
where we set is the step-size. In our experiments we use .
We repeat steps 3-5 until the model converges. In this report, we make slight modifications to the original approach and the modified approach is described in later sections.
3 Related Work
The problems discussed in this report are related to the ones explored in the field of machine unlearning [2]. This nascent field has been motivated in part by the recently introduced privacy regulations such as the General Data Protection Regulation in the European Union[11], the California Consumer Privacy Act in the United States, and Consumer Privacy Protection Act in Canada. While there are many popular approaches to data deletion [1, 6, 9, 5, 12, 3], our approach is, at a high level, closely related to the one presented in [7]; their method is based on a continual learning algorithm called Elastic Weight Consolidation [10], which uses the Fisher Information Matrix [8] to assign importance to the learned weights.
4 Experiments
In the following section we employ the MNIST dataset to perform two simple experiments to demonstrate unlearning and relearning. Unless otherwise stated, we use the standard train/test split of 60k/10k. Throughout, we use a simple convolutional neural network architecture that has around 20k parameters, which, when trained on 60k images, attains an accuracy of 98% on the test set. We do not worry about tuning hyperparameters.
4.1 Unlearning
Setup
We first look at the question of unlearning. In order to do so, we take the MNIST dataset, and we purposefully mislabel two classes in the training dataset to create , i.e., it is the same as the original dataset but with two classes interchanged. Note that we leave the test set unchanged. For this experiment, we labelled the images containing the digit ‘2’ as three and the images containing the digit ‘3’ as two. We then train the CNN network network on the mislabelled training set, and it achieves an accuracy of 78% on the untouched test set , which is understandable given that the model has learned to mislabel some of the images.
Method
At the highest level, we approach this problem of debugging by exploiting the catastrophic forgetfulness of neural networks. That is, if we can somehow preserve the right kind of knowledge that the neural network has acquired, forgetting or unlearning the faulty behaviours comes naturally due to catastrophic forgetting. We do so with the help of modified OGD as follows:
1. After training the model on the mislabelled training set, we compute the gradients of the model (not the loss) by randomly sampling a few hundred points from the training set without the mislabelled points. In other words, we sample from with images corresponding to digits ‘2’ and ‘3’ removed. We then orthonormalize this set of gradient vectors (denoted by ).
2. Then, we randomly sample a vector that has as many elements as there are parameters in the model. That is, the random vector is of the same dimension as the gradient of the model (w.r.t the parameters). We choose to sample each element uniformly from the interval [0, 1]: .
3. Third, we project the randomly sampled vector onto the orthogonal subspace of the model gradients as follows:
4. Finally, we update the parameters using the new projected vector: where we set .
5. We repeat steps 2-4 until the accuracy of the model on the (purposefully) mislabelled portion of the dataset is as good as random. We present the algorithm in the Appendix (2).
Results
On this simple example, we found that we can use modified OGD to unlearn the response of the model on on the misclassified data points, while preserving the general performance on everything else to a large extent. We see from Figure 1 (left) that the model’s accuracy on the purposefully mislabelled potion of the dataset right after training is close to perfect, but as we change the parameters of the model as detailed in steps 2-4, the model starts to unlearn this behaviour and instead responds to the input images of ‘2’ and ‘3’ in a random manner. We note that this unlearning generalizes to the test set as well.


4.2 Relearning
Setup
Here the setup is similar to the unlearning case. We use the same CNN model and train it on the mislabelled training set , achieving around 78% accuracy on the (untouched) test set.
Method
Here, in order for the model to relearn to respond in an appropriate way to all the inputs, including the initially mislabelled ones with minimal retraining, we do the following:
1. Like before, after training the model on the mislabelled training set, we compute the gradients of the model (not the loss) by randomly sampling a few hundred points from the training set without the mislabelled points. In other words, we sample from with images corresponding to digits ‘2’ and ‘3’ removed. We then orthonormalize this set of gradient vectors (denoted by ).
2. Then, we compute the gradient of the loss function using batches from a small portion of the training set that contains all the images of digits ‘2’ and ‘3’, with the right labels.
3. Third, we project the gradient of the loss function computed in step 2 onto the orthogonal subspace of the model gradients as follows:
4. Finally, we update the parameters using the new projected vector:
where we set .
5. We repeat steps 2-4 until the accuracy of the model on the test set, , reaches 98%. Please see the Appendix (3) for the complete algorithm.
Results
The results are presented in Figure 1 (right). We observe that, using the model gradients and the (corrected) mislabelled subset, we can train the model to an accuracy of 98%, which is what the model attains when trained on the whole clean training set from scratch. We also note that using a simpler approach where we add an L2 term such as (where denotes the weights of the model after training on task 1) to the loss fucntion while training on the second task does not work.
5 Conclusion
We present two experiments supporting the hypothesis that the problem of debugging can be treated as a continual learning problem. First, we use the phenomenon of catastrophic forgetting to our advantage, and show that we can in fact do targeted unlearning. The model forgets to (mis)-classify the images containing the digits ‘2’ as three and vice-versa. However, when we compare the average (highest) confidence with which the model classifies these images before and after the unlearning, we did not notice a drastic difference: 0.98 vs 0.78 – when the weights are randomly initialized the confidence is around 0.12. In the ideal case, not only would the model unlearn, but the confidence of the predictions would also look more like the randomly initialized model. Second, we present an experiment on relearning and demonstrate that the model can successfully relearn the right behaviour, reaching the maximum possible accuracy on the dataset, given the model architecture.
We believe that it would be interesting to apply these methods to deal with problems where the model has a lot more parameters and/or where the dataset is much larger. We have also considered the possibility of using modified OGD (like in the unlearning case) to increase the robustness of models against targeted adversarial attacks.
References
- [1] L. Bourtoule, V. Chandrasekaran, C. A. Choquette-Choo, H. Jia, A. Travers, B. Zhang, D. Lie, and N. Papernot. Machine unlearning. In 2021 IEEE Symposium on Security and Privacy (SP), pages 141–159. IEEE, 2021.
- [2] Y. Cao and J. Yang. Towards making systems forget with machine unlearning. In 2015 IEEE Symposium on Security and Privacy, pages 463–480. IEEE, 2015.
- [3] M. Du, Z. Chen, C. Liu, R. Oak, and D. Song. Lifelong anomaly detection through unlearning. In Proceedings of the 2019 ACM SIGSAC Conference on Computer and Communications Security, pages 1283–1297, 2019.
- [4] M. Farajtabar, N. Azizan, A. Mott, and A. Li. Orthogonal gradient descent for continual learning. In International Conference on Artificial Intelligence and Statistics, pages 3762–3773. PMLR, 2020.
- [5] S. Garg, S. Goldwasser, and P. N. Vasudevan. Formalizing data deletion in the context of the right to be forgotten. In Annual International Conference on the Theory and Applications of Cryptographic Techniques, pages 373–402. Springer, 2020.
- [6] A. Ginart, M. Guan, G. Valiant, and J. Y. Zou. Making ai forget you: Data deletion in machine learning. Advances in Neural Information Processing Systems, 32, 2019.
- [7] A. Golatkar, A. Achille, and S. Soatto. Eternal sunshine of the spotless net: Selective forgetting in deep networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 9304–9312, 2020.
- [8] R. Grosse and J. Martens. A kronecker-factored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, pages 573–582. PMLR, 2016.
- [9] C. Guo, T. Goldstein, A. Hannun, and L. Van Der Maaten. Certified data removal from machine learning models. arXiv preprint arXiv:1911.03030, 2019.
- [10] J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, 114(13):3521–3526, 2017.
- [11] A. Mantelero. The eu proposal for a general data protection regulation and the roots of the ‘right to be forgotten’. Computer Law & Security Review, 29(3):229–235, 2013.
- [12] A. Sekhari, J. Acharya, G. Kamath, and A. T. Suresh. Remember what you want to forget: Algorithms for machine unlearning. Advances in Neural Information Processing Systems, 34, 2021.
Sample examples from : =
Input: Task sequence
sample from mislabelled examples}
while stopping criterion not met do
sample from mislabelled examples}
Input: Task sequence : data with desired labels