Source code for echofilter.data.utils

"""
Utility functions for dataset.
"""

# This file is part of Echofilter.
#
# Copyright (C) 2020-2022  Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import random

import numpy as np
import torch

from ..utils import first_nonzero, last_nonzero


[docs]def worker_seed_fn(worker_id): """ Seed builtin :mod:`random` and :mod:`numpy` with :meth:`torch.randint`. A worker initialization function for :class:`torch.utils.data.DataLoader` objects which seeds builtin :mod:`random` and :mod:`numpy` with :meth:`torch.randint` (which is stable if torch is manually seeded in the main program). Parameters ---------- worker_id : int The ID of the worker. """ np.random.seed((torch.randint(0, 4294967296, (1,)).item() + worker_id) % 4294967296) random.seed(torch.randint(0, 4294967296, (1,)).item() + worker_id)
[docs]def worker_staticseed_fn(worker_id): """ Seed builtin :mod:`random`, :mod:`numpy`, and :mod:`torch` with ``worker_id``. A worker initialization function for :class:`torch.utils.data.DataLoader` objects which produces the same seed for builtin :mod:`random`, :mod:`numpy`, and :mod:`torch` every time, so it is the same for every epoch. Parameters ---------- worker_id : int The ID of the worker. """ random.seed(worker_id) np.random.seed(worker_id) torch.manual_seed(worker_id)