Skip to content

Commit fffd964

Browse files
authored
fix FLUX.2 context parallel (#12737)
1 parent 859b809 commit fffd964

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,8 @@ class Flux2Transformer2DModel(
676676
"": {
677677
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
678678
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
679-
"img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
680-
"txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
679+
"img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
680+
"txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
681681
},
682682
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
683683
}

0 commit comments

Comments
 (0)