Source code for echofilter.nn.modules.activations

"""
Pytorch activation functions.

Swish and Mish implementations taken from https://github.com/fastai/fastai2
under the Apache License Version 2.0.
"""

import functools

import torch
import torch.nn.functional as F
from torch import nn

__all__ = [
    "str2actfnfactory",
    "InplaceReLU",
    "swish",
    "Swish",
    "HardSwish",
    "mish",
    "Mish",
    "HardMish",
]


[docs]def str2actfnfactory(actfn_name): """ Map an activation function name to a factory which generates that actfun. Parameters ---------- actfn_name : str Name of the activation function. Returns ------- callable A generator which yields a subclass of :class:`torch.nn.Module`. """ if hasattr(nn, actfn_name): return getattr(nn, actfn_name) actfn_name_og = actfn_name actfn_name = actfn_name.lower().replace("-", "").replace("_", "") if actfn_name == "inplacerelu" or actfn_name == "reluinplace": return InplaceReLU elif actfn_name == "swish": return Swish elif actfn_name == "hardswish": return HardSwish elif actfn_name == "mish": return Mish elif actfn_name == "hardmish": return HardMish else: raise ValueError("Unrecognised activation function: {}".format(actfn_name_og))
InplaceReLU = functools.partial(nn.ReLU, inplace=True) # Swish @torch.jit.script def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x)) @torch.jit.script def _swish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) class _SwishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return _swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_variables[0] return _swish_jit_bwd(x, grad_output)
[docs]def swish(x, inplace=False): return _SwishJitAutoFn.apply(x)
[docs]class Swish(nn.Module):
[docs] def forward(self, x): return _SwishJitAutoFn.apply(x)
[docs]class HardSwish(nn.Module): """ A second-order approximation to the swish activation function. See https://arxiv.org/abs/1905.02244 """ def __init__(self, inplace=True): super().__init__() self.inplace = inplace self.relu6 = torch.nn.ReLU6(inplace=inplace)
[docs] def forward(self, x): return x * self.relu6(x + 3) / 6
[docs] def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "inplace=False" return inplace_str
# Mish @torch.jit.script def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x))) @torch.jit.script def _mish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) class MishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return _mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_variables[0] return _mish_jit_bwd(x, grad_output)
[docs]def mish(x): """ Apply the mish function elementwise. mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) See https://arxiv.org/abs/1908.08681 """ return MishJitAutoFn.apply(x)
[docs]class Mish(nn.Module): """ Apply the mish function elementwise. mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) See https://arxiv.org/abs/1908.08681 """
[docs] def forward(self, x): return MishJitAutoFn.apply(x)
[docs]class HardMish(nn.Module): """ A second-order approximation to the mish activation function. Notes ----- https://forums.fast.ai/t/hard-mish-activation-function/59238 """ def __init__(self, inplace=True): self.relu5 = nn.Hardtanh(0.0, 5.0, inplace)
[docs] def forward(self, x): return x * self.relu5(x + 3) / 5
[docs] def extra_repr(self): inplace_str = "inplace=True" if self.inplace else "inplace=False" return inplace_str