"""
Model wrapper.
"""
# This file is part of Echofilter.
#
# Copyright (C) 2020-2022 Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from .utils import TensorDict, logavgexp
[docs]class Echofilter(nn.Module):
"""
Echofilter logit mapping wrapper.
Parameters
----------
model : :class:`torch.nn.Module`
The model backbone, which converts inputs to logits.
top : str, optional
Type of output for top line and surface line. If ``"mask"``, the top
output corresponds to logits, which are converted into probabilities
with sigmoid. If ``"boundary"`` (default), the output corresponds to
logits for the location of the line, which is converted into a
probability mask using softmax and cumsum.
bottom : str, optional
As for ``top``, but for the bottom line. Default is ``"boundary"``.
mapping : dict or None, optional
Mapping from logit names to output channels provided by ``model``.
If ``None``, a default mapping is used. The mapping is stored as
``self.mapping``.
reduction_ispassive : str, default="logavgexp"
Method used to reduce the depths dimension for the ``"logit_is_passive"``
output.
reduction_isremoved : str , default="logavgexp"
Method used to reduce the depths dimension for the ``"logit_is_removed"``
output.
conditional : bool, optional
Whether to build a conditional model as well as an unconditional model.
If ``True``, there are additional logits in the call output named
``"x|downfacing"`` and ``"x|upfacing"``, in addition to
``"x"``. For instance, ``"p_is_above_turbulence|downfacing"``. Default is ``False``.
"""
aliases = [("top", "turbulence")]
def __init__(
self,
model,
top="boundary",
bottom="boundary",
mapping=None,
reduction_ispassive="logavgexp",
reduction_isremoved="logavgexp",
conditional=False,
):
super(Echofilter, self).__init__()
self.model = model
self.params = {
"top": top,
"bottom": bottom,
"reduction_ispassive": reduction_ispassive,
"reduction_isremoved": reduction_isremoved,
"conditional": conditional,
}
if mapping is None:
mapping = {
"logit_is_above_turbulence": 0,
"logit_is_boundary_turbulence": 0,
"logit_is_below_bottom": 1,
"logit_is_boundary_bottom": 1,
"logit_is_removed": 2,
"logit_is_passive": 3,
"logit_is_patch": 4,
"logit_is_above_surface": 5,
"logit_is_boundary_surface": 5,
"logit_is_above_turbulence-original": 6,
"logit_is_boundary_turbulence-original": 6,
"logit_is_below_bottom-original": 7,
"logit_is_boundary_bottom-original": 7,
"logit_is_patch-original": 8,
"logit_is_patch-ntob": 9,
}
if top == "boundary":
mapping.pop("logit_is_above_turbulence")
mapping.pop("logit_is_above_turbulence-original")
mapping.pop("logit_is_above_surface")
else:
mapping.pop("logit_is_boundary_turbulence")
mapping.pop("logit_is_boundary_turbulence-original")
mapping.pop("logit_is_boundary_surface")
if bottom == "boundary":
mapping.pop("logit_is_below_bottom")
mapping.pop("logit_is_below_bottom-original")
else:
mapping.pop("logit_is_boundary_bottom")
mapping.pop("logit_is_boundary_bottom-original")
self.mapping = mapping
# Ensure all references for aliases are set
mapping_extra = {}
for key in mapping:
for alias_map in self.aliases:
for (alias_a, alias_b) in [alias_map, alias_map[::-1]]:
if "_" + alias_a not in key:
continue
alt_key = key.replace("_" + alias_a, "_" + alias_b)
if alt_key not in mapping:
mapping_extra[alt_key] = mapping[key]
mapping.update(mapping_extra)
self.conditions = [""]
if conditional:
self.conditions += ["downfacing", "upfacing"]
self.n_outputs_per_condition = max(self.mapping.values())
[docs] def forward(self, x, output_device=None):
logits = self.model(x)
outputs = TensorDict()
for i_condition, condition in enumerate(self.conditions):
# Define the condition string
cs = condition
if cs != "":
cs = "|" + cs
# Define the logit index offset
ofs = i_condition * self.n_outputs_per_condition
# Include raw logits in output
for key, index in self.mapping.items():
outputs[key + cs] = logits[:, index + ofs]
# Flatten some outputs which are vectors not arrays
if self.params["reduction_isremoved"] == "mean":
outputs["logit_is_removed" + cs] = torch.mean(
outputs["logit_is_removed" + cs], dim=-1
).to(device=output_device)
elif self.params["reduction_isremoved"] in {"logavgexp", "lae"}:
outputs["logit_is_removed" + cs] = logavgexp(
outputs["logit_is_removed" + cs], dim=-1
).to(device=output_device)
else:
raise ValueError(
"Unsupported reduction_isremoved value: {}".format(
self.params["reduction_isremoved"]
)
)
if self.params["reduction_ispassive"] == "mean":
outputs["logit_is_passive" + cs] = torch.mean(
outputs["logit_is_passive" + cs], dim=-1
).to(device=output_device)
elif self.params["reduction_ispassive"] in {"logavgexp", "lae"}:
outputs["logit_is_passive" + cs] = logavgexp(
outputs["logit_is_passive" + cs], dim=-1
).to(device=output_device)
else:
raise ValueError(
"Unsupported reduction_ispassive value: {}".format(
self.params["reduction_ispassive"]
)
)
# Convert logits to probabilities
outputs["p_is_removed" + cs] = torch.sigmoid(
outputs["logit_is_removed" + cs]
).to(device=output_device)
outputs["p_is_passive" + cs] = torch.sigmoid(
outputs["logit_is_passive" + cs]
).to(device=output_device)
for sfx in ("turbulence", "turbulence-original", "surface"):
if self.params["top"] == "mask":
outputs["p_is_above_" + sfx + cs] = torch.sigmoid(
outputs["logit_is_above_" + sfx + cs]
).to(device=output_device)
outputs["p_is_below_" + sfx + cs] = (
1 - outputs["p_is_above_" + sfx + cs]
).to(device=output_device)
elif self.params["top"] == "boundary":
outputs["p_is_boundary_" + sfx + cs] = F.softmax(
outputs["logit_is_boundary_" + sfx + cs], dim=-1
).to(device=output_device)
outputs["p_is_above_" + sfx + cs] = torch.flip(
torch.cumsum(
torch.flip(
outputs["p_is_boundary_" + sfx + cs], dims=(-1,)
),
dim=-1,
),
dims=(-1,),
).to(device=output_device)
outputs["p_is_below_" + sfx + cs] = torch.cumsum(
outputs["p_is_boundary_" + sfx + cs], dim=-1
).to(device=output_device)
# Due to floating point precision, max value can exceed 1.
# Fix this by clipping the values to the appropriate range.
outputs["p_is_above_" + sfx + cs].clamp_(0, 1)
outputs["p_is_below_" + sfx + cs].clamp_(0, 1)
else:
raise ValueError(
'Unsupported "top" parameter: {}'.format(self.params["top"])
)
for sfx in ("bottom", "bottom-original"):
if self.params["bottom"] == "mask":
outputs["p_is_below_" + sfx + cs] = torch.sigmoid(
outputs["logit_is_below_" + sfx + cs]
).to(device=output_device)
outputs["p_is_above_" + sfx + cs] = (
1 - outputs["p_is_below_" + sfx + cs]
).to(device=output_device)
elif self.params["bottom"] == "boundary":
outputs["p_is_boundary_" + sfx + cs] = F.softmax(
outputs["logit_is_boundary_" + sfx + cs], dim=-1
).to(device=output_device)
outputs["p_is_below_" + sfx + cs] = torch.cumsum(
outputs["p_is_boundary_" + sfx + cs], dim=-1
).to(device=output_device)
outputs["p_is_above_" + sfx + cs] = torch.flip(
torch.cumsum(
torch.flip(
outputs["p_is_boundary_" + sfx + cs], dims=(-1,)
),
dim=-1,
),
dims=(-1,),
).to(device=output_device)
# Due to floating point precision, max value can exceed 1.
# Fix this by clipping the values to the appropriate range.
outputs["p_is_below_" + sfx + cs].clamp_(0, 1)
outputs["p_is_above_" + sfx + cs].clamp_(0, 1)
else:
raise ValueError(
'Unsupported "bottom" parameter: {}'.format(
self.params["bottom"]
)
)
for sfx in ("", "-original", "-ntob"):
outputs["p_is_patch" + sfx + cs] = torch.sigmoid(
outputs["logit_is_patch" + sfx + cs]
).to(device=output_device)
outputs["p_keep_pixel" + cs] = (
1.0
* 0.5
* (
(1 - outputs["p_is_above_turbulence" + cs])
+ outputs["p_is_below_turbulence" + cs]
)
* 0.5
* (
(1 - outputs["p_is_below_bottom" + cs])
+ outputs["p_is_above_bottom" + cs]
)
* (1 - outputs["p_is_removed" + cs].unsqueeze(-1))
* (1 - outputs["p_is_passive" + cs].unsqueeze(-1))
* (1 - outputs["p_is_patch" + cs])
).clamp_(0, 1)
outputs["mask_keep_pixel" + cs] = (
1.0
* (outputs["p_is_above_turbulence" + cs] < 0.5)
* (outputs["p_is_below_bottom" + cs] < 0.5)
* (outputs["p_is_removed" + cs].unsqueeze(-1) < 0.5)
* (outputs["p_is_passive" + cs].unsqueeze(-1) < 0.5)
* (outputs["p_is_patch" + cs] < 0.5)
)
return outputs
[docs]class EchofilterLoss(_Loss):
"""
Evaluate loss for an Echofilter model.
Parameters
----------
reduction : ``"mean"`` or ``"sum"``, optional
The reduction method, which is used to collapse batch and timestamp
dimensions. Default is ``"mean"``.
turbulence_mask : float, optional
Weighting for turbulence line/mask loss term. Default is ``1.0``.
bottom_mask : float, optional
Weighting for bottom line/mask loss term. Default is ``1.0``.
removed_segment : float, optional
Weighting for ``is_removed`` loss term. Default is ``1.0``.
passive : float, optional
Weighting for ``is_passive`` loss term. Default is ``1.0``.
patch : float, optional
Weighting for ``mask_patch`` loss term. Default is ``1.0``.
overall : float, optional
Weighting for overall mask loss term. Default is ``0.0``.
surface : float, optional
Weighting for surface line/mask loss term. Default is ``1.0``.
auxiliary : float, optional
Weighting for auxiliary loss terms ``"turbulence-original"``,
``"bottom-original"``, ``"mask_patches-original"``, and
``"mask_patches-ntob"``. Default is ``1.0``.
ignore_lines_during_passive : bool, optional
Whether targets for turbulence and bottom lines should be excluded from
the loss during passive data collection. Default is ``True``.
ignore_lines_during_removed : bool, optional
Whether targets for turbulence and bottom lines should be excluded from
the loss during entirely removed sections. Default is ``True``.
ignore_surface_during_passive : bool, optional
Whether target for the surface line should be excluded from the loss
during passive data collection. Default is ``False``.
ignore_surface_during_removed : bool, optional
Whether target for the surface line should be excluded from the loss
during entirely removed sections. Default is ``True``.
"""
__constants__ = ["reduction"]
def __init__(
self,
reduction="mean",
conditional=False,
turbulence_mask=1.0,
bottom_mask=1.0,
removed_segment=1.0,
passive=1.0,
patch=1.0,
overall=0.0,
surface=1.0,
auxiliary=1.0,
ignore_lines_during_passive=False,
ignore_lines_during_removed=True,
ignore_surface_during_passive=False,
ignore_surface_during_removed=True,
):
super(EchofilterLoss, self).__init__(None, None, reduction)
self.conditional = conditional
self.turbulence_mask = turbulence_mask
self.bottom_mask = bottom_mask
self.removed_segment = removed_segment
self.passive = passive
self.patch = patch
self.overall = overall
self.surface = surface
self.auxiliary = auxiliary
self.ignore_lines_during_passive = ignore_lines_during_passive
self.ignore_lines_during_removed = ignore_lines_during_removed
self.ignore_surface_during_passive = ignore_surface_during_passive
self.ignore_surface_during_removed = ignore_surface_during_removed
self.conditions = [""]
if conditional:
self.conditions += ["downfacing", "upfacing"]
[docs] def forward(self, input, target):
"""
Construct loss term.
Parameters
----------
input : dict
Output from :class:`echofilter.wrapper.Echofilter` layer.
target : dict
A transect, as provided by :class:`echofilter.data.dataset.TransectDataset`.
"""
loss = 0
target["is_passive"] = target["is_passive"].to(
input["logit_is_passive"].device,
input["logit_is_passive"].dtype,
)
target["is_removed"] = target["is_removed"].to(
input["logit_is_removed"].device,
input["logit_is_removed"].dtype,
)
n_conditions_in_loss = 0
for condition in self.conditions:
closs = 0
with torch.no_grad():
if condition == "":
cs = condition
cmask = torch.ones_like(target["is_upward_facing"])
else:
cs = "|" + condition
if condition == "upfacing":
cmask = target["is_upward_facing"] > 0.5
elif condition == "downfacing":
cmask = target["is_upward_facing"] < 0.5
else:
raise ValueError("Unsupported condition: {}".format(condition))
n_samples_in_condition = torch.sum(cmask)
if n_samples_in_condition.cpu().item() == 0:
# No samples in this batch match this condition
continue
cmask = cmask.to(torch.float32)
n_conditions_in_loss += 1
for sfx in ("turbulence", "turbulence-original", "surface"):
with torch.no_grad():
loss_inclusion_mask = (target["is_bad_labels"] < 1e-7).to(
torch.float32
)
if sfx == "surface":
target_key = "mask_surface"
target_i_key = "index_surface"
weight = self.surface
# Don't include surrogate surface datapoints in loss
loss_inclusion_mask *= (
target["is_surrogate_surface"] < 1e-7
).to(torch.float32)
# Check whether surface line is masked out
if self.ignore_surface_during_passive:
loss_inclusion_mask *= 1 - target["is_passive"]
if self.ignore_surface_during_removed:
loss_inclusion_mask *= 1 - target["is_removed"]
else:
target_key = "mask_" + sfx
target_i_key = "index_" + sfx
weight = self.turbulence_mask
if sfx != "turbulence":
weight *= self.auxiliary
# Check whether line is masked out
if self.ignore_lines_during_passive:
loss_inclusion_mask *= 1 - target["is_passive"]
if self.ignore_lines_during_removed:
loss_inclusion_mask *= 1 - target["is_removed"]
if not weight:
continue
elif "logit_is_boundary_" + sfx in input:
# Load cross-entropy class target
C = target[target_i_key].to(
device=input["logit_is_boundary_" + sfx + cs].device,
dtype=torch.long,
)
loss_term = F.cross_entropy(
input["logit_is_boundary_" + sfx + cs].transpose(-2, -1),
C,
reduction="none",
)
loss_term *= cmask.unsqueeze(-1)
loss_term *= loss_inclusion_mask
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError(
"Unsupported reduction: {}".format(self.reduction)
)
elif "logit_is_above_" + sfx in input:
warnings.warn(
'Using loss corresponding to "mask" logits.'
' The "boundary" is recommended instead.'
" The loss component for this line will be"
" F.binary_cross_entropy_with_logits(input[{}], target[{}])"
"".format("logit_is_above_" + sfx + cs, target_key)
)
loss_term = F.binary_cross_entropy_with_logits(
input["logit_is_above_" + sfx + cs],
target[target_key].to(
input["logit_is_above_" + sfx + cs].device,
input["logit_is_above_" + sfx + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1).unsqueeze(-1)
loss_term *= loss_inclusion_mask.unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError(
"Unsupported reduction: {}".format(self.reduction)
)
else:
raise ValueError(
"The input does not contain either {} or {} fields."
" At least one of these is required if the loss term weighting"
" is non-zero.".format(
"logit_is_boundary_" + sfx,
"logit_is_above_" + sfx,
)
)
if torch.isnan(loss_term).any():
print("Loss term {} is NaN".format(target_key))
else:
closs += weight * loss_term
for sfx in ("", "-original"):
weight = self.bottom_mask
if sfx != "":
weight *= self.auxiliary
with torch.no_grad():
loss_inclusion_mask = (target["is_bad_labels"] < 1e-7).to(
torch.float32
)
# Check whether line is masked out
if self.ignore_lines_during_passive:
loss_inclusion_mask *= 1 - target["is_passive"]
if self.ignore_lines_during_removed:
loss_inclusion_mask *= 1 - target["is_removed"]
if not weight:
continue
elif "logit_is_boundary_bottom" + sfx in input:
# Load cross-entropy class target
C = target["index_bottom" + sfx].to(
device=input["logit_is_boundary_bottom" + sfx + cs].device,
dtype=torch.long,
)
loss_term = F.cross_entropy(
input["logit_is_boundary_bottom" + sfx + cs].transpose(-2, -1),
C,
reduction="none",
)
loss_term *= cmask.unsqueeze(-1)
loss_term *= loss_inclusion_mask
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError(
"Unsupported reduction: {}".format(self.reduction)
)
elif "logit_is_below_bottom" + sfx in input:
warnings.warn(
'Using loss corresponding to "mask" logits.'
' The "boundary" is recommended instead.'
" The loss component for this line will be"
" F.binary_cross_entropy_with_logits(input[{}], target[{}])"
"".format("logit_is_below_bottom" + sfx + cs, target_key)
)
loss_term = F.binary_cross_entropy_with_logits(
input["logit_is_below_bottom" + sfx + cs],
target["mask_bottom" + sfx].to(
input["logit_is_below_bottom" + sfx + cs].device,
input["logit_is_below_bottom" + sfx + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1).unsqueeze(-1)
loss_term *= loss_inclusion_mask.unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError(
"Unsupported reduction: {}".format(self.reduction)
)
else:
raise ValueError(
"The input does not contain either {} or {} fields."
" At least one of these is required if the loss term weighting"
" is non-zero.".format(
"logit_is_boundary_bottom" + sfx,
"logit_is_below_bottom" + sfx,
)
)
if torch.isnan(loss_term).any():
print("Loss term mask_bottom{} is NaN".format(sfx))
else:
closs += weight * loss_term
if self.removed_segment:
loss_term = F.binary_cross_entropy_with_logits(
input["logit_is_removed" + cs],
target["is_removed"].to(
input["logit_is_removed" + cs].device,
input["logit_is_removed" + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError("Unsupported reduction: {}".format(self.reduction))
if torch.isnan(loss_term).any():
print("Loss term is_removed is NaN")
else:
closs += self.removed_segment * loss_term
if self.passive:
loss_term = self.passive * F.binary_cross_entropy_with_logits(
input["logit_is_passive" + cs],
target["is_passive"].to(
input["logit_is_passive" + cs].device,
input["logit_is_passive" + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError("Unsupported reduction: {}".format(self.reduction))
if torch.isnan(loss_term).any():
print("Loss term is_passive is NaN")
else:
closs += self.passive * loss_term
for sfx in ("", "-original", "-ntob"):
weight = self.patch
if sfx != "":
weight *= self.auxiliary
if not weight:
continue
loss_term = F.binary_cross_entropy_with_logits(
input["logit_is_patch" + sfx + cs],
target["mask_patches" + sfx].to(
input["logit_is_patch" + sfx + cs].device,
input["logit_is_patch" + sfx + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1).unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError("Unsupported reduction: {}".format(self.reduction))
if torch.isnan(loss_term).any():
print("Loss term mask_patches{} is NaN".format(sfx))
else:
closs += weight * loss_term
if self.overall:
loss_term = self.overall * F.binary_cross_entropy(
input["p_keep_pixel" + cs],
target["mask"].to(
input["p_keep_pixel" + cs].device,
input["p_keep_pixel" + cs].dtype,
),
reduction="none",
)
loss_term *= cmask.unsqueeze(-1).unsqueeze(-1)
if self.reduction == "mean":
loss_term = torch.mean(loss_term)
elif self.reduction == "sum":
loss_term = torch.sum(loss_term)
elif self.reduction != "none":
raise ValueError("Unsupported reduction: {}".format(self.reduction))
if torch.isnan(loss_term).any():
print("Loss term overall is NaN")
else:
closs += self.overall * loss_term
loss += closs
# Avoid double-counting the loss
if n_conditions_in_loss > 1:
loss /= 2
return loss