Skip to content

Commit ed5e1dd

Browse files
ydcjeffvfdev-5
andauthored
feat(hash_checkpoint): create the directory if not exist (#2273)
* feat(hash_checkpoint): create the directory if not exist * Apply suggestions from code review Co-authored-by: vfdev <vfdev.5@gmail.com> * fix: check for checkpoint file existence * chore: format Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 1c25adf commit ed5e1dd

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

ignite/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,24 @@ def hash_checkpoint(checkpoint_path: Union[str, Path], output_dir: Union[str, Pa
292292
293293
Args:
294294
checkpoint_path: Path to the checkpoint file.
295-
output_dir: Output directory to store the hashed checkpoint file.
295+
output_dir: Output directory to store the hashed checkpoint file
296+
(will be created if not exist).
296297
297298
Returns:
298-
Path to the hashed checkpoint file, The 8 digits of SHA256 hash.
299+
Path to the hashed checkpoint file, the first 8 digits of SHA256 hash.
299300
300301
.. versionadded:: 0.5.0
301302
"""
302303

303304
if isinstance(checkpoint_path, str):
304305
checkpoint_path = Path(checkpoint_path)
305306

307+
if not checkpoint_path.exists():
308+
raise FileNotFoundError(f"{checkpoint_path.name} does not exist in {checkpoint_path.parent}.")
309+
306310
if isinstance(output_dir, str):
307311
output_dir = Path(output_dir)
312+
output_dir.mkdir(parents=True, exist_ok=True)
308313

309314
sha_hash = hashlib.sha256(checkpoint_path.read_bytes()).hexdigest()
310315
old_filename = checkpoint_path.stem

tests/ignite/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,7 @@ def test_hash_checkpoint(tmp_path):
256256
model.load_state_dict(torch.load(hash_checkpoint_path), True)
257257
assert sha_hash[:8] == "b66bff10"
258258
assert hash_checkpoint_path.name == f"squeezenet1_0-{sha_hash[:8]}.pt"
259+
260+
# test non-existent checkpoint_path
261+
with pytest.raises(FileNotFoundError, match=r"not_found.pt does not exist in *"):
262+
hash_checkpoint(f"{tmp_path}/not_found.pt", tmp_path)

0 commit comments

Comments
 (0)