Classification with rejection

reject is a python package for classification with rejection. When the prediction’s uncertainty is too high, the model abstains from predicting and the observation is passed on to a human expert who takes the final decision. It is useful for applications where making an error can be more costly than asking a human expert for help.

import reject
from reject.reject import ClassificationRejector
from reject.utils import generate_synthetic_output
print(reject.__version__)
0.3.2

Generate synthetic NN output

In this example, we generate synthetic outputs of a NN with multiple samples of the predictive distribution. The output predictions are of shape (n_observations, n_samples, n_classes) and the true labels (n_observations,). The data generation function uses 10 output classes.

NUM_SAMPLES = 10
NUM_OBSERVATIONS = 1000

y_pred_all, y_true_all = generate_synthetic_output(NUM_SAMPLES, NUM_OBSERVATIONS)
print(y_pred_all.shape, y_true_all.shape)
(2000, 10, 10) (2000,)

ClassificationRejector

Object creation

The ClassificationRejector class rejects predictions from a classification model. It is initialized with the true labels and the predicted probabilities. The predicted probabilities have to have a shape (n_observations, n_classes) or (n_observations, n_samples, n_classes) for models with multiple samples of the predictive distribution such as MC Dropout or Deep Ensembles. The true labels have to have a shape (n_observations,).

rej = ClassificationRejector(y_true_all, y_pred_all)

uncertainty method

The uncertainty/confidence in a prediction can be quantified in 2 ways:

  • Total uncertainty, which is entropy-based.
    => can be decomposed in aleatoric (i.e. data) and epistemic (i.e. model) uncertainty

  • Confidence, which directly uses the predicted probabilities.
    => cannot be decomposed

Passing “TU” (total uncertainty), “AU” (aleatoric uncertainty), “EU” (epistemic uncertainty), or “confidence” as unc_type returns their corresponding Numpy array.

rej.uncertainty(unc_type="TU")
array([2.36096002, 2.77688085, 2.81727076, ..., 1.82345458, 2.0367632 ,
       1.89708975])

Passing None as unc_type returns a dict with all 3 uncertainties.

rej.uncertainty(unc_type=None)
{'TU': array([2.36096002, 2.77688085, 2.81727076, ..., 1.82345458, 2.0367632 ,
        1.89708975]),
 'AU': array([0.8389478 , 1.137137  , 1.2077203 , ..., 1.21363037, 1.53001765,
        1.30948878]),
 'EU': array([1.52201222, 1.63974385, 1.60955046, ..., 0.60982421, 0.50674554,
        0.58760097])}

plot_uncertainty method

The uncertainty can be visualized using the plot_uncertainty method. It takes the unc_type as an argument. If unc_type is None, it plots all 3 uncertainties.

fig = rej.plot_uncertainty(unc_type="TU")
../_images/26e855f126b98df75e16d05f01248de8f10debdd33ec0bf007127ef26c3ab574.png
fig = rej.plot_uncertainty(unc_type=None)
../_images/8446fb6ddb46f585556b0441e1b137b181d03253683fd282b07792168458fe4a.png

reject method

reject is the core function of the package. It takes as arguments a rejection threshold, the uncertainty type to use, and whether the threshold is relative or absolute. The rejection treshold is either relative (i.e. reject threshold percent) or absolute (i.e. reject if uncertainty is >= threshold). It returns 3 evaluation metrics and a boolean array of the same length as the input data, where True means the observation is rejected. The 3 evaluation metrics are Non-rejected accuracy (NRA), Classification quality (CQ) and Rejection quality (RQ), which follows the work of Condessa et al. (2017).

Again, unc_type is one of {‘TU’, ‘AU’, ‘EU’, ‘confidence’}

# implement single rejection point
rej.reject(threshold=0.5, unc_type="TU", relative=True, show=True)
             Non-rejected    Rejected
---------  --------------  ----------
Correct               875          24
Incorrect             125         976

  Non-rejected accuracy    Classification quality    Rejection quality
-----------------------  ------------------------  -------------------
                 0.8750                    0.9255              33.2056
((0.875, 0.9255, 33.205570693309106),
 array([ True,  True,  True, ..., False, False, False]))

plot_reject method

It is interesting to plot the metrics for a varying threshold. plot_reject allows for visualization with different uncertainty types and metrics. At least one of unc_type or metric has to be passed as an argument. If unc_type is None, it plots all 3 uncertainties. If metric is None, it plots all 3 metrics. The threshold can be relative or absolute.

When a an appropriate threshold is selected, the reject method can be used to get the boolean array of rejected observations.

Specific unc_type and metric, relative rejection

fig = rej.plot_reject(unc_type="TU", metric="NRA")
print(fig)
Figure(450x300)
../_images/07b4c1987b4135c3080f71d924b787eacfcac734ea7b4baeb52846f4cc467425.png

Specific unc_type and metric, absolute rejection

fig = rej.plot_reject(unc_type="TU", metric="NRA", relative=False)
print(fig)
Figure(450x300)
../_images/b17840b4ad59315078a46e10d700481f0a0d303ace03a703b6829ae46fb837e8.png

Specific unc_type for all three metrics, absolute rejection

fig = rej.plot_reject(unc_type="TU", metric=None, relative=False)
print(fig)
Figure(1600x300)
../_images/071d07bd428247812c0f001d1f98bd65281ad39241775244bb03c760f8ca9fa4.png

Specific metric on all three unc_types, absolute rejection

fig = rej.plot_reject(unc_type=None, metric="NRA", relative=False)
print(fig)
Figure(1600x300)
../_images/1af2595288a663b9870da7a69a40be8515bffaa324076da910f77a03d80b0658.png