Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Nov 28, 2025

What does this PR do?

  • Introduces a dedicated test suite for the Z-Image DiT.
  • Adds is_flaky decorator to test_inference() in the Z-Image pipeline test suite.
  • Adds a return_dict argument to the forward() of Z-Image DiT, following other models in the library.
    • As a consequence of this, I followed the return pattern, i.e., return a Transformer2DModelOutput type output or something like return (out,).

Notes

  • The model accepts the hidden states as a list[torch.Tensor] which differs from other models. Output also follows the same type. This is why I had to modify a couple of tests (where it was reasonably easy) to allow this. Tests, where it was not relatively easy, were skipped (such as test_training, `test_ema_training, etc.).
  • The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice.
  • Some of the group offloading tests were skipped because of states that interfered in between the tests (as also noted here).
  • x_pad_token and cap_pad_token params within the DiT are initialized with torch.empty(), possibly for memory efficiency, but they interfere during test in very weird ways. This is because torch.empty() can render NaNs. To prevent this from creeping into the tests, I tried adding is_flaky() to some of the tests that got affected by this, but that didn't help (see this). @JerryWu-code, would it be safe to get x_pad_token and cap_pad_token initialized deterministically, maybe with something like torch.ones()? Or do you think it would have memory implications?

Minor nits

  • We usually avoid raw assert statements inside the model implementations in favor of properly raising errors. Should we follow something similar here, too?
  • There is a self.scheduler.sigma_min = 0.0 inside the Z-Image pipeline:
    self.scheduler.sigma_min = 0.0
    . Maybe I am missing out on something but that seems like an antipattern to me.
  • The signature of forward() of the DiT has shorthand variable names: x, t, cap_feats, unlike hidden_states, timestep, and encoder_hidden_states.
  • Should _cfg_normalization and _cfg_truncation inside the pipeline be turned into properties like guidance_scale?

Maybe we could consider revisiting them (but not a priority perhaps).

Cc: @JerryWu-code

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu November 28, 2025 13:59
Comment on lines -661 to +636
return x, {}
if not return_dict:
return (x,)

return Transformer2DModelOutput(sample=x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a very safe change?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

sayakpaul commented Nov 28, 2025

Failures in "Fast tests for PRs / Fast PyTorch Models & Schedulers CPU tests (pull_request)" pass even when run with CUDA_VISIBLE_DEVICES="" pytest tests/models/transformers/test_models_transformer_z_image.py.

Edit: it likely fails when CUDA_VISIBLE_DEVICES="" pytest tests/models/ is run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants