Source code for echofilter.optim.criterions

"""
Evaluation criterions.
"""

# This file is part of Echofilter.
#
# Copyright (C) 2020-2022  Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import torch


def _binarise_and_reshape(arg, threshold=0.5, ndim=None):
    """
    Binarize and partially flatten a tensor.

    Parameters
    ----------
    arg : array_like
        Input tensor or array.
    threshold : float, optional
        Threshold which entries in ``arg`` must exceed. Default is ``0.5``.
    ndim : int or None
        Number of dimensions to keep. If ``None``, only the first (batch)
        dimension is kept and the rest are flattened. Default is ``None``.

    Returns
    -------
    array_like
        A :class:`numpy.ndarray` or :class:`torch.Tensor` (corresponding to
        the type of ``arg``), but partially flattened and binarised.
    """
    # Binarise mask
    arg = arg > threshold
    # Reshape so pixels are vectorised by batch
    if ndim is None:
        shape = [arg.shape[0], -1]
    else:
        shape = list(arg.shape)
        shape = shape[:-ndim] + [-1]
    arg = arg.reshape(shape)
    return arg


[docs]def mask_active_fraction(input, threshold=0.5, ndim=None, reduction="mean"): """ Measure the fraction of input which exceeds a threshold. Parameters ---------- input : torch.Tensor Input tensor. threshold : float, optional Threshold which entries in ``input`` must exceed. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The fraction of ``input`` which exceeds ``threshold``, with shaped corresponding to ``reduction``. """ # Binarise and reshape mask input = _binarise_and_reshape(input, threshold=threshold, ndim=ndim) # Measure hit rate and number of samples output = input.sum(-1).float() / input.size(-1) # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_active_fraction_with_logits(input, *args, **kwargs): """ Convert logits to probabilities, and measure what fraction exceed threshold. See Also -------- mask_active_fraction """ return mask_active_fraction(torch.sigmoid(input), *args, **kwargs)
[docs]def mask_accuracy(input, target, threshold=0.5, ndim=None, reduction="mean"): """ Measure accuracy of input compared to binary targets. Parameters ---------- input : torch.Tensor Input tensor. target : torch.Tensor Target tensor, the same shape as ``input``. threshold : float, optional Threshold which entries in ``input`` and ``target`` must exceed to be binarised as the positive class. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The fraction of ``input`` which has the same class as ``target`` after thresholding. """ # Binarise and reshape masks input = _binarise_and_reshape(input, threshold=threshold, ndim=ndim) target = _binarise_and_reshape(target, threshold=threshold, ndim=ndim) # Measure hit rate and number of samples hits = (input == target).sum(-1) count = input.size(-1) output = hits.float() / count # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_accuracy_with_logits(input, *args, **kwargs): """ Measure accuracy with logit inputs. Pass through a sigmoid, binarize, then measure accuracy of predictions compared to ground truth target. See Also -------- mask_accuracy """ return mask_accuracy(torch.sigmoid(input), *args, **kwargs)
[docs]def mask_precision(input, target, threshold=0.5, ndim=None, reduction="mean"): """ Measure precision of probability input. Binarize with a threshold, then measure precision compared to a ground truth target. Parameters ---------- input : torch.Tensor Input tensor. target : torch.Tensor Target tensor, the same shape as ``input``. threshold : float, optional Threshold which entries in ``input`` and ``target`` must exceed to be binarised as the positive class. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The precision of ``input`` as compared to ``target`` after thresholding. The fraction of predicted positive cases, `input > 0.5`, which are true positive cases (`input > 0.5 and `target > 0.5`). If there are no predicted positives, the output is ``0`` if there are any positives to predict and ``1`` if there are none. """ # Binarise and reshape masks input = _binarise_and_reshape(input, threshold=threshold, ndim=ndim) target = _binarise_and_reshape(target, threshold=threshold, ndim=ndim) # Measure true positives and total predicted positives true_p = (input & target).sum(-1) predicted_p = input.sum(-1) output = true_p.float() / predicted_p.float() # Handle division by 0: If there were no positives predicted, check whether # there were any to find. output[predicted_p == 0] = (true_p[predicted_p == 0] == 0).float() # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_precision_with_logits(input, *args, **kwargs): """ Measure precision of logit input. Pass through sigmoid, threshold, then measure precision. See Also -------- mask_precision """ return mask_precision(torch.sigmoid(input), *args, **kwargs)
[docs]def mask_recall(input, target, threshold=0.5, ndim=None, reduction="mean"): """ Measure recall of probability input. Binarize with a threshold, then measure the recall compared to a ground truth target. Parameters ---------- input : torch.Tensor Input tensor. target : torch.Tensor Target tensor, the same shape as ``input``. threshold : float, optional Threshold which entries in ``input`` and ``target`` must exceed to be binarised as the positive class. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The recall of ``input`` as compared to ``target`` after thresholding. The fraction of true positive cases, `target > 0.5`, which are true positive cases (`input > 0.5 and `target > 0.5`). If there are no true positives, the output is ``1``. """ # Binarise and reshape masks input = _binarise_and_reshape(input, threshold=threshold, ndim=ndim) target = _binarise_and_reshape(target, threshold=threshold, ndim=ndim) # Measure true positives and actual positives true_p = (input & target).sum(-1) ground_truth_p = target.sum(-1) output = true_p.float() / ground_truth_p.float() # Handle division by 0: If there were no positives to find, all were found. output[ground_truth_p == 0] = 1 # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_recall_with_logits(input, *args, **kwargs): """ Measure recall of logit input. Pass through sigmoid, binarize, then measure recall of ground truth target. See Also -------- mask_recall """ return mask_recall(torch.sigmoid(input), *args, **kwargs)
[docs]def mask_f1_score(input, target, reduction="mean", **kwargs): """ Measure F1-score of probability input. Binarize, then measure the F1-score of the input vs ground truth target, Parameters ---------- input : torch.Tensor Input tensor. target : torch.Tensor Target tensor, the same shape as ``input``. threshold : float, optional Threshold which entries in ``input`` and ``target`` must exceed to be binarised as the positive class. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The F1-score of ``input`` as compared to ``target`` after thresholding. The F1-score is the harmonic mean of precision and recall. See Also -------- mask_precision mask_recall """ precision = mask_precision(input, target, reduction="none", **kwargs) recall = mask_recall(input, target, reduction="none", **kwargs) sum_pr = precision + recall output = 2 * precision * recall / sum_pr # Handle division by 0: If both precision and recall are 0, f1 is 0 too. output[sum_pr == 0] = 0 # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_f1_score_with_logits(input, *args, **kwargs): """ Measure F1-score of logit input. Convert logits to probabilities with sigmoid, apply a threshold, then measure the F1-score of the tensor as compared to ground truth. See Also -------- mask_f1_score """ return mask_f1_score(torch.sigmoid(input), *args, **kwargs)
[docs]def mask_jaccard_index(input, target, threshold=0.5, ndim=None, reduction="mean"): """ Measure Jaccard Index from probabilities. Measure the Jaccard Index (intersection over union) of the input as compared to a ground truth target, after binarising with a threshold. Parameters ---------- input : torch.Tensor Input tensor. target : torch.Tensor Target tensor, the same shape as ``input``. threshold : float, optional Threshold which entries in ``input`` and ``target`` must exceed to be binarised as the positive class. Default is ``0.5``. ndim : int or None Number of dimensions to keep. If ``None``, only the first (batch) dimension is kept and the rest are flattened. Default is ``None``. reduction : ``"none"`` or ``"mean"`` or ``"sum"``, optional Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. Returns ------- torch.Tensor The Jaccard Index of ``input`` as compared to ``target``. The Jaccard Index is the number of elements where both ``input`` and ``target`` exceed ``threshold``, divided by the number of elements where at least one of ``input`` and ``target`` exceeds ``threshold``. """ # Binarise and reshape masks input = _binarise_and_reshape(input, threshold=threshold, ndim=ndim) target = _binarise_and_reshape(target, threshold=threshold, ndim=ndim) # Use bitwise operators to determine intersection and union of masks intersect = input & target union = input | target # Use number of pixels at which intersect and union are activated intersect = intersect.sum(-1) union = union.sum(-1) output = intersect.float() / union.float() # Handle division by 0: If there is no union, the two masks match completely. output[union == 0] = 1 # Apply reduction if reduction == "none": return output elif reduction == "mean": return output.mean() elif reduction == "sum": return output.sum() else: raise ValueError("Unsupported reduction value: {}".format(reduction))
[docs]def mask_jaccard_index_with_logits(input, *args, **kwargs): """ Measure Jaccard Index from logits. Convert logits to probabilities with sigmoid, apply a threshold, then measure the Jaccard Index (intersection over union) of the tensor as compared to ground truth. See Also -------- mask_jaccard_index """ return mask_jaccard_index(torch.sigmoid(input), *args, **kwargs)