2020
2121args = parser .parse_args ()
2222
23- # if dtype is not specified, use the dtype of the original checkpoint(recommended)
24- if args .dtype == "fp16" :
25- dtype = torch .float16
26- elif args .dtype == "bf16" :
27- dtype = torch .bfloat16
28- elif args .dtype == "fp32" :
29- dtype = torch .float32
30- else :
31- dtype = None
32-
3323
3424def load_original_checkpoint (ckpt_path ):
3525 original_state_dict = safetensors .torch .load_file (ckpt_path )
@@ -245,9 +235,6 @@ def convert_sd3_transformer_checkpoint_to_diffusers(
245235 original_state_dict .pop ("final_layer.adaLN_modulation.1.bias" ), dim = caption_projection_dim
246236 )
247237
248- if len (original_state_dict ) > 0 :
249- raise ValueError (f"{ len (original_state_dict )} keys are not converted: { original_state_dict .keys ()} " )
250-
251238 return converted_state_dict
252239
253240
@@ -260,43 +247,71 @@ def is_vae_in_checkpoint(original_state_dict):
260247def get_add_attn2_layers (state_dict ):
261248 add_attn2_layers = []
262249 for key in state_dict .keys ():
263- if "attn2.to_q.weight " in key :
250+ if "attn2." in key :
264251 # Extract the layer number from the key
265252 layer_num = int (key .split ("." )[1 ])
266253 add_attn2_layers .append (layer_num )
267254 return tuple (sorted (add_attn2_layers ))
268255
269256
257+ def get_pos_embed_max_size (state_dict ):
258+ num_patches = state_dict ["pos_embed" ].shape [1 ]
259+ pos_embed_max_size = int (num_patches ** 0.5 )
260+ return pos_embed_max_size
261+
262+
263+ def get_caption_projection_dim (state_dict ):
264+ caption_projection_dim = state_dict ["context_embedder.weight" ].shape [0 ]
265+ return caption_projection_dim
266+
267+
270268def main (args ):
271269 original_ckpt = load_original_checkpoint (args .checkpoint_path )
272270 original_dtype = next (iter (original_ckpt .values ())).dtype
273- if dtype is None :
271+
272+ # Initialize dtype with a default value
273+ dtype = None
274+
275+ if args .dtype is None :
274276 dtype = original_dtype
275- elif dtype != original_dtype :
277+ elif args .dtype == "fp16" :
278+ dtype = torch .float16
279+ elif args .dtype == "bf16" :
280+ dtype = torch .bfloat16
281+ elif args .dtype == "fp32" :
282+ dtype = torch .float32
283+ else :
284+ raise ValueError (f"Unsupported dtype: { args .dtype } " )
285+
286+ if dtype != original_dtype :
276287 print (f"Checkpoint dtype { original_dtype } does not match requested dtype { dtype } " )
277288
278289 num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in original_ckpt if "joint_blocks" in k ))[- 1 ] + 1 # noqa: C401
279- caption_projection_dim = 1536
290+
291+ caption_projection_dim = get_caption_projection_dim (original_ckpt )
292+
280293 # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
281294 add_attn2_layers = get_add_attn2_layers (original_ckpt )
295+
282296 # sd3.5 use qk norm("rms_norm")
283297 has_qk_norm = any ("ln_q" in key for key in original_ckpt .keys ())
284- # sd3.5 use pox_embed_max_size=384 and sd3.0 use 192
285- pos_embed_max_size = 384 if has_qk_norm else 192
298+
299+ # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
300+ pos_embed_max_size = get_pos_embed_max_size (original_ckpt )
286301
287302 converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers (
288303 original_ckpt , num_layers , caption_projection_dim , add_attn2_layers , has_qk_norm
289304 )
290305
291306 with CTX ():
292307 transformer = SD3Transformer2DModel (
293- sample_size = 64 ,
308+ sample_size = 128 ,
294309 patch_size = 2 ,
295310 in_channels = 16 ,
296311 joint_attention_dim = 4096 ,
297312 num_layers = num_layers ,
298313 caption_projection_dim = caption_projection_dim ,
299- num_attention_heads = 24 ,
314+ num_attention_heads = num_layers ,
300315 pos_embed_max_size = pos_embed_max_size ,
301316 qk_norm = "rms_norm" if has_qk_norm else None ,
302317 add_attn2_layers = add_attn2_layers ,
0 commit comments