"""
Pytorch classes backported from later versions.
This contains functions copied from newer versions of pytorch than v1.2.0,
which is the latest version currently available from IBM compiled for ppc64
architectures.
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import math
import warnings
import weakref
from collections import Counter
from functools import wraps
from torch._six import inf
from torch.optim.optimizer import Optimizer
class _LRScheduler(object):
r"""
Backported from pytorch 1.4.0.
"""
def __init__(self, optimizer, last_epoch=-1):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
else:
for i, group in enumerate(optimizer.param_groups):
if "initial_lr" not in group:
raise KeyError(
"param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i)
)
self.base_lrs = list( # noqa: C417
map(lambda group: group["initial_lr"], optimizer.param_groups) # noqa: C417
) # noqa: C417
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, "_with_counter", False):
# `optimizer.step()` has already been replaced, return.
return method
# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method
@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)
# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.step()
def state_dict(self):
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
"""Load the schedulers state.
Arguments
---------
state_dict : dict
Scheduler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
return self._last_lr
def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
)
self._step_count += 1
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
return self
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn("Something deprecated!", DeprecationWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
[docs]class OneCycleLR(_LRScheduler):
r"""
Backported from pytorch 1.4.0.
Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch.
``step`` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
Parameters
----------
optimizer : :class:`torch.optim.optimizer.Optimizer`
Wrapped optimizer.
max_lr : float or list
Upper learning rate boundaries in the cycle
for each parameter group.
total_steps : int
The total number of steps in the cycle. Note that
if a value is provided here, then it must be inferred by providing
a value for epochs and steps_per_epoch.
Default: None
epochs : int
The number of epochs to train for. This is used along
with steps_per_epoch in order to infer the total number of steps in the cycle
if a value for total_steps is not provided.
Default: None
steps_per_epoch : int
The number of steps per epoch to train for. This is
used along with epochs in order to infer the total number of steps in the
cycle if a value for total_steps is not provided.
Default: None
pct_start : float
The percentage of the cycle (in number of steps) spent
increasing the learning rate.
Default: 0.3
anneal_strategy : {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
linear annealing.
Default: 'cos'
cycle_momentum : bool
If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum : float or list
Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.85
max_momentum : float or list
Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.95
div_factor : float
Determines the initial learning rate via
initial_lr = max_lr/div_factor
Default: 25
final_div_factor : float
Determines the minimum learning rate via
min_lr = initial_lr/final_div_factor
Default: 1e4
last_epoch : int
The index of the last batch. This parameter is used when
resuming a training job. Since ``step()`` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
Example
-------
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
>>> optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
>>> )
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""
def __init__(
self,
optimizer,
max_lr,
total_steps=None,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy="cos",
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.0,
final_div_factor=1e4,
last_epoch=-1,
):
# Validate optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
# Validate total_steps
if total_steps is None and epochs is None and steps_per_epoch is None:
raise ValueError(
"You must define either total_steps OR (epochs AND steps_per_epoch)"
)
elif total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError(
"Expected non-negative integer total_steps, but got {}".format(
total_steps
)
)
self.total_steps = total_steps
else:
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(
"Expected non-negative integer epochs, but got {}".format(epochs)
)
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(
"Expected non-negative integer steps_per_epoch, but got {}".format(
steps_per_epoch
)
)
self.total_steps = epochs * steps_per_epoch
self.step_size_up = float(pct_start * self.total_steps) - 1
self.step_size_down = float(self.total_steps - self.step_size_up) - 1
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(
"Expected float between 0 and 1 pct_start, but got {}".format(pct_start)
)
# Validate anneal_strategy
if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
"anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(
anneal_strategy
)
)
elif anneal_strategy == "cos":
self.anneal_func = self._annealing_cos
elif anneal_strategy == "linear":
self.anneal_func = self._annealing_linear
# Initialize learning rate variables
max_lrs = self._format_param("max_lr", self.optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(self.optimizer.param_groups):
group["initial_lr"] = max_lrs[idx] / div_factor
group["max_lr"] = max_lrs[idx]
group["min_lr"] = group["initial_lr"] / final_div_factor
# Initialize momentum variables
self.cycle_momentum = cycle_momentum
if self.cycle_momentum:
if (
"momentum" not in self.optimizer.defaults
and "betas" not in self.optimizer.defaults
):
raise ValueError(
"optimizer must support momentum with `cycle_momentum` option enabled"
)
self.use_beta1 = "betas" in self.optimizer.defaults
max_momentums = self._format_param("max_momentum", optimizer, max_momentum)
base_momentums = self._format_param(
"base_momentum", optimizer, base_momentum
)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(
max_momentums, base_momentums, optimizer.param_groups
):
if self.use_beta1:
_, beta2 = group["betas"]
group["betas"] = (m_momentum, beta2)
else:
group["momentum"] = m_momentum
group["max_momentum"] = m_momentum
group["base_momentum"] = b_momentum
super(OneCycleLR, self).__init__(optimizer, last_epoch)
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError(
"expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)
)
)
return param
else:
return [param] * len(optimizer.param_groups)
def _annealing_cos(self, start, end, pct):
"""Cosine anneal from ``start`` to ``end`` as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
"""Linearly anneal from ``start`` to ``end`` as pct goes from 0.0 to 1.0."""
return (end - start) * pct + start
[docs] def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
DeprecationWarning,
)
lrs = []
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}".format(
step_num + 1, self.total_steps
)
)
for group in self.optimizer.param_groups:
if step_num <= self.step_size_up:
computed_lr = self.anneal_func(
group["initial_lr"], group["max_lr"], step_num / self.step_size_up
)
if self.cycle_momentum:
computed_momentum = self.anneal_func(
group["max_momentum"],
group["base_momentum"],
step_num / self.step_size_up,
)
else:
down_step_num = step_num - self.step_size_up
computed_lr = self.anneal_func(
group["max_lr"],
group["min_lr"],
down_step_num / self.step_size_down,
)
if self.cycle_momentum:
computed_momentum = self.anneal_func(
group["base_momentum"],
group["max_momentum"],
down_step_num / self.step_size_down,
)
lrs.append(computed_lr)
if self.cycle_momentum:
if self.use_beta1:
_, beta2 = group["betas"]
group["betas"] = (computed_momentum, beta2)
else:
group["momentum"] = computed_momentum
return lrs