#!/usr/bin/env python
"""
Model training routine.
"""
# This file is part of Echofilter.
#
# Copyright (C) 2020-2022 Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import copy
import datetime
import os
import pprint
import shutil
import sys
import time
import traceback
from collections import OrderedDict
try:
import apex
except ImportError:
apex = None
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn
import torch.optim
import torch.utils.data
import torchvision.transforms
import echofilter.data
import echofilter.optim.utils
import echofilter.raw.shardloader
from echofilter.nn.unet import UNet
from echofilter.nn.utils import count_parameters, seed_all
from echofilter.nn.wrapper import Echofilter, EchofilterLoss
from echofilter.optim import criterions, schedulers
from echofilter.optim.meters import AverageMeter, ProgressMeter
from echofilter.plotting import plot_transect_predictions
from echofilter.raw.loader import get_partition_list
from echofilter.raw.manipulate import load_decomposed_transect_mask
from echofilter.ui.train_cli import main
# --- For mobile dataset,
# DATA_CENTER = -81.5
# DATA_DEVIATION = 21.9
# CENTER_METHOD = "mean"
# DEVIATION_METHOD = "stdev"
# --- For stationary dataset,
# DATA_CENTER = -78.7
# DATA_DEVIATION = 19.2
# CENTER_METHOD = "mean"
# DEVIATION_METHOD = "stdev"
# --- For intermediate values between both datasets
# DATA_CENTER = -80.
# DATA_DEVIATION = 20.
# CENTER_METHOD = "mean"
# DEVIATION_METHOD = "stdev"
# NAN_VALUE = -3
# --- Overall values to use
# DATA_CENTER = -97.5
# DATA_DEVIATION = 16.5
# CENTER_METHOD = "pc10"
# DEVIATION_METHOD = "idr"
# NAN_VALUE = -1
# Normalise each sample independently, based on its own intensity distribution
CENTER_METHOD = "median"
DEVIATION_METHOD = "idr"
DATA_CENTER = CENTER_METHOD
DATA_DEVIATION = DEVIATION_METHOD
NAN_VALUE = -3
# Transects to plot for debugging
PLOT_TRANSECTS = {
"mobile": [
"mobile/Survey07/Survey07_GR4_N5W_survey7",
"mobile/Survey14/Survey14_GR4_N0W_E",
"mobile/Survey16/Survey16_GR4_N5W_E",
"mobile/Survey17/Survey17_GR4_N5W_E",
],
"MinasPassage": [
"MinasPassage/december2017/december2017_D20180213-T115216_D20180213-T172216",
"MinasPassage/march2018/march2018_D20180513-T195216_D20180514-T012216",
"MinasPassage/september2018/september2018_D20181027-T202217_D20181028-T015217",
"MinasPassage/september2018/september2018_D20181107-T122220_D20181107-T175217",
],
"GrandPassage": [
"GrandPassage/phase2/GrandPassage_WBAT_2B_20200125_UTC160020_ebblow",
"GrandPassage/phase2/GrandPassage_WBAT_2B_20200202_UTC040019_floodhigh",
],
}
DEFAULT_CROP_DEPTH_PLOTS = 100
MAX_INPUT_LEN = 3500
[docs]def train(
data_dir="/data/dsforce/surveyExports",
dataset_name="mobile",
train_partition=None,
val_partition=None,
sample_shape=(128, 512),
crop_depth=None,
resume="",
restart="",
log_name=None,
log_name_append=None,
conditional=False,
n_block=6,
latent_channels=32,
expansion_factor=1,
expand_only_on_down=False,
blocks_per_downsample=(2, 1),
blocks_before_first_downsample=(2, 1),
always_include_skip_connection=True,
deepest_inner="horizontal_block",
intrablock_expansion=6,
se_reduction=4,
downsampling_modes="max",
upsampling_modes="bilinear",
depthwise_separable_conv=True,
residual=True,
actfn="InplaceReLU",
kernel_size=5,
use_mixed_precision=None,
amp_opt="O1",
device="cuda",
multigpu=False,
n_worker=8,
batch_size=16,
stratify=True,
n_epoch=20,
seed=None,
print_freq=50,
optimizer="adam",
schedule="constant",
lr=0.1,
momentum=0.9,
base_momentum=None,
weight_decay=1e-5,
warmup_pct=0.2,
warmdown_pct=0.7,
anneal_strategy="cos",
overall_loss_weight=0.0,
):
"""
Train a model.
"""
if restart and not resume:
raise ValueError(
"A checkpoint must be provided to restart from when doing a cold restart"
)
# Lazy import of tensorboard, so training-only requirements are not needed
# for automated documentation building.
from torch.utils.tensorboard import SummaryWriter
seed_all(seed)
# Can't get this to be deterministic anyway, so may as well keep the
# non-deterministic optimization enabled
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
# Input handling
schedule = schedule.lower()
if base_momentum is None:
base_momentum = momentum
if log_name is None or log_name == "":
log_name = datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
if log_name_append is None:
log_name_append = os.uname()[1]
if len(log_name_append) > 0:
log_name += "_" + log_name_append
print("Output will be written to {}/{}".format(dataset_name, log_name))
if use_mixed_precision is None:
use_mixed_precision = "cpu" not in device
if use_mixed_precision and apex is None:
print("NVIDIA apex must be installed to use mixed precision.")
use_mixed_precision = False
# Need to set the default device for apex.amp
if device is not None and device != "cpu":
torch.cuda.set_device(torch.device(device))
# Build dataset
dataset_train, dataset_val, dataset_augval = build_dataset(
dataset_name,
data_dir,
sample_shape,
train_partition=train_partition,
val_partition=val_partition,
)
print("Train dataset has {:4d} samples".format(len(dataset_train)))
print("Val dataset has {:4d} samples".format(len(dataset_val)))
stratify_allowed = stratify
if stratify_allowed:
stratify = isinstance(dataset_train, echofilter.data.dataset.ConcatDataset)
if stratify:
# Use custom stratified sampler to handle multiple datasets
sampler = echofilter.data.dataset.StratifiedRandomSampler(dataset_train)
else:
sampler = None
loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=batch_size,
shuffle=not stratify,
sampler=sampler,
num_workers=n_worker,
pin_memory=True,
drop_last=True,
worker_init_fn=echofilter.data.utils.worker_seed_fn,
)
loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=batch_size,
shuffle=False,
num_workers=n_worker,
pin_memory=True,
drop_last=False,
worker_init_fn=echofilter.data.utils.worker_staticseed_fn,
)
loader_augval = torch.utils.data.DataLoader(
dataset_augval,
batch_size=batch_size,
shuffle=False,
num_workers=n_worker,
pin_memory=True,
drop_last=False,
worker_init_fn=echofilter.data.utils.worker_staticseed_fn,
)
print("Train loader has {:3d} batches".format(len(loader_train)))
print("Val loader has {:3d} batches".format(len(loader_val)))
print()
print(
"Constructing U-Net model with "
"{} blocks, "
"initial latent channels {}, "
"expansion_factor {}".format(n_block, latent_channels, expansion_factor)
)
model_parameters = dict(
in_channels=1,
out_channels=10,
initial_channels=latent_channels,
bottleneck_channels=latent_channels,
n_block=n_block,
unet_expansion_factor=expansion_factor,
expand_only_on_down=expand_only_on_down,
blocks_per_downsample=blocks_per_downsample,
blocks_before_first_downsample=blocks_before_first_downsample,
always_include_skip_connection=always_include_skip_connection,
deepest_inner=deepest_inner,
intrablock_expansion=intrablock_expansion,
se_reduction=se_reduction,
downsampling_modes=downsampling_modes,
upsampling_modes=upsampling_modes,
depthwise_separable_conv=depthwise_separable_conv,
residual=residual,
actfn=actfn,
kernel_size=kernel_size,
)
if conditional:
model_parameters["out_channels"] *= 3
print()
pprint.pprint(model_parameters)
print()
unet = UNet(**model_parameters)
unet.to(device)
print(
"Built UNet model with {} trainable parameters".format(
count_parameters(unet, only_trainable=True)
)
)
# define loss function (criterion) and optimizer
criterion = EchofilterLoss(conditional=conditional, overall=overall_loss_weight)
optimizer_name = optimizer.lower()
if optimizer_name == "adam":
optimizer_class = torch.optim.Adam
elif optimizer_name == "adamw":
optimizer_class = torch.optim.AdamW
elif optimizer_name == "ranger":
import ranger
optimizer_class = ranger.Ranger
elif optimizer_name == "rangerva":
import ranger
optimizer_class = ranger.RangerVA
elif optimizer_name == "rangerqh":
import ranger
optimizer_class = ranger.RangerQH
else:
# We don't support arbitrary optimizers from torch.optim because they
# need different configuration parameters to Adam.
raise ValueError("Unrecognised optimizer: {}".format(optimizer))
optimizer = optimizer_class(
unet.parameters(),
lr,
betas=(momentum, 0.999),
weight_decay=weight_decay,
)
schedule_data = {"name": schedule}
lr_initial_div_factor = 1e3
lr_final_div_factor = 1e5
if schedule == "lrfinder":
pass
elif schedule == "constant":
pass
elif schedule == "onecycle":
schedule_data["scheduler"] = schedulers.OneCycleLR(
optimizer,
max_lr=lr,
steps_per_epoch=len(loader_train),
epochs=n_epoch,
pct_start=warmup_pct,
anneal_strategy=anneal_strategy,
cycle_momentum=True,
base_momentum=base_momentum,
max_momentum=momentum,
div_factor=lr_initial_div_factor,
final_div_factor=lr_final_div_factor,
)
elif schedule == "mesaonecycle":
schedule_data["scheduler"] = schedulers.MesaOneCycleLR(
optimizer,
max_lr=lr,
steps_per_epoch=len(loader_train),
epochs=n_epoch,
pct_start=warmup_pct,
pct_end=warmdown_pct,
anneal_strategy=anneal_strategy,
cycle_momentum=True,
base_momentum=base_momentum,
max_momentum=momentum,
div_factor=lr_initial_div_factor,
final_div_factor=lr_final_div_factor,
)
else:
raise ValueError("Unsupported schedule: {}".format(schedule))
if use_mixed_precision:
print('Converting unet to mixed precision, opt="{}"'.format(amp_opt))
unet, optimizer = apex.amp.initialize(unet, optimizer, opt_level=amp_opt)
model_inner = unet
if multigpu and torch.cuda.device_count() > 1:
print("Using local parallelism on {} GPUs".format(torch.cuda.device_count()))
model_inner = torch.nn.DataParallel(model_inner)
# Add UI wrapper around model
model = Echofilter(
model_inner,
top="boundary",
bottom="boundary",
conditional=conditional,
)
if schedule == "lrfinder":
from torch_lr_finder import LRFinder
print("Running learning rate finder")
lr_finder = LRFinder(model, optimizer, criterion, device=device)
lr_finder.range_test(loader_train, end_lr=100, num_iter=100, diverge_th=3)
print("Plotting learning rate finder results")
hf = plt.figure(figsize=(15, 9))
ax = plt.axes()
lr_finder.plot(skip_start=0, skip_end=1, log_lr=True, ax=ax)
plt.tick_params(reset=True, color=(0.2, 0.2, 0.2))
plt.tick_params(labelsize=14)
ax.minorticks_on()
ax.tick_params(direction="out")
# Save figure
figpth = os.path.join("models", dataset_name, log_name, "lrfinder.png")
os.makedirs(os.path.dirname(figpth), exist_ok=True)
plt.savefig(figpth)
print("LR Finder results saved to {}".format(figpth))
return
# Initialise loop tracking
start_epoch = 1
best_loss_val = float("inf")
# optionally resume from a checkpoint
if resume:
if not os.path.isfile(resume):
raise EnvironmentError("No checkpoint found at '{}'".format(resume))
print("Loading checkpoint '{}'".format(resume))
if device is None:
checkpoint = torch.load(resume)
else:
# Map model to be loaded to specified single gpu.
checkpoint = torch.load(resume, map_location=device)
try:
unet.load_state_dict(checkpoint["state_dict"])
except RuntimeError:
print(
"Checkpoint doesn't seem to be for the UNet component. Trying"
" to load it as the whole model instead."
)
model.load_state_dict(checkpoint["state_dict"])
if restart == "cold":
print("Loaded checkpoint '{}' for cold restart".format(resume))
else:
optimizer.load_state_dict(checkpoint["optimizer"])
if use_mixed_precision and "amp" in checkpoint:
apex.amp.load_state_dict(checkpoint["amp"])
if restart == "warm":
# We loaded the optimizer state from the checkpoint to get the
# current buffer values (the weighted history of updates).
# However, this also changed our schedule settings to the ones
# used to train the model in the checkpoint, so we need to
# overwrite it with the schedule initialisation step now.
for group in optimizer.param_groups:
if "onecycle" in schedule:
group["lr"] = group["initial_lr"] = lr / lr_initial_div_factor
group["max_lr"] = lr
group["min_lr"] = group["initial_lr"] / lr_final_div_factor
group["base_momentum"] = base_momentum
group["max_momentum"] = momentum
else:
group["lr"] = lr
if "betas" in group:
_, beta2 = group["betas"]
group["betas"] = (momentum, beta2)
else:
group["momentum"] = momentum
print("Loaded checkpoint '{}' for warm restart".format(resume))
elif not restart:
start_epoch = checkpoint["epoch"] + 1
best_loss_val = checkpoint["best_loss"]
if "scheduler" in schedule_data:
step_num = checkpoint["epoch"] * len(loader_train)
schedule_data["scheduler"].last_epoch = step_num
print(
"Loaded checkpoint '{}' (epoch {})".format(resume, checkpoint["epoch"])
)
# Make a tensorboard writer
writer = SummaryWriter(log_dir=os.path.join("runs", dataset_name, log_name))
print("Starting training")
t_start = time.time()
for epoch in range(start_epoch, n_epoch + 1):
t_epoch_start = time.time()
# Set the seed state at the start of each epoch
# This ensures some level of consistency between models which continue
# running and those resumed from checkpoints (don't want to run the
# exact initial augmented samples from the first epoch again when
# resuming later in the run)
seed_all(1000 + seed + epoch)
# Can't get this to be deterministic anyway, so may as well keep the
# non-deterministic optimization enabled
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
# Resample offsets for each window
loader_train.dataset.initialise_datapoints()
# train for one epoch
(
loss_tr,
meters_tr,
(ex_input_tr, ex_data_tr, ex_output_tr),
(batch_time, data_time),
) = train_epoch(
loader_train,
model,
criterion,
optimizer,
device,
epoch,
print_freq=print_freq,
schedule_data=schedule_data,
use_mixed_precision=use_mixed_precision,
)
t_val_start = time.time()
# evaluate on validation set
loss_val, meters_val, (ex_input_val, ex_data_val, ex_output_val) = validate(
loader_val,
model,
criterion,
device,
print_freq=print_freq,
prefix="Validation",
)
# evaluate on augmented validation set
(
loss_augval,
meters_augval,
(ex_input_augval, ex_data_augval, ex_output_augval),
) = validate(
loader_augval,
model,
criterion,
device,
print_freq=print_freq,
prefix="Aug-Val ",
)
t_val_end = time.time()
print(
"{}/{}\nCompleted {} of {} epochs in {}".format(
dataset_name,
log_name,
epoch,
n_epoch,
datetime.timedelta(seconds=time.time() - t_start),
)
)
# Print metrics to terminal
name_fmt = "{:.<28s}"
current_lr = echofilter.optim.utils.get_current_lr(optimizer)
current_mom = echofilter.optim.utils.get_current_momentum(optimizer)
print((name_fmt + " {:.4e}").format("Learning rate", current_lr))
print((name_fmt + " {:.4f}").format("Momentum", current_mom))
print(
(name_fmt + " Train: {:.4e} AugVal: {:.4e} Val: {:.4e}").format(
"Loss", loss_tr, loss_augval, loss_val
)
)
for chn in meters_tr:
if chn.lower() != "overall":
continue
# For each output plane
print(chn)
for cr in meters_tr[chn]:
# For each criterion
fmt_str = name_fmt
fmt_str += " Train: {" + meters_tr[chn][cr].fmt + "}"
fmt_str += " AugVal: {" + meters_augval[chn][cr].fmt + "}"
fmt_str += " Val: {" + meters_val[chn][cr].fmt + "}"
print(
fmt_str.format(
meters_tr[chn][cr].name,
meters_tr[chn][cr].avg,
meters_augval[chn][cr].avg,
meters_val[chn][cr].avg,
)
)
# Add hyper parameters to tensorboard
writer.add_scalar("learning_rate", current_lr, epoch)
writer.add_scalar("momentum", current_mom, epoch)
writer.add_scalar(
"parameter_count", count_parameters(model, only_trainable=True), epoch
)
# Add metrics to tensorboard
for loss_p, partition in (
(loss_tr, "Train"),
(loss_val, "Val"),
(loss_augval, "ValAug"),
):
writer.add_scalar("{}/{}".format("Loss", partition), loss_p, epoch)
for chn in meters_tr:
# For each output plane
for cr in meters_tr[chn]:
# For each criterion
writer.add_scalar(
"{}/{}/{}".format(cr, chn, "Train"), meters_tr[chn][cr].avg, epoch
)
writer.add_scalar(
"{}/{}/{}".format(cr, chn, "ValAug"),
meters_augval[chn][cr].avg,
epoch,
)
writer.add_scalar(
"{}/{}/{}".format(cr, chn, "Val"), meters_val[chn][cr].avg, epoch
)
# Determine whether to generate sample transect plots, or skip them
if n_epoch < 20:
# Every epoch
generate_sample_images = True
else:
# Every 20th of the way through training
generate_sample_images = (
int(20 * epoch / n_epoch) > 20 * (epoch - 1) / n_epoch
)
# But always generate samples for first two epochs and last epoch
if epoch <= 2 or epoch == n_epoch:
generate_sample_images = True
def ensure_clim_met(x, x0=0.0, x1=1.0):
x = x.clone()
x[0, :, 0, 0] = 0
x[0, :, 0, 1] = 1
return x
def add_image_border(x):
"""
Add a green border around a a tensor of images.
Parameters
----------
x : torch.Tensor
Tensor in NCWH or NCHW format.
Returns
-------
torch.Tensor
As ``x``, but padded with a green border.
"""
if x.shape[1] == 1:
x = torch.cat([x, x, x], dim=1)
if x.shape[1] != 3:
raise ValueError("RGB image needs three color channels")
shp = list(x.shape)
shp[-1] = 1
x = torch.cat(
[
torch.zeros(shp, dtype=x.dtype, device=x.device),
x,
torch.zeros(shp, dtype=x.dtype, device=x.device),
],
dim=-1,
)
shp = list(x.shape)
shp[-2] = 1
x = torch.cat(
[
torch.zeros(shp, dtype=x.dtype, device=x.device),
x,
torch.zeros(shp, dtype=x.dtype, device=x.device),
],
dim=-2,
)
x[:, 1, :, 0] = 1.0
x[:, 1, :, -1] = 1.0
x[:, 1, 0, :] = 1.0
x[:, 1, -1, :] = 1.0
return x
# Add example images to tensorboard
for (ex_input, ex_data, ex_output), partition in (
((ex_input_tr, ex_data_tr, ex_output_tr), "Train"),
((ex_input_val, ex_data_val, ex_output_val), "Val"),
((ex_input_augval, ex_data_augval, ex_output_augval), "ValAug"),
):
if not generate_sample_images:
continue
writer.add_images(
"Input/" + partition,
ex_input,
epoch,
dataformats="NCWH",
)
writer.add_images(
"Overall/" + partition + "/Target",
ensure_clim_met(add_image_border(ex_data["mask"].float().unsqueeze(1))),
epoch,
dataformats="NCWH",
)
writer.add_images(
"Overall/" + partition + "/Output/p",
ensure_clim_met(
add_image_border(ex_output["p_keep_pixel"].unsqueeze(1))
),
epoch,
dataformats="NCWH",
)
writer.add_images(
"Overall/" + partition + "/Overlap",
ensure_clim_met(
add_image_border(
torch.stack(
[
ex_output["mask_keep_pixel"].float(),
torch.zeros_like(
ex_output["mask_keep_pixel"], dtype=torch.float
),
ex_data["mask"].float(),
],
dim=1,
)
)
),
epoch,
dataformats="NCWH",
)
for k, plot_transects_k in PLOT_TRANSECTS.items():
if not generate_sample_images:
continue
if "stationary" in dataset_name and k != "mobile":
pass
elif k not in dataset_name:
continue
plot_crop_depth = crop_depth
if plot_crop_depth is None:
plot_crop_depth = DEFAULT_CROP_DEPTH_PLOTS
for transect_name in plot_transects_k:
transect, prediction = generate_from_shards(
os.path.join(data_dir + "_sharded", transect_name),
model,
sample_shape=sample_shape,
crop_depth=plot_crop_depth,
device=device,
dtype=torch.float,
)
hf = plt.figure(figsize=(15, 9))
plot_transect_predictions(
transect, prediction, cmap="viridis", linewidth=1
)
transect_name = transect_name.replace("/evExports", "")
if epoch == n_epoch:
# Only save png if this is the final epoch
figpth = os.path.join(
"models",
dataset_name,
log_name,
"samples",
transect_name + "_output.png",
)
os.makedirs(os.path.dirname(figpth), exist_ok=True)
plt.savefig(figpth)
writer.add_figure(transect_name, hf, epoch, close=True)
# remember best loss and save checkpoint
is_best = loss_val < best_loss_val
best_loss_val = min(loss_val, best_loss_val)
checkpoint = {
"model_parameters": model_parameters,
"sample_shape": sample_shape,
"epoch": epoch,
"state_dict": unet.state_dict(),
"best_loss": best_loss_val,
"optimizer": optimizer.state_dict(),
"meters": meters_val,
"data_center": DATA_CENTER,
"data_deviation": DATA_DEVIATION,
"center_method": CENTER_METHOD,
"deviation_method": DEVIATION_METHOD,
"nan_value": NAN_VALUE,
"wrapper_mapping": model.mapping,
"wrapper_params": model.params,
"training_routine": "echofilter-train {}".format(echofilter.__version__),
}
if use_mixed_precision:
checkpoint["amp"] = apex.amp.state_dict()
if (
schedule == "mesaonecycle"
and epoch / n_epoch <= warmdown_pct
and (epoch + 1) / n_epoch > warmdown_pct
):
dup = "ep{}".format(epoch)
else:
dup = None
save_checkpoint(
checkpoint,
is_best,
dirname=os.path.join("models", dataset_name, log_name),
dup=dup,
)
meters_to_csv(
meters_val, is_best, dirname=os.path.join("models", dataset_name, log_name)
)
# Note how long everything took
writer.add_scalar("time/batch", batch_time.avg, epoch)
writer.add_scalar("time/batch/data", data_time.avg, epoch)
writer.add_scalar("time/train", t_val_start - t_epoch_start, epoch)
writer.add_scalar("time/val", t_val_end - t_val_start, epoch)
writer.add_scalar("time/log", time.time() - t_val_end, epoch)
writer.add_scalar("time/epoch", time.time() - t_epoch_start, epoch)
# Ensure the tensorboard outputs for this epoch are flushed
writer.flush()
# Close tensorboard connection
writer.close()
[docs]def build_dataset(
dataset_name,
data_dir,
sample_shape,
train_partition=None,
val_partition=None,
crop_depth=None,
random_crop_args=None,
):
"""
Construct a pytorch Dataset.
Parameters
----------
dataset_name : str
Name of the dataset. This can optionally be a list of
multiple datasets joined with ``"+"``.
data_dir : str
Path to root data directory, containing the dataset.
sample_shape : iterable of length 2
The shape which will be used for training.
train_partition : str, optional
Name of the partition to use for training. Can optionally be a list of
multiple partitions joined with ``"+"``. Default is ``"train"``
(except for ``stationary2`` where it is mixed).
val_partition : str, optional
Name of the partition to use for validation. Can optionally be a list
of multiple partitions joined with ``"+"``. Default is ``"validate"``
(except for ``stationary2`` where it is mixed).
crop_depth : float or None, optional
Depth at which to crop samples. Default is ``None``.
random_crop_args : dict, optional
Arguments to control the random crop used during training. Default is
an empty dict, which uses the default arguments of
:class`echofilter.data.transforms.RandomCropDepth`.
Returns
-------
dataset_train : echofilter.data.dataset.TransectDataset
Dataset of training samples.
dataset_val : echofilter.data.dataset.TransectDataset
Dataset of validation samples.
dataset_augval : echofilter.data.dataset.TransectDataset
Dataset of validation samples, appyling the training augmentation
stack.
"""
if random_crop_args is None:
random_crop_args = {}
if "+" in dataset_name:
# Join multiple datasets together
datasets = [
build_dataset(
subdataset_name,
data_dir=data_dir,
sample_shape=sample_shape,
train_partition=train_partition,
val_partition=val_partition,
crop_depth=crop_depth,
random_crop_args=random_crop_args,
)
for subdataset_name in dataset_name.split("+")
if len(subdataset_name) > 0
]
return tuple(
echofilter.data.dataset.ConcatDataset([d[i] for d in datasets])
for i in range(len(datasets[0]))
)
if dataset_name == "stationary2":
# The stationary2 dataset is MinasPassage and GrandPassage,
# plus a second duplicate copy of GrandPassage which uses only zoomed
# out depth crops.
random_crop_args2 = copy.deepcopy(random_crop_args)
random_crop_args2["p_crop_is_none"] = 0.2
random_crop_args2["p_crop_is_optimal"] = 0.0
random_crop_args2["p_crop_is_close"] = 0.0
# By default, we only evaluate on the validation partition of
# MinasPassage, and train on the train partition of both plus the
# validation partition of GrandPassage.
if train_partition is None and val_partition is None:
train_partition_main = "train"
val_partition_main = "validate"
train_partition_aux = "train+validate"
val_partition_aux = ""
else:
train_partition_main = train_partition_aux = train_partition
val_partition_main = val_partition_aux = val_partition
# Assemble the datasets
datasets = [
build_dataset(
"MinasPassage",
data_dir=data_dir,
sample_shape=sample_shape,
train_partition=train_partition_main,
val_partition=val_partition_main,
crop_depth=crop_depth,
random_crop_args=random_crop_args,
),
build_dataset(
"GrandPassage",
data_dir=data_dir,
sample_shape=sample_shape,
train_partition=train_partition_aux,
val_partition=val_partition_aux,
crop_depth=crop_depth,
random_crop_args=random_crop_args,
),
build_dataset(
"GrandPassage",
data_dir=data_dir,
sample_shape=sample_shape,
train_partition=train_partition_aux,
val_partition=val_partition_aux,
crop_depth=crop_depth,
random_crop_args=random_crop_args2,
),
]
return tuple(
echofilter.data.dataset.ConcatDataset([d[i] for d in datasets])
for i in range(len(datasets[0]))
)
if train_partition is None:
train_partition = "train"
if val_partition is None:
val_partition = "validate"
# Augmentations
train_transform = torchvision.transforms.Compose(
[
echofilter.data.transforms.RandomCropDepth(**random_crop_args),
echofilter.data.transforms.RandomReflection(),
echofilter.data.transforms.Normalize(DATA_CENTER, DATA_DEVIATION),
echofilter.data.transforms.ColorJitter(0.5, 0.3),
echofilter.data.transforms.ReplaceNan(NAN_VALUE),
echofilter.data.transforms.RandomElasticGrid(
sample_shape,
order=None,
p=0.5,
sigma=[8, 16],
alpha=0.1,
),
echofilter.data.transforms.Rescale(sample_shape, order=None),
]
)
val_transform = torchvision.transforms.Compose(
[
echofilter.data.transforms.OptimalCropDepth(),
echofilter.data.transforms.Normalize(DATA_CENTER, DATA_DEVIATION),
echofilter.data.transforms.ReplaceNan(NAN_VALUE),
echofilter.data.transforms.Rescale(sample_shape, order=1),
]
)
train_paths = []
for partition_name in train_partition.split("+"):
if len(partition_name) == 0:
continue
train_paths += get_partition_list(
partition_name,
dataset=dataset_name,
partitioning_version="firstpass",
root_data_dir=data_dir,
full_path=True,
sharded=True,
)
val_paths = []
for partition_name in val_partition.split("+"):
if len(partition_name) == 0:
continue
val_paths += get_partition_list(
partition_name,
dataset=dataset_name,
partitioning_version="firstpass",
root_data_dir=data_dir,
full_path=True,
sharded=True,
)
print(
"Found {:3d} train sample paths from partition {} for dataset {}".format(
len(train_paths), train_partition, dataset_name
)
)
print(
"Found {:3d} val sample paths from partition {} for dataset {}".format(
len(val_paths), val_partition, dataset_name
)
)
dataset_args = {}
if dataset_name == "mobile":
dataset_args["remove_nearfield"] = False
dataset_args["remove_offset_turbulence"] = 0
dataset_args["remove_offset_bottom"] = 1.0
elif dataset_name == "MinasPassage":
dataset_args["remove_nearfield"] = True
# NB: Nearfield distance is 1.7, except for the samples:
# march2018_D20180502-T045216_D20180502-T102215
# march2018_D20180502-T105216_D20180502-T162215
# march2018_D20180502-T165215_D20180502-T222216
# march2018_D20180502-T225214_D20180503-T042215
# For which it is 1.744
dataset_args["nearfield_distance"] = 1.745
dataset_args["remove_offset_turbulence"] = 0
dataset_args["remove_offset_bottom"] = 0
elif dataset_name == "GrandPassage":
dataset_args["remove_nearfield"] = True
dataset_args["nearfield_distance"] = 1.7
dataset_args["remove_offset_turbulence"] = 1.0
dataset_args["remove_offset_bottom"] = 0
dataset_train = echofilter.data.dataset.TransectDataset(
train_paths,
window_len=sample_shape[0],
p_scale_window=0.8,
num_windows_per_transect=None,
use_dynamic_offsets=True,
crop_depth=crop_depth,
transform=train_transform,
**dataset_args,
)
dataset_val = echofilter.data.dataset.TransectDataset(
val_paths,
window_len=sample_shape[0],
p_scale_window=0,
num_windows_per_transect=None,
use_dynamic_offsets=False,
crop_depth=crop_depth,
transform=val_transform,
**dataset_args,
)
dataset_augval = echofilter.data.dataset.TransectDataset(
val_paths,
window_len=sample_shape[0],
p_scale_window=0.8,
num_windows_per_transect=None,
use_dynamic_offsets=False,
crop_depth=crop_depth,
transform=train_transform,
**dataset_args,
)
return dataset_train, dataset_val, dataset_augval
[docs]def train_epoch(
loader,
model,
criterion,
optimizer,
device,
epoch,
dtype=torch.float,
print_freq=10,
schedule_data=None,
use_mixed_precision=False,
continue_through_error=True,
):
"""
Train a model through a single epoch of the dataset.
Parameters
----------
loader : iterable, torch.utils.data.DataLoader
Dataloader.
model : callable, echofilter.nn.wrapper.Echofilter
Model.
criterion : callable, torch.nn.modules.loss._Loss
Loss function.
device : str or torch.device
Which device the data should be loaded onto.
epoch : int
Which epoch is being performed.
dtype : str or torch.dtype
Datatype which which the data should be loaded.
print_freq : int, optional
Number of batches between reporting progress. Default is ``10``.
schedule_data : dict or None
If a learning rate schedule is being used, this may be passed as a
dictionary with the key ``"scheduler"`` mapping to the learning rate
schedule as a callable.
use_mixed_precision : bool
Whether to use :meth:`apex.amp.scale_loss` to automatically scale the
loss. Default is ``False``.
continue_through_error : bool
Whether to catch errors within an individual batch, ignore them and
continue running training on the rest of the batches. If there are
five or more errors while processing the batch, training will halt
regardless of ``continue_through_error``. Default is ``True``.
Returns
-------
average_loss : float
Average loss as given by criterion (weighted equally for each sample
in ``loader``).
meters : dict of dict
Each key is a strata of the model output, each mapping to a their own
dictionary of evaluation criterions: "Accuracy", "Precision", "Recall",
"F1 Score", "Jaccard".
examples : tuple of torch.Tensor
Tuple of `(example_input, example_data, example_output)`.
timing : tuple of floats
Tuple of `(batch_time, data_time)`.
"""
if schedule_data is None:
schedule_data = {"name": "constant"}
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":6.3f")
meters = {}
for chn in [
"Overall",
"Turbulence",
"Bottom",
"RemovedSeg",
"Passive",
"Patch",
"Surface",
]:
for condition in model.conditions:
cs = condition
if condition != "":
cs = "|" + condition
cc = chn + cs
meters[cc] = {}
meters[cc]["Accuracy"] = AverageMeter("Accuracy (" + cc + ")", ":6.2f")
meters[cc]["Precision"] = AverageMeter("Precision (" + cc + ")", ":6.2f")
meters[cc]["Recall"] = AverageMeter("Recall (" + cc + ")", ":6.2f")
meters[cc]["F1 Score"] = AverageMeter("F1 Score (" + cc + ")", ":6.4f")
meters[cc]["Jaccard"] = AverageMeter("Jaccard (" + cc + ")", ":6.4f")
progress = ProgressMeter(
len(loader),
[
batch_time,
data_time,
losses,
meters["Overall"]["Accuracy"],
meters["Overall"]["Jaccard"],
],
prefix="Epoch: [{}]".format(epoch),
)
# switch to train mode
model.train()
example_input = example_data = example_output = None
n_backward_errors = 0
end = time.time()
for i, (input, metadata) in enumerate(loader):
# measure data loading time
data_time.update(time.time() - end)
input = input.to(device, dtype, non_blocking=True)
metadata = {
k: v.to(device, dtype, non_blocking=True) for k, v in metadata.items()
}
# Compute output
output = model(input)
loss = criterion(output, metadata)
# Record loss
ns = input.size(0)
losses.update(loss.item(), ns)
with torch.no_grad():
if i == max(0, len(loader) - 2):
example_input = input.detach()
example_data = {k: v.detach() for k, v in metadata.items()}
example_output = output.detach()
# Measure and record performance with various metrics
for chn_cond, meters_k in meters.items():
chnparts = chn_cond.split("|")
chn = chnparts[0].lower()
if len(chnparts) < 2:
cs = cond = ""
else:
cond = chnparts[1]
cs = "|" + cond
if chn.startswith("overall"):
output_k = output["mask_keep_pixel" + cs].float()
target_k = metadata["mask"]
elif chn.startswith("turbulence"):
output_k = output["p_is_below_turbulence" + cs]
target_k = 1 - metadata["mask_turbulence"]
elif chn.startswith("surf"):
output_k = output["p_is_below_surface" + cs]
target_k = 1 - metadata["mask_surface"]
elif chn.startswith("bottom"):
output_k = output["p_is_above_bottom" + cs]
target_k = 1 - metadata["mask_bottom"]
elif chn.startswith("removedseg"):
output_k = output["p_is_removed" + cs]
target_k = metadata["is_removed"]
elif chn.startswith("passive"):
output_k = output["p_is_passive" + cs]
target_k = metadata["is_passive"]
elif chn.startswith("patch"):
output_k = output["p_is_patch" + cs]
target_k = metadata["mask_patches"]
else:
raise ValueError("Unrecognised output channel: {}".format(chn))
if cond:
if cond.startswith("up"):
mask = metadata["is_upward_facing"] > 0.5
elif cond.startswith("down"):
mask = metadata["is_upward_facing"] < 0.5
else:
raise ValueError("Unsupported condition {}".format(cond))
if torch.sum(mask).item() == 0:
continue
output_k = output_k[mask]
target_k = target_k[mask]
for c, v in meters_k.items():
c = c.lower()
if c == "accuracy":
v.update(
100.0 * criterions.mask_accuracy(output_k, target_k).item(),
ns,
)
elif c == "precision":
v.update(
100.0
* criterions.mask_precision(output_k, target_k).item(),
ns,
)
elif c == "recall":
v.update(
100.0 * criterions.mask_recall(output_k, target_k).item(),
ns,
)
elif c == "f1 score" or c == "f1":
v.update(
criterions.mask_f1_score(output_k, target_k).item(), ns
)
elif c == "jaccard":
v.update(
criterions.mask_jaccard_index(output_k, target_k).item(), ns
)
elif c == "active output":
v.update(
100.0 * criterions.mask_active_fraction(output_k).item(), ns
)
elif c == "active target":
v.update(
100.0 * criterions.mask_active_fraction(target_k).item(), ns
)
else:
raise ValueError("Unrecognised criterion: {}".format(c))
# compute gradient and do optimizer update step
optimizer.zero_grad()
try:
if use_mixed_precision:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
except Exception as ex:
n_backward_errors += 1
if n_backward_errors > 5:
# If there have been more than 5 faulty batches in this epoch, we halt
# now since it seems like a systemic problem and not a one-off.
continue_through_error = False
if not continue_through_error:
raise ex
print("Error in backward step:")
print("".join(traceback.TracebackException.from_exception(ex).format()))
optimizer.step()
if "scheduler" in schedule_data:
schedule_data["scheduler"].step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or i + 1 == len(loader):
progress.display(i + 1)
return (
losses.avg,
meters,
(example_input, example_data, example_output),
(batch_time, data_time),
)
[docs]def validate(
loader,
model,
criterion,
device,
dtype=torch.float,
print_freq=10,
prefix="Test",
num_examples=32,
):
"""
Validate the model's performance on the validation partition.
Parameters
----------
loader : iterable, torch.utils.data.DataLoader
Dataloader.
model : callable, echofilter.nn.wrapper.Echofilter
Model.
criterion : callable, torch.nn.modules.loss._Loss
Loss function.
device : str or torch.device
Which device the data should be loaded onto.
dtype : str or torch.dtype
Datatype which which the data should be loaded.
print_freq : int, optional
Number of batches between reporting progress. Default is ``10``.
prefix : str, optional
Prefix string to prepend to progress meter names. Default is ``"Test"``.
num_examples : int, optional
Number of example inputs to return. Default is ``32``.
Returns
-------
average_loss : float
Average loss as given by criterion (weighted equally for each sample
in ``loader``).
meters : dict of dict
Each key is a strata of the model output, each mapping to a their own
dictionary of evaluation criterions: "Accuracy", "Precision", "Recall",
"F1 Score", "Jaccard".
examples : tuple of torch.Tensor
Tuple of `(example_input, example_data, example_output)`.
"""
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":6.3f")
meters = {}
for chn in [
"Overall",
"Turbulence",
"Bottom",
"RemovedSeg",
"Passive",
"Patch",
"Surface",
]:
for condition in model.conditions:
cs = condition
if condition != "":
cs = "|" + condition
cc = chn + cs
meters[cc] = {}
meters[cc]["Accuracy"] = AverageMeter("Accuracy (" + cc + ")", ":6.2f")
meters[cc]["Precision"] = AverageMeter("Precision (" + cc + ")", ":6.2f")
meters[cc]["Recall"] = AverageMeter("Recall (" + cc + ")", ":6.2f")
meters[cc]["F1 Score"] = AverageMeter("F1 Score (" + cc + ")", ":6.4f")
meters[cc]["Jaccard"] = AverageMeter("Jaccard (" + cc + ")", ":6.4f")
progress = ProgressMeter(
len(loader),
[
batch_time,
data_time,
losses,
meters["Overall"]["Accuracy"],
meters["Overall"]["Jaccard"],
],
prefix=prefix + ": ",
)
# switch to evaluate mode
model.eval()
example_input = []
example_data = []
example_output = []
example_interval = max(1, len(loader) // num_examples)
with torch.no_grad():
end = time.time()
for i, (input, metadata) in enumerate(loader):
# measure data loading time
data_time.update(time.time() - end)
input = input.to(device, dtype, non_blocking=True)
metadata = {
k: v.to(device, dtype, non_blocking=True) for k, v in metadata.items()
}
# Compute output
output = model(input)
loss = criterion(output, metadata)
# Record loss
ns = input.size(0)
losses.update(loss.item(), ns)
if i % example_interval == 0 and len(example_input) < num_examples:
example_input.append(input[0].detach())
example_data.append({k: v[0].detach() for k, v in metadata.items()})
example_output.append({k: v[0].detach() for k, v in output.items()})
# Measure and record performance with various metrics
for chn_cond, meters_k in meters.items():
chnparts = chn_cond.split("|")
chn = chnparts[0].lower()
if len(chnparts) < 2:
cs = cond = ""
else:
cond = chnparts[1]
cs = "|" + cond
if chn.startswith("overall"):
output_k = output["mask_keep_pixel" + cs].float()
target_k = metadata["mask"]
elif chn.startswith("turbulence"):
output_k = output["p_is_below_turbulence" + cs]
target_k = 1 - metadata["mask_turbulence"]
elif chn.startswith("surf"):
output_k = output["p_is_below_surface" + cs]
target_k = 1 - metadata["mask_surface"]
elif chn.startswith("bottom"):
output_k = output["p_is_above_bottom" + cs]
target_k = 1 - metadata["mask_bottom"]
elif chn.startswith("removedseg"):
output_k = output["p_is_removed" + cs]
target_k = metadata["is_removed"]
elif chn.startswith("passive"):
output_k = output["p_is_passive" + cs]
target_k = metadata["is_passive"]
elif chn.startswith("patch"):
output_k = output["p_is_patch" + cs]
target_k = metadata["mask_patches"]
else:
raise ValueError("Unrecognised output channel: {}".format(chn))
if cond:
if cond.startswith("up"):
mask = metadata["is_upward_facing"] > 0.5
elif cond.startswith("down"):
mask = metadata["is_upward_facing"] < 0.5
else:
raise ValueError("Unsupported condition {}".format(cond))
if torch.sum(mask).item() == 0:
continue
output_k = output_k[mask]
target_k = target_k[mask]
for c, v in meters_k.items():
c = c.lower()
if c == "accuracy":
v.update(
100.0 * criterions.mask_accuracy(output_k, target_k).item(),
ns,
)
elif c == "precision":
v.update(
100.0
* criterions.mask_precision(output_k, target_k).item(),
ns,
)
elif c == "recall":
v.update(
100.0 * criterions.mask_recall(output_k, target_k).item(),
ns,
)
elif c == "f1 score" or c == "f1":
v.update(
criterions.mask_f1_score(output_k, target_k).item(), ns
)
elif c == "jaccard":
v.update(
criterions.mask_jaccard_index(output_k, target_k).item(), ns
)
elif c == "active output":
v.update(
100.0 * criterions.mask_active_fraction(output_k).item(), ns
)
elif c == "active target":
v.update(
100.0 * criterions.mask_active_fraction(target_k).item(), ns
)
else:
raise ValueError("Unrecognised criterion: {}".format(c))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or i + 1 == len(loader):
progress.display(i + 1)
# Restack samples, converting list into higher-dim tensor
example_input = torch.stack(example_input, dim=0)
example_data = {
k: torch.stack([a[k] for a in example_data], dim=0) for k in example_data[0]
}
example_output = {
k: torch.stack([a[k] for a in example_output], dim=0) for k in example_output[0]
}
return losses.avg, meters, (example_input, example_data, example_output)
[docs]def generate_from_transect(model, transect, sample_shape, device, dtype=torch.float):
"""
Generate an output for a sample transect, .
"""
# Put model in evaluation mode
model.eval()
# Make a copy of the transect which we will use to
data = copy.deepcopy(transect)
# Configure data to match what the model expects to see
# Ensure depth is always increasing (which corresponds to descending from
# the air down the water column)
if data["depths"][-1] < data["depths"][0]:
# Found some upward-facing data that still needs to be reflected
for k in ["depths", "signals", "mask"]:
data[k] = np.flip(data[k], -1).copy()
# Apply transforms
transform = torchvision.transforms.Compose(
[
echofilter.data.transforms.Normalize(DATA_CENTER, DATA_DEVIATION),
echofilter.data.transforms.ReplaceNan(NAN_VALUE),
echofilter.data.transforms.Rescale(
(data["signals"].shape[0], sample_shape[1])
),
]
)
data = transform(data)
input = torch.tensor(data["signals"]).unsqueeze(0).unsqueeze(0)
input = input.to(device, dtype).contiguous()
# Put data through model
with torch.no_grad():
output = model(input)
output = {k: v.squeeze(0).cpu().numpy() for k, v in output.items()}
output["depths"] = data["depths"]
output["timestamps"] = data["timestamps"]
return output
def _generate_from_loaded(transect, model, *args, crop_depth=None, **kwargs):
"""
Generate an output from a loaded transect.
"""
# Crop long input
for key in (
echofilter.data.transforms._fields_2d
+ echofilter.data.transforms._fields_1d_timelike
):
if key in transect:
transect[key] = transect[key][:MAX_INPUT_LEN]
# Apply depth crop
if crop_depth is not None:
if transect["is_upward_facing"]:
depth_crop_mask = (
transect["depths"] >= np.max(transect["depths"]) - crop_depth
)
else:
depth_crop_mask = transect["depths"] <= crop_depth
for key in echofilter.data.transforms._fields_2d:
if key in transect:
transect[key] = transect[key][:, depth_crop_mask]
for key in echofilter.data.transforms._fields_1d_depthlike:
if key in transect:
transect[key] = transect[key][depth_crop_mask]
# Convert lines to masks
ddepths = np.broadcast_to(transect["depths"], transect["Sv"].shape)
transect["mask_turbulence"] = np.single(
ddepths < np.expand_dims(transect["turbulence"], -1)
)
transect["mask_bottom"] = np.single(
ddepths > np.expand_dims(transect["bottom"], -1)
)
# Add mask_patches to the data, for plotting
transect["mask_patches"] = 1 - transect["mask"]
transect["mask_patches"][transect["is_passive"] > 0.5] = 0
transect["mask_patches"][transect["is_removed"] > 0.5] = 0
transect["mask_patches"][transect["mask_turbulence"] > 0.5] = 0
transect["mask_patches"][transect["mask_bottom"] > 0.5] = 0
# Generate predictions for the transect
transect["signals"] = transect.pop("Sv")
prediction = generate_from_transect(model, transect, *args, **kwargs)
transect["Sv"] = transect.pop("signals")
return transect, prediction
[docs]def generate_from_file(fname, *args, **kwargs):
"""
Generate an output for a sample transect, specified by its file path.
"""
# Load the data
transect = load_decomposed_transect_mask(fname)
# Process the transect
return _generate_from_loaded(transect, *args, **kwargs)
[docs]def generate_from_shards(fname, *args, **kwargs):
"""
Generate an output for a sample transect, specified by a path to sharded data.
"""
# Load the data
transect = echofilter.raw.shardloader.load_transect_segments_from_shards_abs(fname)
# Process the transect
return _generate_from_loaded(transect, *args, **kwargs)
[docs]def save_checkpoint(state, is_best, dirname=".", fname_fmt="checkpoint{}.pt", dup=None):
"""
Save a model checkpoint, using :meth:`torch.save`.
Parameters
----------
state : dict
Model checkpoint state to record.
is_best : bool
Whether this model state is the best so far. If ``True``, the best
checkpoint (by default named ``"checkpoint_best.pt"``) will be overwritten
with this ``state``.
dirname : str, optional
Path to directory in which the checkpoint will be saved.
Default is ``"."`` (current directory of the executed script).
fname_fmt : str, optional
Format for the file name(s) of the saved checkpoint(s). Must include
one string argument output. Default is ``"checkpoint{}.pt"``.
dup : str or None
If this is not ``None``, a duplicate copy of the checkpoint is recorded
in accordance with ``fname_fmt``. By default the duplicate output file
name will be styled as ``"checkpoint_<dup>.pt"``.
"""
os.makedirs(dirname, exist_ok=True)
fname = os.path.join(dirname, fname_fmt.format(""))
torch.save(state, fname)
if is_best:
shutil.copyfile(fname, os.path.join(dirname, fname_fmt.format("_best")))
if dup:
shutil.copyfile(
fname, os.path.join(dirname, fname_fmt.format("_{}".format(dup)))
)
[docs]def meters_to_csv(meters, is_best, dirname=".", filename="meters.csv"):
"""
Export performance metrics to CSV format.
Parameters
----------
meters : dict of dict
Collection of output meters, as a nested dictionary.
is_best : bool
Whether this model state is the best so far. If ``True``, the CSV file
will be copied to ``"model_best.meters.csv"``.
dirname : str, optional
Path to directory in which the checkpoint will be saved.
Default is ``"."`` (current directory of the executed script).
filename : str, optional
Format for the output file. Default is ``"meters.csv"``.
"""
os.makedirs(dirname, exist_ok=True)
df = pd.DataFrame()
for chn in meters:
if "|" in chn:
# Skip conditional model evaluations
continue
# For each output plane
for _, meter in meters[chn].items():
# For each criterion
df[meter.name] = meter.values
df.to_csv(os.path.join(dirname, filename), index=False)
if is_best:
shutil.copyfile(
os.path.join(dirname, filename),
os.path.join(dirname, "model_best.meters.csv"),
)
if __name__ == "__main__":
main()