Skip to content

[Feature Request] Batch inference #52

@Alex-experiments

Description

@Alex-experiments

Hello,
Thanks for your amazing work!

I was wondering if it would be possible to process multiple videos (of same resolution and same length) in parallel.

From a little investigation, I could see that it is made impossible by this snippet in wan_video_edit.py:

@torch.no_grad()
def generate_draft_block_mask(batch_size, nheads, seqlen,
                              q_w, k_w, topk=10, local_attn_mask=None):
    assert batch_size == 1, "Only batch_size=1 supported for now"
    ...

I tried then changing the SelfAttention.forward method from

attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)

to the following (compute masks per sample and then concatenating the results).

        if B == 1:
            attention_mask = generate_draft_block_mask(
                B,
                self.num_heads,
                seqlen,
                q_w,
                k_w,
                topk=topk,
                local_attn_mask=self.local_attn_mask,
            )
        else:
            masks = []
            for i in range(B):
                q_w_i = q_w[i * block_n : (i + 1) * block_n]
                k_w_i = k_w[i * block_n_kv : (i + 1) * block_n_kv]
                mask_i = generate_draft_block_mask(
                    1,
                    self.num_heads,
                    seqlen,
                    q_w_i,
                    k_w_i,
                    topk=topk,
                    local_attn_mask=self.local_attn_mask,
                )
                masks.append(mask_i)
            attention_mask = torch.cat(masks, dim=0)

With additionnal modifications in the generate_draft_block_mask and flash_attention functions. I got to a point where I could get it running, but it will always provoque an OOM error even for a very small and low-res video with batch-size 2 with the TinyLongPipeline. Considering that I'm trying this on an H100 with 80Gb of VRAM, I definitely think that I'm doing things wrong...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions