#!/usr/bin/env python3
# =============================================================================
# Created By : Arthur Thuy
# Created Date: Wed February 28 2024
# =============================================================================
"""Module for rejection."""
# =============================================================================
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
from tabulate import tabulate
from reject.constant import (
ALL_UNC_LIST,
ENTROPY_UNC_LIST,
GENERAL_UNC_LIST,
METRICS_DICT,
UNCERTAINTIES_DICT,
AllUnc,
GeneralUnc,
Metric,
)
from reject.uncertainty import compute_confidence, compute_uncertainty
from reject.utils import compute_correct
[docs]
def confusion_matrix(
correct: NDArray,
unc_ary: NDArray,
threshold: float,
relative: bool = True,
show: bool = False,
) -> tuple[tuple[int, int, int, int], NDArray]:
"""Compute confusion matrix.
Confusion matrix with 2 axes: (i) correct/incorrect, (ii) rejected/non-rejected.
Parameters
----------
correct : NDArray
1D array of correct/incorrect indicators.
unc_ary : NDArray
1D array of uncertainty values, largest value rejected first.
threshold : float
Rejection threshold.
relative : bool, optional
Use relative rejection, otherwise absolute rejection, by default True
show : bool, optional
Print confusion matrix to console, by default False
Returns
-------
n_cor_rej : int
Number of correct observations that are rejected.
n_cor_nonrej : int
Number of correct observations that are not rejected.
n_incor_rej : int
Number of incorrect observations that are rejected.
n_incor_nonrej : int
Number of incorrect observations that are not rejected.
pred_reject : ndarray
Array of True/False indicators to reject predictions.
"""
# input checks
if threshold < 0:
raise ValueError("Threshold must be non-negative.")
if relative and threshold > 1:
raise ValueError("Threshold must be less than or equal to 1.")
# axis 0: correct or incorrect
idx_correct = np.where(correct == 1.0)[0]
idx_incorrect = np.where(correct == 0.0)[0]
# axis 1: rejected or non-rejected
if relative:
# relative rejection
n_preds_rej = int(threshold * correct.shape[0])
# use uncertainty array
# sort by unc_ary, then by random numbers random_draws
# -> if values equal e.g. 1.0 -> rejected randomly
random_draws = np.random.random(correct.shape[0])
idx = np.lexsort((random_draws, unc_ary))
idx = np.flip(idx, axis=0)
idx_rej = idx[:n_preds_rej]
idx_nonrej = idx[n_preds_rej:]
pred_reject = np.where(
np.isin(np.arange(correct.shape[0]), idx_rej), True, False
)
else:
# absolute rejection
pred_reject = np.where(unc_ary >= threshold, True, False)
idx_rej = np.where(pred_reject == True)[0] # noqa
idx_nonrej = np.where(pred_reject == False)[0] # noqa
# intersections
idx_cor_rej = np.intersect1d(idx_correct, idx_rej)
idx_cor_nonrej = np.intersect1d(idx_correct, idx_nonrej)
idx_incor_rej = np.intersect1d(idx_incorrect, idx_rej)
idx_incor_nonrej = np.intersect1d(idx_incorrect, idx_nonrej)
n_cor_rej = idx_cor_rej.shape[0]
n_cor_nonrej = idx_cor_nonrej.shape[0]
n_incor_rej = idx_incor_rej.shape[0]
n_incor_nonrej = idx_incor_nonrej.shape[0]
if show:
print( # TODO: use logging?
tabulate(
[
["", "Non-rejected", "Rejected"],
["Correct", n_cor_nonrej, n_cor_rej],
["Incorrect", n_incor_nonrej, n_incor_rej],
],
headers="firstrow",
)
)
return (n_cor_rej, n_cor_nonrej, n_incor_rej, n_incor_nonrej), pred_reject
[docs]
def compute_metrics(
threshold: float,
correct: NDArray,
unc_ary: NDArray,
relative: bool = True,
return_bool: bool = True,
show: bool = True,
) -> Union[tuple[float, float, float], tuple[tuple[float, float, float], NDArray]]:
"""Compute 3 rejection metrics using relative or absolute threshold.
| 3 rejection metrics:
| - non-rejeced accuracy (NRA)
| - classification quality (CQ)
| - rejection quality (RQ)
Parameters
----------
threshold : float
Rejection threshold.
correct : NDArray
1D array of correct/incorrect indicator.
unc_ary : ndarray
1D array of uncertainty values, largest value rejected first.
relative : bool, optional
Use relative rejection, otherwise absolute rejection, by default True
return_bool : bool, optional
Return boolean array of rejected predictions, by default True
show : bool, optional
Print confusion matrix to console, by default True
Returns
-------
nonrej_acc : float
Non-rejeced accuracy (NRA).
class_quality : float
Classification quality (CQ).
rej_quality : float
Rejection quality (RQ).
pred_reject : ndarray
Array of True/False indicators to reject predictions.
Notes
-----
- rejection quality is undefined when `n_cor_rej=0`
- if any observation is rejected: RQ = positive infinite
- if no sample is rejected: RQ = 1
- see: `Condessa et al. (2017) <https://doi.org/10.1016/j.patcog.2016.10.011>`_
"""
(n_cor_rej, n_cor_nonrej, n_incor_rej, n_incor_nonrej), pred_reject = (
confusion_matrix(
correct=correct,
unc_ary=unc_ary,
threshold=threshold,
show=show,
relative=relative,
)
)
# non-rejected accuracy
try:
nonrej_acc = n_cor_nonrej / (n_incor_nonrej + n_cor_nonrej)
except ZeroDivisionError:
nonrej_acc = np.inf # invalid
# classification quality
try:
class_quality = (n_cor_nonrej + n_incor_rej) / (
n_cor_rej + n_cor_nonrej + n_incor_rej + n_incor_nonrej
)
except ZeroDivisionError:
class_quality = np.inf # invalid
# rejection quality
try:
rej_quality = (n_incor_rej / n_cor_rej) / (
(n_incor_rej + n_incor_nonrej) / (n_cor_rej + n_cor_nonrej)
)
except ZeroDivisionError:
if (n_incor_rej + n_cor_rej) > 0:
rej_quality = np.inf
else:
rej_quality = 1.0
if show:
data = [[nonrej_acc, class_quality, rej_quality]]
print( # TODO: use logging instead of print
"\n"
+ tabulate(
data,
headers=[
"Non-rejected accuracy",
"Classification quality",
"Rejection quality",
],
floatfmt=".4f",
)
)
if return_bool:
return (nonrej_acc, class_quality, rej_quality), pred_reject
else:
return (nonrej_acc, class_quality, rej_quality)
[docs]
class ClassificationRejector:
"""Classification with rejection."""
def __init__(
self,
y_true: NDArray,
y_pred: NDArray,
):
"""Classification with rejection.
Parameters
----------
y_true : NDArray
Array of true labels. Shape (n_observations,).
y_pred : NDArray
Array of predictions. Shape (n_observations, n_classes)\
or (n_observations, n_samples, n_classes).
"""
self.y_true = y_true
self.y_pred = y_pred
self.num_classes = y_pred.shape[-1]
self.max_entropy = np.log2(self.num_classes)
# calculate uncertainty and correctness
self._uncertainty = compute_uncertainty(y_pred)
self.confidence = compute_confidence(y_pred)
self.correct = compute_correct(y_true, y_pred)
[docs]
def uncertainty(
self, unc_type: Optional[AllUnc] = None
) -> Union[NDArray, dict[str, NDArray]]:
"""Get uncertainty or confidence values.
Parameters
----------
unc_type : Optional[AllUnc], optional
Uncertainty type to return. If None, dict of TU, AU, and EU is returned.
By default None
Returns
-------
Union[NDArray, dict[NDArray, NDArray, NDArray]]
Array of one uncertainty type, or all three uncertainty types.
Raises
------
ValueError
If `unc_type` is invalid.
"""
# checks
if unc_type is not None and unc_type not in ALL_UNC_LIST:
raise ValueError(
"Invalid uncertainty type."
" Expected one of: TU, AU, EU, confidence or None"
)
if unc_type == "confidence":
return self.confidence
elif unc_type is not None:
return self._uncertainty[unc_type]
else:
return self._uncertainty
[docs]
def plot_uncertainty(
self,
unc_type: Optional[AllUnc] = None,
bins: int = 15,
) -> plt.Figure:
"""Plot uncertainty values.
Parameters
----------
unc_type : Optional[AllUnc], optional
Uncertainty type to return. If None, dict of TU, AU, and EU is returned.
By default None
Returns
-------
plt.Figure
Figure object.
Raises
------
ValueError
If `unc_type` is invalid.
"""
# checks
if unc_type is not None and unc_type not in ALL_UNC_LIST:
raise ValueError(
"Invalid uncertainty type."
" Expected one of: TU, AU, EU, confidence or None"
)
if unc_type == "confidence":
xlim = (0 - 0.05, 1 + 0.05)
else:
xlim = (
0 - 0.05 * self.max_entropy,
self.max_entropy + 0.05 * self.max_entropy,
)
# draw plot
if unc_type is not None:
unc_enumerate = [unc_type]
fig, axes = plt.subplots(ncols=1, figsize=(4.5, 3))
axes = [axes]
else:
unc_enumerate = ENTROPY_UNC_LIST # type: ignore
fig, axes = plt.subplots(ncols=3, figsize=(16, 3))
for i, unc_type in enumerate(unc_enumerate):
axes[i].hist(self.uncertainty(unc_type), bins=bins)
axes[i].grid(linestyle="dashed")
axes[i].set(xlabel=UNCERTAINTIES_DICT[unc_type], ylabel="Frequency")
axes[i].set_xlim(xlim)
return fig
[docs]
def reject(
self,
threshold: float,
unc_type: AllUnc,
relative: bool = True,
show: bool = False,
) -> tuple[tuple[float, float, float], NDArray]:
"""Reject with a single threshold.
Parameters
----------
threshold : float
Rejection threshold.
unc_type : AllUnc
Uncertainty type to use for rejection order.
relative : bool, optional
Reject relative to the amount of observations, otherwise compare to the
uncertainty value. By default True
show : bool, optional
Print confusion matrix and metrics, by default False
Returns
-------
tuple[float, float, float]
Non-rejected accuracy, classification quality, and rejection quality.
"""
# checks
if unc_type not in ALL_UNC_LIST:
raise ValueError(
"Invalid uncertainty type. Expected one of: TU, AU, EU, confidence"
)
unc_ary = (
self.confidence if unc_type == "confidence" else self._uncertainty[unc_type]
)
# NOTE: mypy think is can return tuple[float, float, float]
# but not true because `return_bool=True`
return compute_metrics( # type: ignore[return-value]
threshold=threshold,
correct=self.correct,
unc_ary=unc_ary,
relative=relative,
show=show,
return_bool=True,
)
[docs]
def plot_reject(
self,
unc_type: Optional[AllUnc] = None,
metric: Optional[Metric] = None,
relative: bool = True,
space_start: float = 0.001,
space_stop: float = 0.99,
space_bins: int = 100,
filename: Optional[str] = None,
**save_args
) -> plt.Figure:
"""Plot one or multiple rejection metrics for a range of thresholds.
Rejection can be based on one or more uncertainty types.
There should be at least one of `unc_type` or `metric` specified.
Parameters
----------
unc_type : Optional[AllUnc], optional
Uncertainty type to use for rejection order. If None, 3 panels with
TU, AU, and EU are plotted. By default None
metric : Optional[Metric], optional
Rejection metrics to compute. If None, 3 panels with NRA, CQ, and RQ
are plotted. By default None
relative : bool, optional
Reject relative to the amount of observations, otherwise compare to
the uncertainty value. By default True
space_start : float, optional
Threshold value to start figure at, by default 0.001
space_stop : float, optional
Threshold value to stop figure at, by default 0.99
space_bins : int, optional
Number of evaluation points in the line plot, by default 100
filename : Optional[str], optional
Filename to save figure. If None, no figure saved. By default None
Returns
-------
plt.Figure
Figure object.
Raises
------
ValueError
If `unc_type` and `metric` are both None.
ValueError
If `unc_type` is invalid.
"""
# checks
if unc_type is None and metric is None:
raise ValueError(
"`unc_type` and `metric` cannot be both None, at least one must"
" be specified."
)
if unc_type is not None and unc_type not in ALL_UNC_LIST:
raise ValueError(
"Invalid uncertainty type. Expected one of: %s" % ALL_UNC_LIST
)
if metric is not None:
fig = self.__plot_1_3_uncertainty_panels(
metric=metric,
unc_type=unc_type,
relative=relative,
space_start=space_start,
space_stop=space_stop,
space_bins=space_bins,
filename=filename,
**save_args
)
elif metric is None and unc_type is not None:
fig = self.__plot_3_metric_panels(
unc_type=unc_type,
relative=relative,
space_start=space_start,
space_stop=space_stop,
space_bins=space_bins,
filename=filename,
**save_args
)
return fig
def __plot_1_3_uncertainty_panels(
self,
metric: Metric,
unc_type: Optional[AllUnc] = None,
relative: bool = True,
space_start: float = 0.001,
space_stop: float = 0.99,
space_bins: int = 100,
filename: Optional[str] = None,
**save_args
) -> plt.Figure:
"""Plot one or 3 panels with uncertainty types, for a specific metric.
Parameters
----------
metric : Metric
Rejection metrics to compute.
unc_type : Optional[AllUnc], optional
Uncertainty type to use for rejection order. If None, 3 panels with
TU, AU, and EU are plotted. By default None
relative : bool, optional
Reject relative to the amount of observations, otherwise compare to
the uncertainty value. By default True
space_start : float, optional
Threshold value to start figure at, by default 0.001
space_stop : float, optional
Threshold value to stop figure at, by default 0.99
space_bins : int, optional
Number of evaluation points in the line plot, by default 100
filename : Optional[str], optional
Filename to save figure. If None, no figure saved. By default None
Returns
-------
plt.Figure
Figure object.
"""
# draw plot
if unc_type is not None:
unc_enumerate = [unc_type]
fig, axes = plt.subplots(ncols=1, figsize=(4.5, 3))
axes = [axes]
else:
unc_enumerate = ENTROPY_UNC_LIST # type: ignore
fig, axes = plt.subplots(ncols=3, figsize=(16, 3))
for i, unc_type in enumerate(unc_enumerate):
unc_ary = (
self.confidence
if unc_type == "confidence"
else self._uncertainty[unc_type]
)
if unc_type == "confidence":
# largest value is most uncertain
unc_ary = 1.0 - unc_ary
self.__plot_base_panel(
correct=self.correct,
unc_ary=unc_ary,
metric=metric,
unc_type=(
GeneralUnc.CONFIDENCE
if unc_type == "confidence"
else GeneralUnc.ENTROPY
),
relative=relative,
space_start=space_start,
space_stop=space_stop,
space_bins=space_bins,
ax=axes[i],
)
axes[i].grid(linestyle="dashed")
if relative:
axes[i].set(xlabel="Relative threshold", ylabel=METRICS_DICT[metric])
else:
axes[i].set(xlabel="Absolute threshold", ylabel=METRICS_DICT[metric])
if filename is not None:
fig.tight_layout()
fig.savefig(filename, **save_args)
return fig
def __plot_3_metric_panels(
self,
unc_type: AllUnc,
relative: bool = True,
space_start: float = 0.001,
space_stop: float = 0.99,
space_bins: int = 100,
filename: Optional[str] = None,
**save_args
) -> plt.Figure:
"""Plot 3 panels with rejection metrics, for a specific uncertainty type.
Parameters
----------
unc_type : AllUnc
Uncertainty type to use for rejection order.
relative : bool, optional
Reject relative to the amount of observations, otherwise compare to
the uncertainty value. By default True
space_start : float, optional
Threshold value to start figure at, by default 0.001
space_stop : float, optional
Threshold value to stop figure at, by default 0.99
space_bins : int, optional
Number of evaluation points in the line plot, by default 100
filename : Optional[str], optional
Filename to save figure. If None, no figure saved. By default None
Returns
-------
plt.Figure
Figure object.
"""
unc_ary = (
self.confidence if unc_type == "confidence" else self._uncertainty[unc_type]
)
if unc_type == "confidence":
# largest value is most uncertain
unc_ary = 1.0 - unc_ary
# draw plot
fig, axes = plt.subplots(ncols=3, figsize=(16, 3))
for i, label in enumerate(Metric):
self.__plot_base_panel(
correct=self.correct,
unc_ary=unc_ary,
metric=label,
unc_type=(
GeneralUnc.CONFIDENCE
if unc_type == "confidence"
else GeneralUnc.ENTROPY
),
relative=relative,
space_start=space_start,
space_stop=space_stop,
space_bins=space_bins,
ax=axes[i],
)
axes[i].grid(linestyle="dashed")
if relative:
axes[i].set(xlabel="Relative threshold", ylabel=METRICS_DICT[label])
else:
axes[i].set(xlabel="Absolute threshold", ylabel=METRICS_DICT[label])
if filename is not None:
fig.tight_layout()
fig.savefig(filename, **save_args)
return fig
def __plot_base_panel(
self,
correct: NDArray,
unc_ary: NDArray,
metric: Metric,
unc_type: GeneralUnc,
relative: bool = True,
space_start: float = 0.001,
space_stop: float = 0.99,
space_bins: int = 100,
ax: Optional[plt.Axes] = None,
) -> plt.Axes:
"""Plot single panel with some rejection metric and uncertainty type.
Parameters
----------
correct : NDArray
Array of correct predictions. Shape (n_observations,).
unc_ary : NDArray
Array of uncertainty values, largest value rejected first.
metric : Metric
Rejection metric to compute.
unc_type : GeneralUnc
Uncertainty type to use for rejection order.
relative : bool, optional
Reject relative to the amount of observations, otherwise compare to
the uncertainty value. By default True
space_start : float, optional
Threshold value to start figure at, by default 0.001
space_stop : float, optional
Threshold value to stop figure at, by default 0.99
space_bins : int, optional
Number of evaluation points in the line plot, by default 100
ax : Optional[plt.Axes], optional
Ax to plot on. If None, new ax is created. By default None
Returns
-------
plt.Axes
Axes object.
Raises
------
ValueError
If `unc_type` is invalid.
"""
# checks
if unc_type not in GENERAL_UNC_LIST:
raise ValueError(
"Invalid uncertainty type. Expected one of: %s" % GENERAL_UNC_LIST
)
if relative:
treshold_ary = np.linspace(
start=space_start, stop=space_stop, num=space_bins
)
reject_ary = treshold_ary
plot_ary = treshold_ary
elif not relative and unc_type == "confidence":
treshold_ary = np.linspace(
start=(1 - space_start),
stop=(1 - space_stop),
num=space_bins,
)
reject_ary = treshold_ary
plot_ary = np.flip(treshold_ary, axis=0)
elif not relative and unc_type == "entropy":
treshold_ary = np.linspace(
start=(1 - space_start) * self.max_entropy,
stop=(1 - space_stop) * self.max_entropy,
num=space_bins,
)
reject_ary = treshold_ary
plot_ary = treshold_ary
compute_metrics_rej_v = np.vectorize(
compute_metrics,
excluded=["correct", "unc_ary", "show", "relative"],
)
nonrej_acc, class_quality, rej_quality = compute_metrics_rej_v(
reject_ary,
correct=correct,
unc_ary=unc_ary,
show=False,
return_bool=False,
relative=relative,
)
# plot on existing axis or new axis
if ax is None:
ax = plt.gca()
if metric == "NRA":
ax.plot(plot_ary, nonrej_acc)
elif metric == "CQ":
ax.plot(plot_ary, class_quality)
elif metric == "RQ":
ax.plot(plot_ary, rej_quality)
if not relative and unc_type == "entropy":
# invert x-axis, largest uncertainty values on the left
ax.invert_xaxis()
ax.set_xlim(
self.max_entropy + 0.05 * self.max_entropy, 0 - 0.05 * self.max_entropy
)
return ax