diff --git a/ignite/utils.py b/ignite/utils.py index fba790e9c39..442b986d5b2 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -292,10 +292,11 @@ def hash_checkpoint(checkpoint_path: Union[str, Path], output_dir: Union[str, Pa Args: checkpoint_path: Path to the checkpoint file. - output_dir: Output directory to store the hashed checkpoint file. + output_dir: Output directory to store the hashed checkpoint file + (will be created if not exist). Returns: - Path to the hashed checkpoint file, The 8 digits of SHA256 hash. + Path to the hashed checkpoint file, the first 8 digits of SHA256 hash. .. versionadded:: 0.5.0 """ @@ -303,8 +304,12 @@ def hash_checkpoint(checkpoint_path: Union[str, Path], output_dir: Union[str, Pa if isinstance(checkpoint_path, str): checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"{checkpoint_path.name} does not exist in {checkpoint_path.parent}.") + if isinstance(output_dir, str): output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) sha_hash = hashlib.sha256(checkpoint_path.read_bytes()).hexdigest() old_filename = checkpoint_path.stem diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index d5ed7bd128f..308b12d2502 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -256,3 +256,7 @@ def test_hash_checkpoint(tmp_path): model.load_state_dict(torch.load(hash_checkpoint_path), True) assert sha_hash[:8] == "b66bff10" assert hash_checkpoint_path.name == f"squeezenet1_0-{sha_hash[:8]}.pt" + + # test non-existent checkpoint_path + with pytest.raises(FileNotFoundError, match=r"not_found.pt does not exist in *"): + hash_checkpoint(f"{tmp_path}/not_found.pt", tmp_path)