|
8 | 8 | import argparse |
9 | 9 | import contextlib |
10 | 10 | import json |
| 11 | +import safetensors.torch |
11 | 12 | import os |
12 | 13 | import re |
13 | 14 | import sys |
@@ -177,6 +178,31 @@ def write(self) -> None: |
177 | 178 | self.gguf_writer.write_tensors_to_file(progress=True) |
178 | 179 | self.gguf_writer.close() |
179 | 180 |
|
| 181 | +# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 |
| 182 | +def _merge_sharded_checkpoints(folder: Path): |
| 183 | + with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: |
| 184 | + ckpt_metadata = json.load(f) |
| 185 | + weight_map = ckpt_metadata.get("weight_map", None) |
| 186 | + if weight_map is None: |
| 187 | + raise KeyError("'weight_map' key not found in the shard index file.") |
| 188 | + |
| 189 | + # Collect all unique safetensors files from weight_map |
| 190 | + files_to_load = set(weight_map.values()) |
| 191 | + merged_state_dict = {} |
| 192 | + |
| 193 | + # Load tensors from each unique file |
| 194 | + for file_name in files_to_load: |
| 195 | + part_file_path = folder / file_name |
| 196 | + if not os.path.exists(part_file_path): |
| 197 | + raise FileNotFoundError(f"Part file {file_name} not found.") |
| 198 | + |
| 199 | + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: |
| 200 | + for tensor_key in f.keys(): |
| 201 | + if tensor_key in weight_map: |
| 202 | + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) |
| 203 | + |
| 204 | + return merged_state_dict |
| 205 | + |
180 | 206 |
|
181 | 207 | def parse_args() -> argparse.Namespace: |
182 | 208 | parser = argparse.ArgumentParser( |
@@ -216,10 +242,29 @@ def main() -> None: |
216 | 242 | else: |
217 | 243 | logging.basicConfig(level=logging.INFO) |
218 | 244 |
|
219 | | - if not args.model.is_file(): |
| 245 | + if not args.model.is_dir() and not args.model.is_file(): |
220 | 246 | logging.error(f"Model path {args.model} does not exist.") |
221 | 247 | sys.exit(1) |
222 | 248 |
|
| 249 | + if args.model.is_dir(): |
| 250 | + logging.info("Supplied a directory.") |
| 251 | + merged_state_dict = None |
| 252 | + files = list(args.model.glob('*.safetensors')) |
| 253 | + n = len(files) |
| 254 | + if n == 0: |
| 255 | + logging.error("No safetensors files found.") |
| 256 | + sys.exit(1) |
| 257 | + if n == 1: |
| 258 | + logging.info(f"Assinging {files[0]} to `args.model`") |
| 259 | + args.model = files[0] |
| 260 | + if n > 1: |
| 261 | + assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) |
| 262 | + merged_state_dict = _merge_sharded_checkpoints(args.model) |
| 263 | + filepath = "merged_state_dict.safetensors" |
| 264 | + safetensors.torch.save_file(merged_state_dict, filepath) |
| 265 | + logging.info(f"Serialized merged state dict to {filepath}") |
| 266 | + args.model = Path(filepath) |
| 267 | + |
223 | 268 | if args.model.suffix != ".safetensors": |
224 | 269 | logging.error(f"Model path {args.model} is not a safetensors file.") |
225 | 270 | sys.exit(1) |
|
0 commit comments