Source code for echofilter.nn.modules.utils

"""
nn.modules utility functions.
"""

import collections
from itertools import repeat

import torch
from torch import nn

__all__ = ["same_to_padding", "init_cnn"]


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            x = tuple(x)
            if len(x) == 0:
                raise ValueError("Input {} is an empty iterable".format(x))
            if len(x) > 1 or n == 1:
                return x
            x = x[0]
        return tuple(repeat(x, n))

    return parse


_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)


[docs]def same_to_padding(kernel_size, stride=1, dilation=1, ndim=None): """ Determine the amount of padding to use for a convolutional layer. Parameters ---------- kernel_size : int or sequence Size of kernel for each dimension. stride : int or sequence, optional Amount of stride to apply in each dimension of the kernel. If ``stride`` is an int, the same value is applied for each dimension. Default is ``1``. dilation : int or sequence, optional Amount of dilation to apply in each dimension of the kernel. If ``dilation`` is an int, the same value is applied for each dimension. Default is ``1``. ndim : int or None, optional Number of dimensions of kernel to pad. If ``None`` (default), the number of dimensions is inferred from the number of dimensions to ``kernel_size``. Returns ------- padding : tuple Amount of padding to apply to each dimension before convolving with the kernel in order to preserve the size of input. """ if isinstance(kernel_size, int): if kernel_size % 2 == 0: raise ValueError("Same padding is not implemented for even kernels sizes.") elif any(k % 2 == 0 for k in kernel_size): raise ValueError("Same padding is not implemented for even kernels sizes.") args = [kernel_size, stride, dilation] if all([isinstance(arg, int) for arg in args]): return (kernel_size - 1) * dilation // 2 if ndim is not None: pass elif isinstance(kernel_size, collections.sequence): ndim = len(kernel_size) elif isinstance(stride, collections.sequence): ndim = len(stride) elif isinstance(dilation, collections.sequence): ndim = len(dilation) else: raise ValueError("Wrong argument type. Must be int or iterable.") ntuple = _ntuple(ndim) kernel_size = ntuple(kernel_size) dilation = ntuple(dilation) padding = tuple((k - 1) * d // 2 for k, d in zip(kernel_size, dilation)) return padding
def rint(x, minval=None): """ Round and cast as an int, optionally with a floor value. Parameters ---------- x : float Number to round. minval : bool, optional A floor value for the output. If ``None``, no floor is applied. Default is ``None``. Returns ------- int The number rounded to the nearest int, and cast as an int. If ``minval`` is set, the max with ``minval`` is taken. """ x = int(round(x)) if minval is not None: x = max(minval, x) return x
[docs]def init_cnn(m): """ Initialize biases and weights for a CNN layer. Uses a Kaiming normal distribution for the weight and 0 for biases. Function is applied recursively within the module. Parameters ---------- m : torch.nn.Module Module """ if getattr(m, "bias", None) is not None: nn.init.constant_(m.bias, 0) if isinstance(m, (nn._ConvNd, nn.Linear)): nn.init.kaiming_normal_(m.weight) for layer in m.children(): init_cnn(layer)