Source code for echofilter.optim.meters
"""
Meters for tracking measurements during training.
"""
# This file is part of Echofilter.
#
# Copyright (C) 2020-2022 Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
import torch
[docs]class AverageMeter(object):
"""Compute and store the average and current value."""
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
[docs] def update(self, val, n=None):
if isinstance(val, torch.Tensor):
val = val.cpu().numpy()
if isinstance(val, (int, float)):
values = [val]
else:
values = list(val)
val = np.mean(val)
if n is None:
n = len(values)
self.values += values
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
[docs]class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
[docs] def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"