#!/usr/bin/env python3
# =============================================================================
# Created By : Arthur Thuy
# Created Date: Wed February 28 2024
# =============================================================================
"""Module for utils."""
# =============================================================================
from typing import Union
import numpy as np
from numpy.typing import NDArray
from scipy.special import softmax
[docs]
def compute_correct(y_true: NDArray, y_pred: NDArray) -> NDArray:
"""Compute correct predictions.
Parameters
----------
y_true : NDArray
Array of true labels. Shape (n_observations,).
y_pred : NDArray
Array of predictions. Shape (n_observations, n_classes)\
or (n_observations, n_samples, n_classes).
Returns
-------
NDArray
Array of correct predictions. Shape (n_observations,).
Raises
------
ValueError
If shape of `y_pred` or `y_true` is invalid.
"""
# checks
if y_true.shape[0] != y_pred.shape[0]:
raise ValueError(
f"Number of observations in `y_true` and `y_pred` should match,\
got {y_true.shape[0]} and {y_pred.shape[0]}"
)
if y_pred.ndim not in [1, 2, 3]:
raise ValueError(
f"`y_pred` should have rank 1, 2, or 3, has rank {y_pred.ndim}"
)
if not y_true.ndim == 1:
raise ValueError(f"`y_true` should have rank 1, has rank {y_true.ndim}")
if y_pred.ndim == 1:
y_label = y_pred
else:
_, _, y_label = aggregate_preds(y_pred)
is_correct = np.equal(y_true, y_label)
return is_correct
[docs]
def aggregate_preds(y_pred: NDArray) -> tuple[NDArray, NDArray, NDArray]:
"""Aggregate predictions to get stack, mean, and label.
Parameters
----------
y_pred : NDArray
Array of predictions. Shape (n_observations, n_classes)\
or (n_observations, n_samples, n_classes).
Returns
-------
tuple[NDArray, NDArray, NDArray]
Stack (rank 2 or 3), mean (rank 2), and label (rank 1) of predictions.
"""
# checks
if y_pred.ndim not in [2, 3]:
raise ValueError(f"`y_pred` should have rank 2 or 3, has rank {y_pred.ndim}")
# only take mean if multiple samples
if y_pred.ndim == 3:
y_stack = y_pred
y_mean = np.mean(y_pred, axis=-2)
elif y_pred.ndim == 2:
y_stack = np.expand_dims(y_pred, axis=-2)
y_mean = y_pred
y_label = np.argmax(y_mean, axis=-1)
return y_stack, y_mean, y_label
[docs]
def generate_synthetic_output(
num_samples: int, num_observations: int, concat: bool = True
) -> Union[
tuple[NDArray, NDArray], tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]]
]:
"""Generate synthetic NN output for showcasing functions.
Parameters
----------
num_samples : int
Number of samples to draw per observation.
num_observations : int
Number of observations.
concat : bool, optional
Whether to concatenate ID and OOD samples, by default True.
Returns
-------
Union[tuple[NDArray, NDArray], tuple[tuple[NDArray, NDArray],
tuple[NDArray, NDArray]]]
Tuple of synthetic predictions and true labels.
"""
NUM_CLASSES = 10
# example logit output
logit_ary = [0.01, 0.01, 0.01, 0.4, 0.01, 0.01, 0.03, 0.01, 0.40, 0.11]
assert np.isclose(np.sum(logit_ary), 1.0)
# OOD
y_pred_ood = np.empty((num_observations, num_samples, NUM_CLASSES))
for i in range(num_observations):
for j in range(num_samples):
roll_idx = np.random.choice(
10, 1, p=[0.11, 0.01, 0.01, 0.27, 0.01, 0.11, 0.02, 0.01, 0.20, 0.25]
)
y_pred_ood[i, j] = np.random.multinomial(
10, np.roll(logit_ary, roll_idx), size=1
)
y_pred_ood = softmax(y_pred_ood, axis=-1)
assert y_pred_ood.shape == (num_observations, num_samples, NUM_CLASSES)
# ID
id_ary = [0.01, 0.01, 0.01, 0.27, 0.01, 0.01, 0.02, 0.01, 0.40, 0.25]
assert np.isclose(np.sum(id_ary), 1.0)
y_pred_id = np.random.multinomial(10, id_ary, size=(num_observations, num_samples))
y_pred_id = softmax(y_pred_id, axis=-1)
assert y_pred_id.shape == (num_observations, num_samples, NUM_CLASSES)
# concatenate preds
y_pred_all = np.concatenate((y_pred_ood, y_pred_id), axis=0)
assert y_pred_all.shape == (2 * num_observations, num_samples, NUM_CLASSES)
# true labels
y_true_id = np.full((num_observations), 8)
y_true_ood = np.full((num_observations), 999)
y_true_all = np.concatenate((y_true_ood, y_true_id), axis=0)
assert y_true_all.shape == (2 * num_observations,)
if concat:
return y_pred_all, y_true_all
else:
return (y_pred_id, y_true_id), (y_pred_ood, y_true_ood)