Skip to content

Commit 41e149b

Browse files
committed
Fix: pass scaling, dropout to add_mask
1. apply scaling when we use exp-weights 2. remove previous mask before we do top-k / top-p
1 parent b8c428d commit 41e149b

File tree

12 files changed

+104
-36
lines changed

12 files changed

+104
-36
lines changed

sparse_attention_hub/metric_logging/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def flush(self) -> None:
190190
return
191191

192192
# Get current timestamp for filename
193-
filename = f"micro_metrics.jsonl"
193+
filename = "micro_metrics.jsonl"
194194
filepath = os.path.join(self.log_path, filename)
195195

196196
# Write events to file

sparse_attention_hub/sparse_attention/research_attention/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def custom_attention(
105105
queries=queries,
106106
values=values,
107107
attention_mask=attention_mask,
108+
scaling=scaling,
109+
dropout=dropout,
108110
sparse_meta_data=sparse_meta_data,
109111
previous_mask=sparse_attention_mask,
110112
**kwargs,

sparse_attention_hub/sparse_attention/research_attention/maskers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def add_mask(
160160
queries: torch.Tensor,
161161
values: torch.Tensor,
162162
attention_mask: torch.Tensor,
163+
scaling: float,
164+
dropout: float,
163165
sparse_meta_data: Dict[Any, Any], # want to keep it general here.
164166
previous_mask: Mask,
165167
**kwargs: Dict[str, Any],

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/basic_fixed.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def add_mask(
3939
queries: torch.Tensor,
4040
values: torch.Tensor,
4141
attention_mask: torch.Tensor,
42+
scaling: float,
43+
dropout: float,
4244
sparse_meta_data: Dict[Any, Any],
4345
previous_mask: Mask,
4446
**kwargs: Dict[str, Any],
@@ -142,6 +144,8 @@ def add_mask(
142144
queries: torch.Tensor,
143145
values: torch.Tensor,
144146
attention_mask: torch.Tensor,
147+
scaling: float,
148+
dropout: float,
145149
sparse_meta_data: Dict[Any, Any],
146150
previous_mask: Mask,
147151
**kwargs: Dict[str, Any],
@@ -182,6 +186,8 @@ def add_mask(
182186
queries: torch.Tensor,
183187
values: torch.Tensor,
184188
attention_mask: torch.Tensor,
189+
scaling: float,
190+
dropout: float,
185191
sparse_meta_data: Dict[Any, Any],
186192
previous_mask: Mask,
187193
**kwargs: Dict[str, Any],

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/double_sparsity_top_k.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def add_mask(
5353
queries: torch.Tensor,
5454
values: torch.Tensor,
5555
attention_mask: torch.Tensor,
56+
scaling: float,
57+
dropout: float,
5658
sparse_meta_data: Dict[Any, Any],
5759
previous_mask: Mask,
5860
**kwargs: Dict[str, Any],

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def add_mask(
5757
queries: torch.Tensor,
5858
values: torch.Tensor,
5959
attention_mask: torch.Tensor,
60+
scaling: float,
61+
dropout: float,
6062
sparse_meta_data: Dict[str, Dict[int, Optional[torch.Tensor]]],
6163
previous_mask: Mask,
6264
**kwargs: Dict[str, Any],
@@ -85,6 +87,8 @@ def add_mask(
8587
effective_heavy_size,
8688
keys,
8789
queries,
90+
attention_mask,
91+
previous_mask.get_dense_mask(),
8892
sparse_meta_data,
8993
previous_mask,
9094
layer_idx,
@@ -143,14 +147,21 @@ def _create_hash_topk_mask(
143147
heavy_size: int,
144148
keys: torch.Tensor,
145149
queries: torch.Tensor,
150+
attention_mask: torch.Tensor,
146151
sparse_meta_data: Dict[str, Dict[int, Optional[torch.Tensor]]],
147152
previous_mask: Mask,
148153
layer_idx: int,
149154
**kwargs: Dict[str, Any],
150155
) -> Mask:
151156
"""Create hash attention top-K mask using hash-based scoring."""
152157
scores: torch.Tensor = self._compute_hashattention_score(
153-
queries, keys, sparse_meta_data, layer_idx, **kwargs
158+
queries,
159+
keys,
160+
attention_mask,
161+
previous_mask.get_dense_mask(),
162+
sparse_meta_data,
163+
layer_idx,
164+
**kwargs,
154165
)
155166
top_k_indices: torch.Tensor = self._get_topk_indices_from_inactive_positions(
156167
scores, previous_mask, heavy_size
@@ -303,6 +314,8 @@ def _compute_hashattention_score(
303314
self,
304315
queries: torch.Tensor,
305316
keys: torch.Tensor,
317+
attention_mask: torch.Tensor,
318+
previous_dense_mask: torch.Tensor,
306319
sparse_meta_data: Dict[str, Dict[int, Optional[torch.Tensor]]],
307320
layer_idx: int,
308321
**kwargs: Dict[str, Any],
@@ -319,7 +332,13 @@ def _compute_hashattention_score(
319332
)
320333

321334
# (B, H, #queries, hat_bits) x (B, H, hat_bits, #keys) -> (B, H, #queries, #keys)
322-
return torch.matmul(query_signatures, key_signatures.transpose(-2, -1))
335+
scores: torch.Tensor = torch.matmul(
336+
query_signatures, key_signatures.transpose(-2, -1)
337+
)
338+
if attention_mask is not None:
339+
scores = scores + attention_mask[:, :, :, : keys.shape[-2]]
340+
scores[previous_dense_mask != 0] = torch.finfo(scores.dtype).min
341+
return scores
323342

324343
@classmethod
325344
def create_from_config(cls, config: MaskerConfig) -> "HashAttentionTopKMasker":

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_k.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def add_mask(
4343
queries: torch.Tensor,
4444
values: torch.Tensor,
4545
attention_mask: torch.Tensor,
46+
scaling: float,
47+
dropout: float,
4648
sparse_meta_data: Dict[Any, Any],
4749
previous_mask: Mask,
4850
**kwargs: Dict[str, Any],
@@ -64,7 +66,12 @@ def add_mask(
6466

6567
# Create oracle top-K mask
6668
oracle_mask: Mask = self._create_oracle_topk_mask(
67-
tensor_dims, effective_heavy_size, keys, queries, previous_mask
69+
tensor_dims,
70+
effective_heavy_size,
71+
keys,
72+
queries,
73+
attention_mask,
74+
previous_mask,
6875
)
6976
return previous_mask.merge_mask(oracle_mask, inplace=False)
7077

@@ -84,11 +91,12 @@ def _create_oracle_topk_mask(
8491
heavy_size: int,
8592
keys: torch.Tensor,
8693
queries: torch.Tensor,
94+
attention_mask: torch.Tensor,
8795
previous_mask: Mask,
8896
) -> Mask:
8997
"""Create oracle top-K mask using raw attention scores."""
9098
raw_attention_scores: torch.Tensor = self._compute_raw_attention_scores(
91-
keys, queries
99+
keys, queries, attention_mask, previous_mask.get_dense_mask()
92100
)
93101
top_k_indices: torch.Tensor = self._get_topk_indices_from_inactive_positions(
94102
raw_attention_scores, previous_mask, heavy_size
@@ -98,12 +106,20 @@ def _create_oracle_topk_mask(
98106
)
99107

100108
def _compute_raw_attention_scores(
101-
self, keys: torch.Tensor, queries: torch.Tensor
109+
self,
110+
keys: torch.Tensor,
111+
queries: torch.Tensor,
112+
attention_mask: torch.Tensor,
113+
previous_dense_mask: torch.Tensor,
102114
) -> torch.Tensor:
103115
"""Compute raw attention scores using query-key dot product."""
104116
ngroups = _get_num_key_value_groups(queries, keys)
105117
keys = repeat_kv(keys, ngroups)
106-
return torch.matmul(queries, keys.transpose(-2, -1))
118+
scores: torch.Tensor = torch.matmul(queries, keys.transpose(-2, -1))
119+
if attention_mask is not None:
120+
scores = scores + attention_mask[:, :, :, : keys.shape[-2]]
121+
scores[previous_dense_mask != 0] = torch.finfo(scores.dtype).min
122+
return scores
107123

108124
@classmethod
109125
def create_from_config(cls, config: MaskerConfig) -> "OracleTopK":

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def add_mask(
4444
queries: torch.Tensor,
4545
values: torch.Tensor,
4646
attention_mask: torch.Tensor,
47+
scaling: float,
48+
dropout: float,
4749
sparse_meta_data: Dict[Any, Any],
4850
previous_mask: Mask,
4951
**kwargs: Dict[str, Any],
@@ -62,7 +64,7 @@ def add_mask(
6264

6365
# Create oracle top-P attention mask
6466
oracle_mask: Mask = self._create_oracle_top_p_mask(
65-
tensor_dims, keys, queries, previous_mask
67+
tensor_dims, keys, queries, previous_mask, attention_mask, scaling
6668
)
6769
return previous_mask.merge_mask(oracle_mask, inplace=False)
6870

@@ -71,13 +73,25 @@ def _should_use_full_attention(self, dims: AttentionTensorDimensions) -> bool:
7173
effective_size: int = int(self.top_p * dims.seq_len_keys)
7274
return dims.seq_len_keys <= effective_size
7375

74-
def _compute_attention_scores(
75-
self, keys: torch.Tensor, queries: torch.Tensor
76+
def _compute_exp_attention_scores(
77+
self,
78+
keys: torch.Tensor,
79+
queries: torch.Tensor,
80+
previous_dense_mask: torch.Tensor,
81+
attention_mask: torch.Tensor,
82+
scaling: float,
7683
) -> torch.Tensor:
7784
"""Compute exp(attention scores) between queries and keys."""
7885
ngroups = _get_num_key_value_groups(queries, keys)
7986
keys = repeat_kv(keys, ngroups)
80-
raw_attention_scores = queries @ keys.transpose(-2, -1)
87+
raw_attention_scores = torch.matmul(queries, keys.transpose(2, 3)) * scaling
88+
if attention_mask is not None:
89+
raw_attention_scores = (
90+
raw_attention_scores + attention_mask[:, :, :, : keys.shape[-2]]
91+
)
92+
raw_attention_scores[previous_dense_mask != 0] = torch.finfo(
93+
raw_attention_scores.dtype
94+
).min
8195
_max_attention_score = raw_attention_scores.max(dim=-1, keepdim=True)[0]
8296
adjusted = torch.exp(raw_attention_scores - _max_attention_score)
8397
return adjusted
@@ -101,42 +115,33 @@ def _compute_top_p_thresholds(
101115

102116
# Find positions where normalized_cumsum >= top_p
103117
threshold_positions = torch.searchsorted(
104-
normalized_cumsum, top_p_tensor, side="left"
105-
)
106-
107-
# Prepare indices for advanced indexing (shape-agnostic)
108-
leading_shape = scores.shape[:-1]
109-
idx_grids = torch.meshgrid(
110-
*[torch.arange(s, device=scores.device) for s in leading_shape],
111-
indexing="ij",
118+
normalized_cumsum, top_p_tensor, side="right"
112119
)
113-
thresholds = sorted_scores[idx_grids + (threshold_positions.squeeze(-1),)]
114-
115-
# Add trailing singleton dimension for broadcasting
116-
return thresholds.unsqueeze(-1)
120+
thresholds = torch.gather(sorted_scores, dim=-1, index=threshold_positions)
121+
return thresholds
117122

118123
def _create_oracle_top_p_mask(
119124
self,
120125
dims: AttentionTensorDimensions,
121126
keys: torch.Tensor,
122127
queries: torch.Tensor,
123128
previous_mask: Mask,
129+
attention_mask: torch.Tensor,
130+
scaling: float,
124131
) -> Mask:
125132
"""Create oracle top-P attention mask using vectorized computation."""
126-
# Get attention scores
127-
scores: torch.Tensor = self._compute_attention_scores(keys, queries)
128-
# Get previous dense mask and mask out already active positions
133+
# Get attention scores after masking out already active positions
129134
previous_dense_mask: torch.Tensor = previous_mask.get_dense_mask()
130-
masked_scores: torch.Tensor = scores.clone()
131-
masked_scores[previous_dense_mask != 0] = float("-inf")
135+
scores: torch.Tensor = self._compute_exp_attention_scores(
136+
keys, queries, previous_dense_mask, attention_mask, scaling
137+
)
132138

133139
# Compute thresholds using vectorized operations
134-
thresholds: torch.Tensor = self._compute_top_p_thresholds(
135-
masked_scores, self.top_p
136-
)
140+
thresholds: torch.Tensor = self._compute_top_p_thresholds(scores, self.top_p)
141+
thresholds = thresholds.to(queries.dtype)
137142

138143
# Create dense mask: scores >= thresholds
139-
dense_mask: torch.Tensor = masked_scores >= thresholds
144+
dense_mask: torch.Tensor = scores >= thresholds
140145

141146
# Create mask object
142147
mask_shape: tuple = (

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/pq_top_k.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def add_mask(
3939
queries: torch.Tensor,
4040
values: torch.Tensor,
4141
attention_mask: torch.Tensor,
42+
scaling: float,
43+
dropout: float,
4244
sparse_meta_data: Dict[Any, Any],
4345
previous_mask: Mask,
4446
**kwargs: Dict[str, Any],

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,18 @@ def __init__(self, config: AdaptiveSamplingMaskerConfig) -> None:
144144
self.delta_ppf = float(norm.ppf(1 - self.delta))
145145

146146
def _compute_exp_attention_scores(
147-
self, queries: torch.Tensor, keys: torch.Tensor
147+
self,
148+
queries: torch.Tensor,
149+
keys: torch.Tensor,
150+
scaling: float,
151+
attention_mask: torch.Tensor,
148152
) -> torch.Tensor:
149153
"""Compute exponential attention scores with numerical stability."""
150154
ngroups = _get_num_key_value_groups(queries, keys)
151155
keys = repeat_kv(keys, ngroups)
152-
raw_scores = torch.matmul(queries, keys.transpose(-2, -1))
156+
raw_scores = torch.matmul(queries, keys.transpose(-2, -1)) * scaling
157+
if attention_mask is not None:
158+
raw_scores = raw_scores + attention_mask[:, :, :, : keys.shape[-2]]
153159
max_scores = torch.max(raw_scores, dim=-1, keepdim=True)[0]
154160
return torch.exp(raw_scores - max_scores)
155161

@@ -244,6 +250,8 @@ def add_mask(
244250
queries: torch.Tensor,
245251
values: torch.Tensor,
246252
attention_mask: torch.Tensor,
253+
scaling: float,
254+
dropout: float,
247255
sparse_meta_data: Dict[Any, Any],
248256
previous_mask: Mask,
249257
**kwargs: Dict[str, Any],
@@ -280,8 +288,10 @@ def add_mask(
280288
dims.seq_len_queries,
281289
dims.seq_len_keys,
282290
)
283-
284-
expwts = self._compute_exp_attention_scores(queries, keys)
291+
# Compute attention scores after removing attention_mask
292+
expwts = self._compute_exp_attention_scores(
293+
queries, keys, scaling, attention_mask
294+
)
285295
static_denominator = apply_inv_mask_sum(expwts, previous_mask)
286296

287297
# Get sampling parameters

0 commit comments

Comments
 (0)