Source code for echofilter.optim.utils
"""
Utility functions for interacting with optimizers.
"""
# This file is part of Echofilter.
#
# Copyright (C) 2020-2022 Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
[docs]def get_current_lr(optimizer):
"""
Get the learning rate of an optimizer.
Parameters
----------
optimizer : torch.optim.Optimizer
An optimizer, with a learning rate common to all parameter groups.
Returns
-------
float
The learning rate of the first parameter group.
"""
return optimizer.param_groups[0]["lr"]
[docs]def get_current_momentum(optimizer):
"""
Get the momentum of an optimizer.
Parameters
----------
optimizer : torch.optim.Optimizer
An optimizer which implements momentum or betas (where momentum is the
first beta, c.f. :class:`torch.optim.Adam`) with a momentum common to all
parameter groups.
Returns
-------
float
The momentum of the first parameter group.
"""
if "momentum" not in optimizer.defaults and "betas" not in optimizer.defaults:
raise ValueError(
"optimizer {} does not support momentum".format(optimizer.__class__)
)
group = optimizer.param_groups[0]
if "momentum" in group:
return group["momentum"]
else:
return group["betas"][0]