Skip to content

Commit 1c25adf

Browse files
authored
feat(utils): add hashing a checkpoint utility (#2272)
* feat(utils): add hashing utility for checkpoints * fix(type): convert Path to str * chore(docs): code quote * chore: remove CLI for hashing
1 parent c399ee9 commit 1c25adf

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

ignite/utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
import collections.abc as collections
22
import functools
3+
import hashlib
34
import logging
45
import random
6+
import shutil
57
import warnings
8+
from pathlib import Path
69
from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast
710

811
import torch
912

10-
__all__ = ["convert_tensor", "apply_to_tensor", "apply_to_type", "to_onehot", "setup_logger", "manual_seed"]
13+
__all__ = [
14+
"convert_tensor",
15+
"apply_to_tensor",
16+
"apply_to_type",
17+
"to_onehot",
18+
"setup_logger",
19+
"manual_seed",
20+
"hash_checkpoint",
21+
]
1122

1223

1324
def convert_tensor(
@@ -272,3 +283,34 @@ def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> Callable:
272283
return cast(F, wrapper)
273284

274285
return decorator
286+
287+
288+
def hash_checkpoint(checkpoint_path: Union[str, Path], output_dir: Union[str, Path],) -> Tuple[Path, str]:
289+
"""
290+
Hash the checkpoint file in the format of ``<filename>-<hash>.<ext>``
291+
to be used with ``check_hash`` of :func:`torch.hub.load_state_dict_from_url`.
292+
293+
Args:
294+
checkpoint_path: Path to the checkpoint file.
295+
output_dir: Output directory to store the hashed checkpoint file.
296+
297+
Returns:
298+
Path to the hashed checkpoint file, The 8 digits of SHA256 hash.
299+
300+
.. versionadded:: 0.5.0
301+
"""
302+
303+
if isinstance(checkpoint_path, str):
304+
checkpoint_path = Path(checkpoint_path)
305+
306+
if isinstance(output_dir, str):
307+
output_dir = Path(output_dir)
308+
309+
sha_hash = hashlib.sha256(checkpoint_path.read_bytes()).hexdigest()
310+
old_filename = checkpoint_path.stem
311+
new_filename = "-".join((old_filename, sha_hash[:8])) + ".pt"
312+
313+
hash_checkpoint_path = output_dir / new_filename
314+
shutil.move(str(checkpoint_path), hash_checkpoint_path)
315+
316+
return hash_checkpoint_path, sha_hash

tests/ignite/test_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from ignite.engine import Engine, Events
10-
from ignite.utils import convert_tensor, deprecated, setup_logger, to_onehot
10+
from ignite.utils import convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot
1111

1212

1313
def test_convert_tensor():
@@ -242,3 +242,17 @@ def func_with_everything():
242242

243243
def test_smoke__utils():
244244
from ignite._utils import apply_to_tensor, apply_to_type, convert_tensor, to_onehot # noqa: F401
245+
246+
247+
def test_hash_checkpoint(tmp_path):
248+
# download lightweight model
249+
from torchvision.models import squeezenet1_0
250+
251+
model = squeezenet1_0()
252+
torch.hub.download_url_to_file(
253+
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", f"{tmp_path}/squeezenet1_0.pt",
254+
)
255+
hash_checkpoint_path, sha_hash = hash_checkpoint(f"{tmp_path}/squeezenet1_0.pt", str(tmp_path))
256+
model.load_state_dict(torch.load(hash_checkpoint_path), True)
257+
assert sha_hash[:8] == "b66bff10"
258+
assert hash_checkpoint_path.name == f"squeezenet1_0-{sha_hash[:8]}.pt"

0 commit comments

Comments
 (0)