Source code for echofilter.optim.meters

"""
Meters for tracking measurements during training.
"""

# This file is part of Echofilter.
#
# Copyright (C) 2020-2022  Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import numpy as np
import torch


[docs]class AverageMeter(object): """Compute and store the average and current value.""" def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset()
[docs] def reset(self): self.values = [] self.val = 0 self.avg = 0 self.sum = 0 self.count = 0
[docs] def update(self, val, n=None): if isinstance(val, torch.Tensor): val = val.cpu().numpy() if isinstance(val, (int, float)): values = [val] else: values = list(val) val = np.mean(val) if n is None: n = len(values) self.values += values self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
def __str__(self): fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__)
[docs]class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix
[docs] def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = "{:" + str(num_digits) + "d}" return "[" + fmt + "/" + fmt.format(num_batches) + "]"