|
1 | 1 | import collections.abc as collections |
2 | 2 | import functools |
| 3 | +import hashlib |
3 | 4 | import logging |
4 | 5 | import random |
| 6 | +import shutil |
5 | 7 | import warnings |
| 8 | +from pathlib import Path |
6 | 9 | from typing import Any, Callable, Dict, Optional, TextIO, Tuple, Type, TypeVar, Union, cast |
7 | 10 |
|
8 | 11 | import torch |
9 | 12 |
|
10 | | -__all__ = ["convert_tensor", "apply_to_tensor", "apply_to_type", "to_onehot", "setup_logger", "manual_seed"] |
| 13 | +__all__ = [ |
| 14 | + "convert_tensor", |
| 15 | + "apply_to_tensor", |
| 16 | + "apply_to_type", |
| 17 | + "to_onehot", |
| 18 | + "setup_logger", |
| 19 | + "manual_seed", |
| 20 | + "hash_checkpoint", |
| 21 | +] |
11 | 22 |
|
12 | 23 |
|
13 | 24 | def convert_tensor( |
@@ -272,3 +283,34 @@ def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> Callable: |
272 | 283 | return cast(F, wrapper) |
273 | 284 |
|
274 | 285 | return decorator |
| 286 | + |
| 287 | + |
| 288 | +def hash_checkpoint(checkpoint_path: Union[str, Path], output_dir: Union[str, Path],) -> Tuple[Path, str]: |
| 289 | + """ |
| 290 | + Hash the checkpoint file in the format of ``<filename>-<hash>.<ext>`` |
| 291 | + to be used with ``check_hash`` of :func:`torch.hub.load_state_dict_from_url`. |
| 292 | +
|
| 293 | + Args: |
| 294 | + checkpoint_path: Path to the checkpoint file. |
| 295 | + output_dir: Output directory to store the hashed checkpoint file. |
| 296 | +
|
| 297 | + Returns: |
| 298 | + Path to the hashed checkpoint file, The 8 digits of SHA256 hash. |
| 299 | +
|
| 300 | + .. versionadded:: 0.5.0 |
| 301 | + """ |
| 302 | + |
| 303 | + if isinstance(checkpoint_path, str): |
| 304 | + checkpoint_path = Path(checkpoint_path) |
| 305 | + |
| 306 | + if isinstance(output_dir, str): |
| 307 | + output_dir = Path(output_dir) |
| 308 | + |
| 309 | + sha_hash = hashlib.sha256(checkpoint_path.read_bytes()).hexdigest() |
| 310 | + old_filename = checkpoint_path.stem |
| 311 | + new_filename = "-".join((old_filename, sha_hash[:8])) + ".pt" |
| 312 | + |
| 313 | + hash_checkpoint_path = output_dir / new_filename |
| 314 | + shutil.move(str(checkpoint_path), hash_checkpoint_path) |
| 315 | + |
| 316 | + return hash_checkpoint_path, sha_hash |
0 commit comments