Source code for echofilter.nn.unet

"""
U-Net model.
"""

# This file is part of Echofilter.
#
# Copyright (C) 2020-2022  Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import functools

import torch
from torch import nn

from . import modules


[docs]class Down(nn.Module): """ Downscaling layer, downsampling by a factor of two in one or more dimensions. """ def __init__(self, mode="max", compress_dims=True): super(Down, self).__init__() compress_dims = modules.utils._pair(compress_dims) kernel_sizes = modules.utils._pair( 2 if compress_dim else 1 for compress_dim in compress_dims ) if mode == "max": self.pool = nn.MaxPool2d(kernel_sizes) elif mode == "avg": self.pool = nn.AvgPool2d(kernel_sizes) else: raise ValueError("Unsupported pooling method: {}".format(mode))
[docs] def forward(self, x): return self.pool(x)
[docs]class Up(nn.Module): """ Upscaling layer, upsampling by a factor of two in one or more dimensions. """ def __init__(self, in_channels=None, up_dims=True, mode="bilinear"): super(Up, self).__init__() up_dims = modules.utils._pair(up_dims) kernel_sizes = modules.utils._pair(2 if up_dim else 1 for up_dim in up_dims) # If conv mode, use a transposed convolution to increase the size # Otherwise, use one of the nn.Upsample modes: # {"nearest", "linear", "bilinear", "bicubic"} if "conv" in mode: if in_channels is None: raise ValueError( "Number of channels must be provided if upscaling with " "transposed convolution." ) self.up = nn.ConvTranspose2d( in_channels, in_channels, kernel_size=kernel_sizes, stride=kernel_sizes, ) else: self.up = nn.Upsample( scale_factor=kernel_sizes, mode=mode, align_corners=True )
[docs] def forward(self, x): return self.up(x)
[docs]class UNetBlock(nn.Module): """ Create a (cascading set of) UNet block(s). Each block performs the steps: - Store input to be used in skip connection - Down step - Horizontal block - <Recursion> - Up step - Concatenate with skip connection - Horizontal block Where <Recursion> is a call generating a child UNetBlock instance. Parameters ---------- in_channels : int Number of input channels to this block. horizontal_block_factory : callable A :class:`torch.nn.Module` constructor or function which returns a block of layers. The resulting module must accept ``in_channels`` and ``out_channels`` as its first two arguments. n_block : int, optional The number of nested UNetBlocks to use. Default is ``1`` (no nesting). block_expansion_factor : int or float, optional Expansion factor for the number of channels between nested UNetBlocks. Default is ``2``. expand_only_on_down : bool, optional Whether to exand the number of channels only when one of the spatial dimensions is compressed. Default is ``False``. blocks_per_downsample : int or sequence, optional How many blocks to include between each downsample operation. This can be a tuple of values for each spatial dimension, or an int which uses the same value for each spatial dimension. Default is ``1``. blocks_before_first_downsample : int or sequence, optional How many blocks to include before the first spatial downsampling occurs. Default is ``1``. always_include_skip_connection : bool, optional If ``True``, a skip connection is included even if no dimensions were downsampled in this block. Default is ``True``. deepest_inner : {callable, "horizontal_block", "identity", None}, optional A layer which should be applied at the deepest part of the network, before the first upsampling step. The parameter should either be a pre-instantiated layer, or the string ``"horizontal_block"``, to indicate an additional block as generated by the ``horizontal_block_factory``. If it is the string ``"identity"`` or ``None`` (default), no additional layer is included at the deepest point before upsampling begins. downsampling_modes : {"max", "avg", "stride"} or sequence, optional The downsampling mode to use. If this is a string, the same downsampling mode is used for every downsampling step. If it is a sequence, it should contain a string for each downsampling step. If the input sequence is too short, the final value will be used for all remaining downsampling steps. Default is ``"max"``. upsampling_modes : str or sequence, optional The upsampling mode to use. If this is a string, it must be ``"conv"``, or something supported by :class:`torch.nn.Upsample`; the same upsampling mode is used for every upsampling step. If it is a sequence, it should contain a string for each upsampling step. If the input sequence is too short, the final value will be used for all remaining upsampling steps. Default is ``"bilinear"``. _i_block : int, optional The current block number. Used internally to track recursion. Default is ``0``. _i_down : int, optional Used internally to track downsampling depth. Default is ``0``. Notes ----- This class is defined recursively, and will instantiate itself as its own child until the number of blocks has been satisfied. """ def __init__( self, in_channels, horizontal_block_factory, n_block=1, block_expansion_factor=2, expand_only_on_down=False, blocks_per_downsample=1, blocks_before_first_downsample=0, always_include_skip_connection=True, deepest_inner="identity", downsampling_modes="max", upsampling_modes="bilinear", _i_block=0, _i_down=0, ): super(UNetBlock, self).__init__() # Ensure these variables are a tuple of length two blocks_per_downsample = modules.utils._pair(blocks_per_downsample) blocks_before_first_downsample = modules.utils._pair( blocks_before_first_downsample ) # Check which downsampling and upsampling mode we are using for this # layer (may be the same for every layer) if isinstance(downsampling_modes, str): downsampling_mode = downsampling_modes elif _i_down >= len(downsampling_modes): downsampling_mode = downsampling_modes[-1] else: downsampling_mode = downsampling_modes[_i_down] if isinstance(upsampling_modes, str): upsampling_mode = upsampling_modes elif _i_down >= len(upsampling_modes): upsampling_mode = upsampling_modes[-1] else: upsampling_mode = upsampling_modes[_i_down] # Check which dimensions need to be compressed with this block compress_dims = tuple( _i_block >= i0 and (_i_block - i0) % k == 0 for i0, k in zip(blocks_before_first_downsample, blocks_per_downsample) ) compress_any_dims = any(compress_dims) # Determine whether we are increasing the number of channels, and # if so what to if expand_only_on_down and not compress_any_dims: out_channels = in_channels else: out_channels = int(max(1, round(in_channels * block_expansion_factor))) # Downsamling step. If the mode is "stride", this is incorporated # into the horizontal block with a strided convolution. If there # is no need to downsample, it will be the Identity function. stride = 1 if not compress_any_dims: self.down = nn.Identity() elif downsampling_mode == "stride": self.down = nn.Identity() stride = tuple(2 if compress_dim else 1 for compress_dim in compress_dims) else: self.down = Down(mode=downsampling_mode, compress_dims=compress_dims) # First horizontal block. It might begin with a downsampling stride, # and might increase the number of channels. self.horizontal_block_a = horizontal_block_factory( in_channels, out_channels, stride=stride, ) # In the sequence, the inner step comes next. But we will define it # once we have finished defining everything else in this UNet block. # Upsampling step. Does the inverse of the Down step, using some # method. if not compress_any_dims: self.up = nn.Identity() else: self.up = Up( in_channels=out_channels, up_dims=compress_dims, mode=upsampling_mode ) if compress_any_dims or always_include_skip_connection: # Concatenation step self.concatenate = modules.FlexibleConcat2d() b_in_channels = in_channels + out_channels else: # No concatenation step self.concatenate = None b_in_channels = out_channels # Second horizontal block. Takes both the skip connection and the # upsampled data as its input. self.horizontal_block_b = horizontal_block_factory(b_in_channels, in_channels) if _i_block + 1 < n_block: # Recurse deeper! Call this class again, but with the # block counter increased. self.nested = UNetBlock( out_channels, horizontal_block_factory, n_block=n_block, block_expansion_factor=block_expansion_factor, expand_only_on_down=expand_only_on_down, blocks_per_downsample=blocks_per_downsample, blocks_before_first_downsample=blocks_before_first_downsample, always_include_skip_connection=always_include_skip_connection, deepest_inner=deepest_inner, downsampling_modes=downsampling_modes, upsampling_modes=upsampling_modes, _i_block=_i_block + 1, _i_down=_i_down + compress_any_dims, ) elif callable(deepest_inner): self.nested = deepest_inner elif deepest_inner is None or deepest_inner.lower() == "identity": # End recursion, by doing nothing for the inner loop. self.nested = nn.Identity() elif deepest_inner == "horizontal_block": # End recursion, by doing an extra regular block. self.nested = horizontal_block_factory(out_channels, out_channels) else: raise ValueError( "Unsupported deepest_inner value: {}".format(deepest_inner) )
[docs] def forward(self, input): x = self.down(input) x = self.horizontal_block_a(x) x = self.nested(x) x = self.up(x) if self.concatenate is not None: x = self.concatenate(x, input) x = self.horizontal_block_b(x) return x
[docs]class UNet(nn.Module): """ UNet model. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. initial_channels : int, optional Number of latent channels to output from the initial convolution facing the input layer. Default is ``32``. bottleneck_channels : int, optional Number of channels to output from the first block, before the first unet downsampling step can occur. Default is the same as ``initial_channels``. n_block : int, optional Number of blocks, both up and down. Default is ``4``. unet_expansion_factor : int or float, optional Channel expansion factor between unet blocks. Default is ``2``. expand_only_on_down : bool, optional Whether to only apply ``unet_expansion_factor`` on unet blocks which actually containg a down/up sampling component, and not on vanilla blocks. Default is ``False``. blocks_per_downsample : int or sequence, optional Block interval between dowsampling steps in the unet. If this is a sequence, it corresponds to the number of blocks for each spatial dimension. Default is ``1``. blocks_before_first_downsample : int, optional Number of blocks to use before and after the main unet structure. Must be at least ``1``. Default is ``1``. always_include_skip_connection : bool, optional If ``True``, a skip connection is included between all blocks equally far from the start and end of the UNet. If ``False``, skip connections are only used between downsampling and upsampling operations. Default is ``True``. deepest_inner : {callable, "horizontal_block", "identity", None}, optional A layer which should be applied at the deepest part of the network, before the first upsampling step. The parameter should either be a pre-instantiated layer, or the string ``"horizontal_block"``, to indicate an additional block as generated by the ``horizontal_block_factory``. If it is the string ``"identity"`` or ``None`` (default), no additional layer is included at the deepest point before upsampling begins. intrablock_expansion : int or float, optional Channel expansion factor within inverse residual block. Default is ``6``. se_reduction : int or float, optional Channel reduction factor within squeeze and excite block. Default is ``4``. downsampling_modes : {"max", "avg", "stride"} or sequence, optional The downsampling mode to use. If this is a string, the same downsampling mode is used for every downsampling step. If it is a sequence, it should contain a string for each downsampling step. If the input sequence is too short, the final value will be used for all remaining downsampling steps. Default is ``"max"``. upsampling_modes : str or sequence, optional The upsampling mode to use. If this is a string, it must be ``"conv"``, or something supported by :class:`torch.nn.Upsample`; the same upsampling mode is used for every upsampling step. If it is a sequence, it should contain a string for each upsampling step. If the input sequence is too short, the final value will be used for all remaining upsampling steps. Default is ``"bilinear"``. depthwise_separable_conv : bool, optional Whether to use depthwise separable convolutions in the MBConv block. Otherwise, the depth and pointwise convolutions are fused together into a regular convolution. Default is ``True``. residual : bool, optional Whether to use a residual architecture for the MBConv blocks. Default is ``True``. actfn : str, optional Name of the activation function to use. Default is ``"InplaceReLU"``. kernel_size : int, optional Size of convolution kernel to use. Default is ``5``. """ def __init__( self, in_channels, out_channels, initial_channels=32, bottleneck_channels=None, n_block=4, unet_expansion_factor=2, expand_only_on_down=False, blocks_per_downsample=1, blocks_before_first_downsample=1, always_include_skip_connection=True, deepest_inner="identity", intrablock_expansion=6, se_reduction=4, downsampling_modes="max", upsampling_modes="bilinear", depthwise_separable_conv=True, residual=True, actfn="InplaceReLU", kernel_size=5, ): super(UNet, self).__init__() if bottleneck_channels is None: bottleneck_channels = initial_channels blocks_before_first_downsample = modules.utils._pair( blocks_before_first_downsample ) if any(b < 1 for b in blocks_before_first_downsample): raise ValueError( "An initial block is hard coded. Number of blocks before first" " downsample must be at least 1." ) self.in_channels = in_channels self.out_channels = out_channels actfn_factory = modules.activations.str2actfnfactory(actfn) horizontal_block_factory = functools.partial( modules.MBConv, expansion=intrablock_expansion, se_reduction=se_reduction, fused=not depthwise_separable_conv, residual=residual, actfn=actfn_factory, kernel_size=kernel_size, ) self.initial_conv = nn.Sequential( modules.Conv2dSame(in_channels, initial_channels, kernel_size=kernel_size), nn.BatchNorm2d(initial_channels), actfn_factory(), ) self.first_block = horizontal_block_factory( initial_channels, bottleneck_channels, expansion=1, ) self.main_blocks = UNetBlock( bottleneck_channels, horizontal_block_factory, n_block=n_block, block_expansion_factor=unet_expansion_factor, expand_only_on_down=expand_only_on_down, blocks_per_downsample=blocks_per_downsample, blocks_before_first_downsample=tuple( b - 1 for b in blocks_before_first_downsample ), always_include_skip_connection=always_include_skip_connection, deepest_inner=deepest_inner, downsampling_modes=downsampling_modes, upsampling_modes=upsampling_modes, ) self.final_block = horizontal_block_factory(bottleneck_channels, out_channels)
[docs] def forward(self, x): x = self.initial_conv(x) x = self.first_block(x) x = self.main_blocks(x) x = self.final_block(x) return x