A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks: An Explainer
Published:
The OOD baseline paper titled “A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks” is one of the most important research paper in OOD detection literature. Though the paper is not mathematically and technically heavy, it certainly has some concepts that might not be straightforward for beginners. The discussion and/or tutorial of the paper has been lacking in the internet. Hence, to help the beginners dive into the field of out-of-distribution detection and make it easier for beginners to fully understand the concepts of the paper, I have written an explainer blog.
What is OOD and failure detection?
OOD and failure detection are separate concepts. The training data has its data distribution depending on its data generation process. The data generation process of the training data might be different than the data generation process of the test data. This difference in data generation process is likely to cause errors in prediction. The process of detecting if a data point belongs to the training data distribution is called out-of-distribution (OOD) detection..
Failure detection, on the other hand, is detecting if a model is likely to make a wrong prediction. Several things other than out-of-distribution can cause failure. Failure can be a result of inherent noise in the data or boundary point or simply inability of a machine learning model. The notion of model flagging some predictions as error-prone can help and increase usability of the model.
ML model design and overconfidence
Machine learning-based classification models generally use softmax function as the last layer to convert real numbers to normalized values. If the normalized values fall between 0 and 1 and sum up to 1, the values can be interpreted as class probabilities. Due to the exponential nature of softmax function and lack of I don’t know architectural design in our models, the normalized values are generally skewed to some class. If we take the normalized values as class confidence of the model, we can see overconfidence in the predictions. The paper also shows overconfidence of our models by calculating mean softmax probabiity of the errorneous and out-of-distribution examples.
Why is OOD and failure detection important?
If we are using a machine learning model to give a user product recommendation such as in Amazon or friend recommendation such as in Facebook, it might not hurt if there are occassional bad recommendations. However, if we are using a machine learning system to detect a disease, the errors might be catastrophic. The cost of a person’s life can not be quantified in any measurable terms. So, in critical applications such as medicine, law, and fully autonomous driving, the safety and accuracy of the systems is extremely important. It is way more helpful for the machine learning model to say “I don’t know” rather than give wrong predictions in such cases. As the machine learning is starting to be ubiquitous in a lot of fields, the safety and accuracy of the machine learning-based systems is of utmost importance. So, the problem of OOD and failure detection has seen its importance increasingly growing.
What is the paper trying to convey?
This is the first question we need to answer. The paper is trying to convey a few important things about predictions of a softmax function.
Interpretation of prediction probability: The prediction probability outputted by the softmax function to a bunch of classes can not be literally interpreted as a probability. If a cat/dog classifier predicts 90% probability for cat, it does not mean 9 out of 10 times the ground truth would be cat. Instead, it gives a proxy value of confidence the model has in its prediction.
- Usefulness of softmax probability: However, the softmax probability can be useful to determine two things:
- How likely the model’s prediction is wrong?
- How likely the input data is out-of-distribution?
The lower value of softmax probability gives us the signal that the model might not be doing great in the example. This might be because the example is quite difficult for the model or the example might be out-of-distribution.
- Statistics of softmax probability: So, they see the statistics of the softmax probability of correct examples (the examples on which the model made right predictions) and wrong examples (the examples on which the model made wrong predictions). They find out that the distribution of the softmax probability is different for correct and wrong examples. They even show the statistical significance of the difference using Wilcoxon rank-sum test. So, this means we can pick some threshold which separates the distribution of correct examples and wrong examples, and then use that to reject to make predictions on probable wrong examples.
Weight initialization
The authors have used Hendrycks & Gimpel, 2016c weight initialization as it is suited for arbitrary non-linearities. According to the paper, for GELU (µ = 0, σ = 1) activation function, $E(f(z^{l−1})^2) = 0.425$ and for identity activation function, $E(f(z^{l−1})^2) = 1$. So, the four matrics of weights are divided by $\sqrt{1 + 0.425}$, $\sqrt{0.425 + 0.425}$, $\sqrt{0.425 + 0.425}$, and $\sqrt{0.425 + 1}$ respectively. The paper will be explained in a whole detail in another blog post.
The authors have conducted experiments for computer vision, natural language processing, and speech recognition tasks.
Computer vision experiments
For computer vision, the authors have taken three datasets: MNIST, CIFAR-10, and CIFAR-100. Let’s first replicate the results for success vs error prediction.
Success vs error prediction
The softmax probability gives us some quantitative value of the confidence of the prediction. The authors show higher value of softmax probability associated with success and relatively lower value associated with error.
Let’s take the three datasets: MNIST, CIFAR-10, and CIFAR-100. The authors have also mentioned the architectures on which the model is built.
MNIST: For MNIST, they have three layer (excluding input and output layer) fully connected neural network with 256 neurons in each layer. They use Adam optimizer and GELU activation function. They use Hendrycks & Gimpel, 2016c weight initialization as it is suited for arbitrary nonlinearities. The biases are initialized to zero values. They train the neural network for 30 epochs.
CIFAR: CIFAR-10 & CIFAR-100 4 layers network of width 40, residual network 50 epochs SGD Use restart GELU nonlinearity standard mirroring and cropping data augmentation
In- vs Out-of-distribution detection
For in-distribution vs out-of-distribution experiments, the authors have paired datasets in this way:
- CIFAR-10 vs Scene UNderstanding dataset
