Skip to content

Commit 35da60f

Browse files
committed
Fix GQA in oracle-top-p and magicpig
1 parent d05fc36 commit 35da60f

File tree

2 files changed

+14
-0
lines changed
  • sparse_attention_hub/sparse_attention/research_attention/maskers

2 files changed

+14
-0
lines changed

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
MaskerConfig,
1111
MaskerRegistry,
1212
)
13+
from sparse_attention_hub.sparse_attention.utils.kv_utils import (
14+
_get_num_key_value_groups,
15+
repeat_kv,
16+
)
1317
from sparse_attention_hub.sparse_attention.utils.mask import Mask
1418

1519
from ..base import TopPMasker, TopPMaskerConfig
@@ -71,6 +75,8 @@ def _compute_attention_scores(
7175
self, keys: torch.Tensor, queries: torch.Tensor
7276
) -> torch.Tensor:
7377
"""Compute exp(attention scores) between queries and keys."""
78+
ngroups = _get_num_key_value_groups(queries, keys)
79+
keys = repeat_kv(keys, ngroups)
7480
raw_attention_scores = queries @ keys.transpose(-2, -1)
7581
_max_attention_score = raw_attention_scores.max(dim=-1, keepdim=True)[0]
7682
adjusted = torch.exp(raw_attention_scores - _max_attention_score)

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
MaskerConfig,
1616
MaskerRegistry,
1717
)
18+
from sparse_attention_hub.sparse_attention.utils.kv_utils import (
19+
_get_num_key_value_groups,
20+
repeat_kv,
21+
)
22+
1823
from sparse_attention_hub.sparse_attention.utils.mask import Mask
1924

2025
from ..base import SamplingMasker, SamplingMaskerConfig
@@ -308,6 +313,9 @@ def add_mask(
308313
seq_len_queries: int = queries.shape[2]
309314
seq_len_keys: int = keys.shape[2]
310315

316+
ngroups = _get_num_key_value_groups(queries, keys)
317+
keys = repeat_kv(keys, ngroups)
318+
311319
probabilities: torch.Tensor = self._compute_probabilities(keys, queries)
312320
matches: torch.Tensor = self._compute_lsh_matches(keys, queries)
313321
dense_mask: torch.Tensor = matches * probabilities

0 commit comments

Comments
 (0)