diff --git a/pynitrokey/cli/nk3/__init__.py b/pynitrokey/cli/nk3/__init__.py index 816cf29d..15da42e6 100644 --- a/pynitrokey/cli/nk3/__init__.py +++ b/pynitrokey/cli/nk3/__init__.py @@ -30,6 +30,7 @@ from pynitrokey.nk3.exceptions import TimeoutException from pynitrokey.nk3.updates import get_repo from pynitrokey.nk3.utils import Version +from pynitrokey.updates import OverwriteError T = TypeVar("T", bound=Nitrokey3Base) @@ -250,15 +251,18 @@ def fetch_update(path: str, force: bool, version: Optional[str]) -> None: path = update.download_to_dir(path, overwrite=force, callback=bar.update) else: if not force and os.path.exists(path): - raise CliException( - f"{path} already exists. Use --force to overwrite the file." - ) + raise OverwriteError(path) with open(path, "wb") as f: update.download(f, callback=bar.update) bar.close() local_print(f"Successfully downloaded firmware release {update.tag} to {path}") + except OverwriteError as e: + raise CliException( + f"{e.path} already exists. Use --force to overwrite the file.", + support_hint=False, + ) except Exception as e: raise CliException(f"Failed to download firmware update {update.tag}", e) diff --git a/pynitrokey/updates.py b/pynitrokey/updates.py index bc176237..b38ba812 100644 --- a/pynitrokey/updates.py +++ b/pynitrokey/updates.py @@ -19,6 +19,17 @@ ProgressCallback = Callable[[int, int], None] +class DownloadError(Exception): + def __init__(self, msg: str) -> None: + super().__init__("Cannot download firmware: " + msg) + + +class OverwriteError(Exception): + def __init__(self, path: str) -> None: + super().__init__(f"File {path} already exists and may not be overwritten") + self.path = path + + class FirmwareUpdate: def __init__(self, tag: str, url: str) -> None: self.tag = tag @@ -37,14 +48,14 @@ def download_to_dir( callback: Optional[ProgressCallback] = None, ) -> str: if not os.path.exists(d): - raise Exception(f"Cannot download firmware: {d} does not exist") + raise DownloadError(f"Directory {d} does not exist") if not os.path.isdir(d): - raise Exception(f"Cannot download firmware: {d} is not a directory") + raise DownloadError(f"{d} is not a directory") url = urllib.parse.urlparse(self.url) filename = os.path.basename(url.path) path = os.path.join(d, filename) if os.path.exists(path) and not overwrite: - raise Exception(f"File {path} already exists and may not be overwritten") + raise OverwriteError(path) with open(path, "wb") as f: self.download(f, callback=callback) return path