File tree Expand file tree Collapse file tree 2 files changed +14
-0
lines changed
sparse_attention_hub/sparse_attention/research_attention/maskers Expand file tree Collapse file tree 2 files changed +14
-0
lines changed Original file line number Diff line number Diff line change 1010 MaskerConfig ,
1111 MaskerRegistry ,
1212)
13+ from sparse_attention_hub .sparse_attention .utils .kv_utils import (
14+ _get_num_key_value_groups ,
15+ repeat_kv ,
16+ )
1317from sparse_attention_hub .sparse_attention .utils .mask import Mask
1418
1519from ..base import TopPMasker , TopPMaskerConfig
@@ -71,6 +75,8 @@ def _compute_attention_scores(
7175 self , keys : torch .Tensor , queries : torch .Tensor
7276 ) -> torch .Tensor :
7377 """Compute exp(attention scores) between queries and keys."""
78+ ngroups = _get_num_key_value_groups (queries , keys )
79+ keys = repeat_kv (keys , ngroups )
7480 raw_attention_scores = queries @ keys .transpose (- 2 , - 1 )
7581 _max_attention_score = raw_attention_scores .max (dim = - 1 , keepdim = True )[0 ]
7682 adjusted = torch .exp (raw_attention_scores - _max_attention_score )
Original file line number Diff line number Diff line change 1515 MaskerConfig ,
1616 MaskerRegistry ,
1717)
18+ from sparse_attention_hub .sparse_attention .utils .kv_utils import (
19+ _get_num_key_value_groups ,
20+ repeat_kv ,
21+ )
22+
1823from sparse_attention_hub .sparse_attention .utils .mask import Mask
1924
2025from ..base import SamplingMasker , SamplingMaskerConfig
@@ -308,6 +313,9 @@ def add_mask(
308313 seq_len_queries : int = queries .shape [2 ]
309314 seq_len_keys : int = keys .shape [2 ]
310315
316+ ngroups = _get_num_key_value_groups (queries , keys )
317+ keys = repeat_kv (keys , ngroups )
318+
311319 probabilities : torch .Tensor = self ._compute_probabilities (keys , queries )
312320 matches : torch .Tensor = self ._compute_lsh_matches (keys , queries )
313321 dense_mask : torch .Tensor = matches * probabilities
You can’t perform that action at this time.
0 commit comments