Skip to content

Commit 3ab805a

Browse files
committed
update
1 parent 8e659b9 commit 3ab805a

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

scripts/convert_sd3_to_diffusers.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@
2020

2121
args = parser.parse_args()
2222

23-
# if dtype is not specified, use the dtype of the original checkpoint(recommended)
24-
if args.dtype == "fp16":
25-
dtype = torch.float16
26-
elif args.dtype == "bf16":
27-
dtype = torch.bfloat16
28-
elif args.dtype == "fp32":
29-
dtype = torch.float32
30-
else:
31-
dtype = None
32-
3323

3424
def load_original_checkpoint(ckpt_path):
3525
original_state_dict = safetensors.torch.load_file(ckpt_path)
@@ -245,9 +235,6 @@ def convert_sd3_transformer_checkpoint_to_diffusers(
245235
original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
246236
)
247237

248-
if len(original_state_dict) > 0:
249-
raise ValueError(f"{len(original_state_dict)} keys are not converted: {original_state_dict.keys()}")
250-
251238
return converted_state_dict
252239

253240

@@ -260,43 +247,71 @@ def is_vae_in_checkpoint(original_state_dict):
260247
def get_add_attn2_layers(state_dict):
261248
add_attn2_layers = []
262249
for key in state_dict.keys():
263-
if "attn2.to_q.weight" in key:
250+
if "attn2." in key:
264251
# Extract the layer number from the key
265252
layer_num = int(key.split(".")[1])
266253
add_attn2_layers.append(layer_num)
267254
return tuple(sorted(add_attn2_layers))
268255

269256

257+
def get_pos_embed_max_size(state_dict):
258+
num_patches = state_dict["pos_embed"].shape[1]
259+
pos_embed_max_size = int(num_patches**0.5)
260+
return pos_embed_max_size
261+
262+
263+
def get_caption_projection_dim(state_dict):
264+
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
265+
return caption_projection_dim
266+
267+
270268
def main(args):
271269
original_ckpt = load_original_checkpoint(args.checkpoint_path)
272270
original_dtype = next(iter(original_ckpt.values())).dtype
273-
if dtype is None:
271+
272+
# Initialize dtype with a default value
273+
dtype = None
274+
275+
if args.dtype is None:
274276
dtype = original_dtype
275-
elif dtype != original_dtype:
277+
elif args.dtype == "fp16":
278+
dtype = torch.float16
279+
elif args.dtype == "bf16":
280+
dtype = torch.bfloat16
281+
elif args.dtype == "fp32":
282+
dtype = torch.float32
283+
else:
284+
raise ValueError(f"Unsupported dtype: {args.dtype}")
285+
286+
if dtype != original_dtype:
276287
print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}")
277288

278289
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
279-
caption_projection_dim = 1536
290+
291+
caption_projection_dim = get_caption_projection_dim(original_ckpt)
292+
280293
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
281294
add_attn2_layers = get_add_attn2_layers(original_ckpt)
295+
282296
# sd3.5 use qk norm("rms_norm")
283297
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
284-
# sd3.5 use pox_embed_max_size=384 and sd3.0 use 192
285-
pos_embed_max_size = 384 if has_qk_norm else 192
298+
299+
# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
300+
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
286301

287302
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
288303
original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm
289304
)
290305

291306
with CTX():
292307
transformer = SD3Transformer2DModel(
293-
sample_size=64,
308+
sample_size=128,
294309
patch_size=2,
295310
in_channels=16,
296311
joint_attention_dim=4096,
297312
num_layers=num_layers,
298313
caption_projection_dim=caption_projection_dim,
299-
num_attention_heads=24,
314+
num_attention_heads=num_layers,
300315
pos_embed_max_size=pos_embed_max_size,
301316
qk_norm="rms_norm" if has_qk_norm else None,
302317
add_attn2_layers=add_attn2_layers,

0 commit comments

Comments
 (0)