reject.utils
Module for utils.
Module Contents
Functions
|
Compute correct predictions. |
|
Aggregate predictions to get stack, mean, and label. |
Generate synthetic NN output for showcasing functions. |
- reject.utils.compute_correct(y_true: numpy.typing.NDArray, y_pred: numpy.typing.NDArray) numpy.typing.NDArray[source]
Compute correct predictions.
- Parameters:
y_true (NDArray) – Array of true labels. Shape (n_observations,).
y_pred (NDArray) – Array of predictions. Shape (n_observations, n_classes) or (n_observations, n_samples, n_classes).
- Returns:
Array of correct predictions. Shape (n_observations,).
- Return type:
NDArray
- Raises:
ValueError – If shape of y_pred or y_true is invalid.
- reject.utils.aggregate_preds(y_pred: numpy.typing.NDArray) tuple[numpy.typing.NDArray, numpy.typing.NDArray, numpy.typing.NDArray][source]
Aggregate predictions to get stack, mean, and label.
- Parameters:
y_pred (NDArray) – Array of predictions. Shape (n_observations, n_classes) or (n_observations, n_samples, n_classes).
- Returns:
Stack (rank 2 or 3), mean (rank 2), and label (rank 1) of predictions.
- Return type:
tuple[NDArray, NDArray, NDArray]
- reject.utils.generate_synthetic_output(num_samples: int, num_observations: int, concat: bool = True) tuple[numpy.typing.NDArray, numpy.typing.NDArray] | tuple[tuple[numpy.typing.NDArray, numpy.typing.NDArray], tuple[numpy.typing.NDArray, numpy.typing.NDArray]][source]
Generate synthetic NN output for showcasing functions.
- Parameters:
num_samples (int) – Number of samples to draw per observation.
num_observations (int) – Number of observations.
concat (bool, optional) – Whether to concatenate ID and OOD samples, by default True.
- Returns:
Union[tuple[NDArray, NDArray], tuple[tuple[NDArray, NDArray],
tuple[NDArray, NDArray]]] – Tuple of synthetic predictions and true labels.