Skip to content

Commit e682af2

Browse files
authored
Qwen Image Edit Support (#12164)
* feat(qwen-image): add qwen-image-edit support * fix(qwen image): - compatible with torch.compile in new rope setting - fix init import - add prompt truncation in img2img and inpaint pipe - remove unused logic and comment - add copy statement - guard logic for rope video shape tuple * fix(qwen image): - make fix-copies - update doc
1 parent a58a4f6 commit e682af2

File tree

9 files changed

+949
-89
lines changed

9 files changed

+949
-89
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@
492492
"QwenImageImg2ImgPipeline",
493493
"QwenImageInpaintPipeline",
494494
"QwenImagePipeline",
495+
"QwenImageEditPipeline",
495496
"ReduxImageEncoder",
496497
"SanaControlNetPipeline",
497498
"SanaPAGPipeline",
@@ -1123,6 +1124,7 @@
11231124
PixArtAlphaPipeline,
11241125
PixArtSigmaPAGPipeline,
11251126
PixArtSigmaPipeline,
1127+
QwenImageEditPipeline,
11261128
QwenImageImg2ImgPipeline,
11271129
QwenImageInpaintPipeline,
11281130
QwenImagePipeline,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import functools
1716
import math
1817
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160
super().__init__()
162161
self.theta = theta
163162
self.axes_dim = axes_dim
164-
pos_index = torch.arange(1024)
165-
neg_index = torch.arange(1024).flip(0) * -1 - 1
166-
pos_freqs = torch.cat(
163+
pos_index = torch.arange(4096)
164+
neg_index = torch.arange(4096).flip(0) * -1 - 1
165+
self.pos_freqs = torch.cat(
167166
[
168167
self.rope_params(pos_index, self.axes_dim[0], self.theta),
169168
self.rope_params(pos_index, self.axes_dim[1], self.theta),
170169
self.rope_params(pos_index, self.axes_dim[2], self.theta),
171170
],
172171
dim=1,
173172
)
174-
neg_freqs = torch.cat(
173+
self.neg_freqs = torch.cat(
175174
[
176175
self.rope_params(neg_index, self.axes_dim[0], self.theta),
177176
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179
dim=1,
181180
)
182181
self.rope_cache = {}
183-
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184-
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
185182

186-
# 是否使用 scale rope
183+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184
self.scale_rope = scale_rope
188185

189186
def rope_params(self, index, dim, theta=10000):
@@ -201,35 +198,47 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199
txt_length: [bs] a list of 1 integers representing the length of the text
203200
"""
201+
if self.pos_freqs.device != device:
202+
self.pos_freqs = self.pos_freqs.to(device)
203+
self.neg_freqs = self.neg_freqs.to(device)
204+
204205
if isinstance(video_fhw, list):
205206
video_fhw = video_fhw[0]
206-
frame, height, width = video_fhw
207-
rope_key = f"{frame}_{height}_{width}"
208-
209-
if not torch.compiler.is_compiling():
210-
if rope_key not in self.rope_cache:
211-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212-
vid_freqs = self.rope_cache[rope_key]
213-
else:
214-
vid_freqs = self._compute_video_freqs(frame, height, width)
207+
if not isinstance(video_fhw, list):
208+
video_fhw = [video_fhw]
209+
210+
vid_freqs = []
211+
max_vid_index = 0
212+
for idx, fhw in enumerate(video_fhw):
213+
frame, height, width = fhw
214+
rope_key = f"{idx}_{height}_{width}"
215+
216+
if not torch.compiler.is_compiling():
217+
if rope_key not in self.rope_cache:
218+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219+
video_freq = self.rope_cache[rope_key]
220+
else:
221+
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
vid_freqs.append(video_freq)
215223

216-
if self.scale_rope:
217-
max_vid_index = max(height // 2, width // 2)
218-
else:
219-
max_vid_index = max(height, width)
224+
if self.scale_rope:
225+
max_vid_index = max(height // 2, width // 2, max_vid_index)
226+
else:
227+
max_vid_index = max(height, width, max_vid_index)
220228

221229
max_len = max(txt_seq_lens)
222230
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
231+
vid_freqs = torch.cat(vid_freqs, dim=0)
223232

224233
return vid_freqs, txt_freqs
225234

226235
@functools.lru_cache(maxsize=None)
227-
def _compute_video_freqs(self, frame, height, width):
236+
def _compute_video_freqs(self, frame, height, width, idx=0):
228237
seq_lens = frame * height * width
229238
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230239
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231240

232-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
241+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233242
if self.scale_rope:
234243
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235244
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@
391391
"QwenImagePipeline",
392392
"QwenImageImg2ImgPipeline",
393393
"QwenImageInpaintPipeline",
394+
"QwenImageEditPipeline",
394395
]
395396
try:
396397
if not is_onnx_available():
@@ -708,7 +709,12 @@
708709
from .paint_by_example import PaintByExamplePipeline
709710
from .pia import PIAPipeline
710711
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
711-
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712+
from .qwenimage import (
713+
QwenImageEditPipeline,
714+
QwenImageImg2ImgPipeline,
715+
QwenImageInpaintPipeline,
716+
QwenImagePipeline,
717+
)
712718
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
713719
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
714720
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2828
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2930

3031
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3132
try:
@@ -35,6 +36,7 @@
3536
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3637
else:
3738
from .pipeline_qwenimage import QwenImagePipeline
39+
from .pipeline_qwenimage_edit import QwenImageEditPipeline
3840
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
3941
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4042
else:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def encode_prompt(
253253
if prompt_embeds is None:
254254
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255255

256+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
257+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
258+
256259
_, seq_len, _ = prompt_embeds.shape
257260
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
258261
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -316,20 +319,6 @@ def check_inputs(
316319
if max_sequence_length is not None and max_sequence_length > 1024:
317320
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
318321

319-
@staticmethod
320-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
321-
latent_image_ids = torch.zeros(height, width, 3)
322-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
323-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
324-
325-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
326-
327-
latent_image_ids = latent_image_ids.reshape(
328-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
329-
)
330-
331-
return latent_image_ids.to(device=device, dtype=dtype)
332-
333322
@staticmethod
334323
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
335324
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -402,8 +391,7 @@ def prepare_latents(
402391
shape = (batch_size, 1, num_channels_latents, height, width)
403392

404393
if latents is not None:
405-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
406-
return latents.to(device=device, dtype=dtype), latent_image_ids
394+
return latents.to(device=device, dtype=dtype)
407395

408396
if isinstance(generator, list) and len(generator) != batch_size:
409397
raise ValueError(
@@ -414,9 +402,7 @@ def prepare_latents(
414402
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415403
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
416404

417-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
418-
419-
return latents, latent_image_ids
405+
return latents
420406

421407
@property
422408
def guidance_scale(self):
@@ -594,7 +580,7 @@ def __call__(
594580

595581
# 4. Prepare latent variables
596582
num_channels_latents = self.transformer.config.in_channels // 4
597-
latents, latent_image_ids = self.prepare_latents(
583+
latents = self.prepare_latents(
598584
batch_size * num_images_per_prompt,
599585
num_channels_latents,
600586
height,
@@ -604,7 +590,7 @@ def __call__(
604590
generator,
605591
latents,
606592
)
607-
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
593+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
608594

609595
# 5. Prepare timesteps
610596
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

0 commit comments

Comments
 (0)