-
Notifications
You must be signed in to change notification settings - Fork 87
Open
Description
Why do the topk selection across all query-key pairs together but not topk selection across the row for every query respectively. There may be only a part of tokens are updated in the attention module.
FlashVSR/diffsynth/models/wan_video_dit.py
Lines 140 to 149 in 914dcd4
| attn_map = torch.softmax(scores, dim=-1) | |
| attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) | |
| loop_num, s1, s2 = attn_map.shape | |
| flat = attn_map.reshape(loop_num, -1) | |
| n = flat.shape[1] | |
| apply_topk = min(flat.shape[1]-1, topk) | |
| thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] | |
| thresholds = thresholds.unsqueeze(1) | |
| mask_new = (flat > thresholds).reshape(loop_num, s1, s2) | |
| mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note |
Metadata
Metadata
Assignees
Labels
No labels