Skip to content

Commit 3d24bb2

Browse files
authored
Merge pull request #1 from sayakpaul/support-diffusers-ckpt
feat: support diffusers ckpt.
2 parents 7c9dda5 + b16d946 commit 3d24bb2

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

convert_flux_to_gguf.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import argparse
99
import contextlib
1010
import json
11+
import safetensors.torch
1112
import os
1213
import re
1314
import sys
@@ -177,6 +178,31 @@ def write(self) -> None:
177178
self.gguf_writer.write_tensors_to_file(progress=True)
178179
self.gguf_writer.close()
179180

181+
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
182+
def _merge_sharded_checkpoints(folder: Path):
183+
with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f:
184+
ckpt_metadata = json.load(f)
185+
weight_map = ckpt_metadata.get("weight_map", None)
186+
if weight_map is None:
187+
raise KeyError("'weight_map' key not found in the shard index file.")
188+
189+
# Collect all unique safetensors files from weight_map
190+
files_to_load = set(weight_map.values())
191+
merged_state_dict = {}
192+
193+
# Load tensors from each unique file
194+
for file_name in files_to_load:
195+
part_file_path = folder / file_name
196+
if not os.path.exists(part_file_path):
197+
raise FileNotFoundError(f"Part file {file_name} not found.")
198+
199+
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
200+
for tensor_key in f.keys():
201+
if tensor_key in weight_map:
202+
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
203+
204+
return merged_state_dict
205+
180206

181207
def parse_args() -> argparse.Namespace:
182208
parser = argparse.ArgumentParser(
@@ -216,10 +242,29 @@ def main() -> None:
216242
else:
217243
logging.basicConfig(level=logging.INFO)
218244

219-
if not args.model.is_file():
245+
if not args.model.is_dir() and not args.model.is_file():
220246
logging.error(f"Model path {args.model} does not exist.")
221247
sys.exit(1)
222248

249+
if args.model.is_dir():
250+
logging.info("Supplied a directory.")
251+
merged_state_dict = None
252+
files = list(args.model.glob('*.safetensors'))
253+
n = len(files)
254+
if n == 0:
255+
logging.error("No safetensors files found.")
256+
sys.exit(1)
257+
if n == 1:
258+
logging.info(f"Assinging {files[0]} to `args.model`")
259+
args.model = files[0]
260+
if n > 1:
261+
assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*"))
262+
merged_state_dict = _merge_sharded_checkpoints(args.model)
263+
filepath = "merged_state_dict.safetensors"
264+
safetensors.torch.save_file(merged_state_dict, filepath)
265+
logging.info(f"Serialized merged state dict to {filepath}")
266+
args.model = Path(filepath)
267+
223268
if args.model.suffix != ".safetensors":
224269
logging.error(f"Model path {args.model} is not a safetensors file.")
225270
sys.exit(1)

0 commit comments

Comments
 (0)