-
Notifications
You must be signed in to change notification settings - Fork 87
Description
Hello,
Thanks for your amazing work!
I was wondering if it would be possible to process multiple videos (of same resolution and same length) in parallel.
From a little investigation, I could see that it is made impossible by this snippet in wan_video_edit.py:
@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
q_w, k_w, topk=10, local_attn_mask=None):
assert batch_size == 1, "Only batch_size=1 supported for now"
...I tried then changing the SelfAttention.forward method from
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)to the following (compute masks per sample and then concatenating the results).
if B == 1:
attention_mask = generate_draft_block_mask(
B,
self.num_heads,
seqlen,
q_w,
k_w,
topk=topk,
local_attn_mask=self.local_attn_mask,
)
else:
masks = []
for i in range(B):
q_w_i = q_w[i * block_n : (i + 1) * block_n]
k_w_i = k_w[i * block_n_kv : (i + 1) * block_n_kv]
mask_i = generate_draft_block_mask(
1,
self.num_heads,
seqlen,
q_w_i,
k_w_i,
topk=topk,
local_attn_mask=self.local_attn_mask,
)
masks.append(mask_i)
attention_mask = torch.cat(masks, dim=0)With additionnal modifications in the generate_draft_block_mask and flash_attention functions. I got to a point where I could get it running, but it will always provoque an OOM error even for a very small and low-res video with batch-size 2 with the TinyLongPipeline. Considering that I'm trying this on an H100 with 80Gb of VRAM, I definitely think that I'm doing things wrong...