"""
Interacting with the list of available checkpoints.
"""
# This file is part of Echofilter.
#
# Copyright (C) 2020-2022 Scott C. Lowe and Offshore Energy Research Association (OERA)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import argparse
import os
import pickle
from collections import OrderedDict
import appdirs
import yaml
from . import style
PACKAGE_DIR = os.path.dirname(os.path.dirname(__file__))
REPO_DIR = os.path.dirname(PACKAGE_DIR)
CHECKPOINT_FILE = os.path.join(PACKAGE_DIR, "checkpoints.yaml")
CHECKPOINT_FILE_ALT = os.path.join(REPO_DIR, "checkpoints.yaml")
CHECKPOINT_EXT = ".pt"
[docs]def get_checkpoint_list():
"""
List the currently available checkpoints, as stored in a local file.
Returns
-------
checkpoints : OrderedDict
Dictionary with a key for each checkpoint. Each key maps to a dictionary
whose elements describe the checkpoint.
"""
checkpoint_file_use = None
if os.path.isfile(CHECKPOINT_FILE):
checkpoint_file_use = CHECKPOINT_FILE
elif os.path.isfile(CHECKPOINT_FILE_ALT):
checkpoint_file_use = CHECKPOINT_FILE_ALT
else:
raise EnvironmentError(f"No such file: '{CHECKPOINT_FILE}'")
with open(checkpoint_file_use, "r") as hf:
checkpoints = OrderedDict(yaml.safe_load(hf)["checkpoints"])
return checkpoints
[docs]def get_default_checkpoint():
"""
Get the name of the current default checkpoint.
Returns
-------
checkpoint_name : str
Name of current checkpoint.
"""
return next(iter(get_checkpoint_list()))
[docs]def cannonise_checkpoint_name(name):
"""
Cannonises checkpoint name by removing extension.
Parameters
----------
name : str
Name of checkpoint, possibly including extension.
Returns
-------
name : str
Name of checkpoint, with extension removed it matches a possible
checkpoint file extension.
"""
for possible_ext in [
".ckpt.pth.tar",
".checkpoint.pth.tar",
".pth.tar",
".ckpt.tar",
".checkpoint.tar",
CHECKPOINT_EXT,
".pt",
".pth",
".ckpt",
".tar",
]:
if name.lower().endswith(possible_ext):
name = name[: -len(possible_ext)]
return name
return name
[docs]class ListCheckpoints(argparse.Action):
def __call__(self, parser, namespace, values, option_string):
print("Currently available model checkpoints:")
default_checkpoint = get_default_checkpoint()
checkpoints = get_checkpoint_list()
for checkpoint in checkpoints:
if checkpoint == default_checkpoint:
print(" * " + style.progress_fmt(checkpoint))
else:
print(" " + checkpoint)
parser.exit() # exits the program with no more arg parsing and checking
[docs]def get_default_cache_dir():
"""Determine the default cache directory."""
return appdirs.user_cache_dir("echofilter", "DeepSense")
[docs]class ShowCacheDir(argparse.Action):
def __call__(self, parser, namespace, values, option_string):
print("Downloaded model checkpoints are cached in:")
print(get_default_cache_dir())
parser.exit() # exits the program with no more arg parsing and checking
[docs]def download_checkpoint(checkpoint_name, cache_dir=None, verbose=1):
"""
Download a checkpoint if it isn't already cached.
Parameters
----------
checkpoint_name : str
Name of checkpoint to download.
cache_dir : str or None, optional
Path to local cache directory. If ``None`` (default), an OS-appropriate
application-specific default cache directory is used.
verbose : int, optional
Verbosity level. Default is ``1``. Set to ``0`` to disable print
statements.
Returns
-------
str
Path to downloaded checkpoint file.
"""
if cache_dir is None:
cache_dir = get_default_cache_dir()
checkpoint_name = cannonise_checkpoint_name(checkpoint_name)
destination = os.path.join(cache_dir, checkpoint_name + CHECKPOINT_EXT)
if os.path.exists(destination):
return destination
checkpoint_resources = get_checkpoint_list()
if checkpoint_name in checkpoint_resources:
sources = checkpoint_resources[checkpoint_name]
else:
for key, sources in checkpoint_resources.items():
if checkpoint_name in sources.get("aliases", []):
checkpoint_name = key
break
else:
msg = style.error_fmt(
"The checkpoint parameter should either be a path to a file or one of:"
)
msg += "\n ".join([""] + list(checkpoint_resources.keys()))
msg += style.error_fmt("\nbut '{}' was provided.".format(checkpoint_name))
with style.error_message():
raise ValueError(msg)
destination = os.path.join(cache_dir, checkpoint_name + CHECKPOINT_EXT)
if os.path.exists(destination):
return destination
# Import packages needed for downloading files
import urllib
import requests
from torchvision.datasets.utils import download_file_from_google_drive, download_url
os.makedirs(cache_dir, exist_ok=True)
if "aliases" in sources:
sources.pop("aliases")
success = False
for key, url_or_id in sources.items():
if key == "gdrive":
if verbose >= 1:
print(
"Downloading checkpoint {} from GDrive...".format(checkpoint_name)
)
try:
download_file_from_google_drive(
url_or_id,
os.path.dirname(destination),
filename=os.path.basename(destination),
)
success = True
continue
except pickle.UnpicklingError:
if verbose >= 1:
print(
style.error_fmt(
"\nCould not download checkpoint {} from GDrive!".format(
checkpoint_name
)
)
)
except (requests.exceptions.ConnectionError, urllib.error.URLError):
msg = "Could not connect to Google Drive. Please check your Internet connection."
with style.error_message(msg) as msg:
raise EnvironmentError(msg)
else:
if verbose >= 1:
print(
"Downloading checkpoint {} from {}...".format(
checkpoint_name, url_or_id
)
)
try:
download_url(url_or_id, cache_dir, filename=checkpoint_name)
success = True
continue
except pickle.UnpicklingError:
if verbose >= 1:
print(
style.error_fmt(
"\nCould not download checkpoint {} from {}".format(
checkpoint_name, url_or_id
)
)
)
except (requests.exceptions.ConnectionError, urllib.error.URLError):
msg = "Could not connect to file server to download {}. Please check your Internet connection.".format(
url_or_id
)
with style.error_message(msg) as msg:
raise EnvironmentError(msg)
if not success:
msg = "Unable to download {} from {}".format(checkpoint_name, sources)
with style.error_message(msg) as msg:
raise OSError(msg)
if verbose >= 1:
print("Downloaded checkpoint to {}".format(destination))
return destination
[docs]def load_checkpoint(
ckpt_name=None, cache_dir=None, device="cpu", return_name=False, verbose=1
):
"""
Load a checkpoint, either from absolute path or the cache.
Parameters
----------
checkpoint_name : str or None, optional
Path to checkpoint file, or name of checkpoint to download.
Default is ``None``.
cache_dir : str or None, optional
Path to local cache directory. If ``None`` (default), an OS-appropriate
application-specific default cache directory is used.
device : str or torch.device or None, optional
Device onto which weight tensors will be mapped. If ``None``, no mapping
is performed and tensors will be loaded onto the same device as they
were on when saved (which will result in an error if the device is not
present). Default is ``"cpu"``.
return_name : bool, optional
If ``True``, a tuple is returned indicting the name of the checkpoint
which was loaded. This is useful if the default checkpoint was loaded.
Default is ``False``.
verbose : int, optional
Verbosity level. Default is ``1``. Set to ``0`` to disable print
statements.
Returns
-------
checkpoint : dict
Loaded checkpoint.
checkpoint_name : str, optional
If ``return_name`` is ``True``, the name of the checkpoint is also
returned.
"""
import torch
if ckpt_name is None:
ckpt_name = get_default_checkpoint()
if cache_dir is None:
cache_dir = get_default_cache_dir()
ckpt_name_cannon = cannonise_checkpoint_name(ckpt_name)
builtin_ckpt_path_a = os.path.join(
PACKAGE_DIR,
"checkpoints",
os.path.split(ckpt_name)[1],
)
builtin_ckpt_path_b = os.path.join(
PACKAGE_DIR,
"checkpoints",
ckpt_name_cannon + CHECKPOINT_EXT,
)
using_cache = False
if os.path.isfile(ckpt_name):
ckpt_path = ckpt_name
ckpt_dscr = "local"
elif os.path.isfile(ckpt_name + CHECKPOINT_EXT):
ckpt_path = ckpt_name + CHECKPOINT_EXT
ckpt_dscr = "local"
elif os.path.isfile(builtin_ckpt_path_a):
ckpt_path = builtin_ckpt_path_a
ckpt_dscr = "builtin"
elif os.path.isfile(builtin_ckpt_path_b):
ckpt_path = builtin_ckpt_path_b
ckpt_dscr = "builtin"
else:
using_cache = True
ckpt_path = download_checkpoint(
ckpt_name_cannon, cache_dir=cache_dir, verbose=verbose
)
ckpt_dscr = "cached"
if not os.path.isfile(ckpt_path):
msg = "No checkpoint found at '{}'".format(ckpt_path)
with style.error_message(msg) as msg:
raise EnvironmentError(msg)
if verbose >= 1:
print("Loading model from {} checkpoint:\n '{}'".format(ckpt_dscr, ckpt_path))
load_args = {}
if device is not None:
# Map model to be loaded to specified single gpu.
load_args = dict(map_location=device)
try:
checkpoint = torch.load(ckpt_path, **load_args)
except pickle.UnpicklingError:
if not using_cache:
# Direct path to checkpoint was given, so we shouldn't delete
# the user's file
msg = "Error: Unable to load checkpoint {}".format(
os.path.abspath(ckpt_path)
)
with style.error_message(msg) as msg:
print(msg)
raise
else:
# Delete the checkpoint and try again, in case it is just a
# malformed download (interrupted download, etc)
os.remove(ckpt_path)
ckpt_path = download_checkpoint(
ckpt_name, cache_dir=cache_dir, verbose=verbose
)
try:
checkpoint = torch.load(ckpt_path, **load_args)
except pickle.UnpicklingError:
msg = "Error: Unable to load checkpoint {}.".format(
os.path.abspath(ckpt_path)
)
with style.error_message(msg) as msg:
print(msg)
# Check if there was an error because the file was missing
# and we downloaded a 404 error page instead.
with open(ckpt_path) as myfile:
contents = myfile.read()
if "Not Found" in contents or "Error 404" in contents:
msg = (
"The file you are trying to download is not available"
" at this web address."
" The download produced a 404 Error (File Not Found)."
"\nThe original source file may have been moved or deleted"
" by its host."
)
with style.error_message(msg) as msg:
print(msg + "\n")
raise EnvironmentError(msg) from None
# Check if the user ran out of storage space part-way through
# the download.
import shutil
_, _, free_B = shutil.disk_usage("/")
free_MiB = free_B // 10**6
if free_MiB < 64:
msg = (
"You only have {}MB of free space on your hard disk."
" Please free up 100MB of space on your hard disk and"
" then try again.\n"
).format(free_MiB)
with style.error_message(msg) as msg:
print(msg + "\n")
raise EnvironmentError(msg) from None
msg = "There was an unknown issue opening the downloaded file."
with style.error_message(msg) as msg:
print(msg)
raise
if return_name:
return checkpoint, ckpt_name
return checkpoint