#!/usr/bin/env python3
# =============================================================================
# Created By : Arthur Thuy
# Created Date: Wed February 28 2024
# =============================================================================
"""Module for uncertainty."""
# =============================================================================
from typing import Optional, Union
import numpy as np
from numpy.typing import NDArray
from scipy.stats import entropy
from reject.constant import ENTROPY_UNC_LIST, EntropyUnc
from reject.utils import aggregate_preds
[docs]
def compute_uncertainty(
y_pred: NDArray, unc_type: Optional[EntropyUnc] = None
) -> Union[NDArray, dict[str, NDArray]]:
"""Calculate total uncertainty (TU), aleatoric uncertainty (AU) and epistemic\
uncertainty (EU).
Parameters
----------
y_pred : NDArray
Array of predictions. Shape (n_observations, n_classes)\
or (n_observations, n_samples, n_classes).
unc_type : Unc_type, optional
Type of uncertainty to compute (either TU, AU, or EU), by default None
Returns
-------
Union[NDArray, tuple[NDArray, NDArray, NDArray]]
Array of one uncertainty type, or all three uncertainty types.
Raises
------
ValueError
If unc_type is invalid.
"""
# checks
if (unc_type is not None) and (unc_type not in ENTROPY_UNC_LIST):
raise ValueError("`type` must be `None` or one of TU, AU, EU.")
if y_pred.ndim not in [2, 3]:
raise ValueError(f"`y_stack` should have rank 2 or 3, has rank {y_pred.ndim}")
# get rank 3 stack
y_stack, _, _ = aggregate_preds(y_pred)
# total: (observations, samples, classes) => (observations, classes)
# => (observations,)
unc_total = entropy(np.mean(y_stack, axis=-2), base=2, axis=-1)
# aleatoric: (observations, samples, classes) => (observations, samples)
# => (observations,)
unc_aleatoric = np.mean(entropy(y_stack, base=2, axis=-1), axis=-1)
# epistemic: (observations,)
unc_epistemic = np.subtract(unc_total, unc_aleatoric)
unc_all = {"TU": unc_total, "AU": unc_aleatoric, "EU": unc_epistemic}
if unc_type is not None:
return unc_all[unc_type]
else:
return unc_all
[docs]
def compute_confidence(y_pred: NDArray) -> NDArray:
"""Compute confidence.
Parameters
----------
y_pred : NDArray
Array of predictions. Shape (n_observations, n_classes)\
or (n_observations, n_samples, n_classes).
Returns
-------
conf : NDArray
Array of confidence values.
"""
# checks
if y_pred.ndim not in [2, 3]:
raise ValueError(f"`y_stack` should have rank 2 or 3, has rank {y_pred.ndim}")
_, y_mean, _ = aggregate_preds(y_pred)
conf = np.max(y_mean, axis=-1)
return conf