Skip to content

Commit 5824fa6

Browse files
committed
Fixing tests
1. changed signatures 2. bug in oracle top p
1 parent 0589b2f commit 5824fa6

File tree

8 files changed

+159
-14
lines changed

8 files changed

+159
-14
lines changed

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def add_mask(
8888
keys,
8989
queries,
9090
attention_mask,
91-
previous_mask.get_dense_mask(),
9291
sparse_meta_data,
9392
previous_mask,
9493
layer_idx,

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def _compute_top_p_thresholds(
117117
threshold_positions = torch.searchsorted(
118118
normalized_cumsum, top_p_tensor, side="right"
119119
)
120+
# if top_p is 1.0, then threshold_positions will be equal to sorted_scores.shape[-1]
121+
# which is not a valid index, so we clamp it to the last valid index
122+
threshold_positions = torch.clamp(threshold_positions, max=sorted_scores.shape[-1] - 1)
120123
thresholds = torch.gather(sorted_scores, dim=-1, index=threshold_positions)
121124
return thresholds
122125

tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_basic_fixed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def test_local_masker_add_mask_full_previous(self):
9595
queries=queries,
9696
values=values,
9797
attention_mask=None,
98+
scaling=1.0,
99+
dropout=0.0,
98100
sparse_meta_data=None,
99101
previous_mask=full_previous_mask,
100102
)
@@ -129,6 +131,8 @@ def test_local_masker_add_mask_small_sequence(self):
129131
queries=queries,
130132
values=values,
131133
attention_mask=None,
134+
scaling=1.0,
135+
dropout=0.0,
132136
sparse_meta_data=None,
133137
previous_mask=empty_previous_mask,
134138
)
@@ -163,6 +167,8 @@ def test_local_masker_add_mask_integer_window(self):
163167
queries=queries,
164168
values=values,
165169
attention_mask=None,
170+
scaling=1.0,
171+
dropout=0.0,
166172
sparse_meta_data=None,
167173
previous_mask=empty_previous_mask,
168174
)
@@ -216,6 +222,8 @@ def test_local_masker_add_mask_float_window(self):
216222
queries=queries,
217223
values=values,
218224
attention_mask=None,
225+
scaling=1.0,
226+
dropout=0.0,
219227
sparse_meta_data=None,
220228
previous_mask=empty_previous_mask,
221229
)
@@ -268,6 +276,8 @@ def test_local_masker_add_mask_merge_with_previous(self):
268276
queries=queries,
269277
values=values,
270278
attention_mask=None,
279+
scaling=1.0,
280+
dropout=0.0,
271281
sparse_meta_data=None,
272282
previous_mask=previous_mask,
273283
)
@@ -318,6 +328,8 @@ def test_local_masker_add_mask_edge_case_window_size_zero(self):
318328
queries=queries,
319329
values=values,
320330
attention_mask=None,
331+
scaling=1.0,
332+
dropout=0.0,
321333
sparse_meta_data=None,
322334
previous_mask=empty_previous_mask,
323335
)
@@ -463,6 +475,8 @@ def test_sink_masker_add_mask(self):
463475
queries=queries,
464476
values=values,
465477
attention_mask=None,
478+
scaling=1.0,
479+
dropout=0.0,
466480
sparse_meta_data=None,
467481
previous_mask=full_previous_mask,
468482
)
@@ -484,6 +498,8 @@ def test_sink_masker_add_mask(self):
484498
queries=queries,
485499
values=values,
486500
attention_mask=None,
501+
scaling=1.0,
502+
dropout=0.0,
487503
sparse_meta_data=None,
488504
previous_mask=empty_previous_mask,
489505
)
@@ -507,6 +523,8 @@ def test_sink_masker_add_mask(self):
507523
queries=queries,
508524
values=values,
509525
attention_mask=None,
526+
scaling=1.0,
527+
dropout=0.0,
510528
sparse_meta_data=None,
511529
previous_mask=partial_previous_mask,
512530
)
@@ -536,6 +554,8 @@ def test_sink_masker_add_mask(self):
536554
queries=queries,
537555
values=values,
538556
attention_mask=None,
557+
scaling=1.0,
558+
dropout=0.0,
539559
sparse_meta_data=None,
540560
previous_mask=partial_previous_mask,
541561
)

tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_hashattention_top_k.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def test_compute_hashattetion_scores(self, basic_config, test_tensors):
249249
scores = masker._compute_hashattention_score(
250250
keys=test_tensors["keys"],
251251
queries=test_tensors["queries"],
252+
attention_mask=None,
253+
previous_dense_mask=torch.zeros(test_tensors["batch_size"], test_tensors["num_heads"], test_tensors["seq_len_queries"], test_tensors["seq_len_keys"]),
252254
sparse_meta_data=sparse_meta_data,
253255
layer_idx=0,
254256
)
@@ -336,6 +338,8 @@ def test_hash_attention_top_k_masker_add_mask_input_validation(
336338
queries=test_tensors["queries"],
337339
values=test_tensors["values"],
338340
attention_mask=None,
341+
scaling=1.0,
342+
dropout=0.0,
339343
sparse_meta_data=None,
340344
previous_mask=empty_previous_mask,
341345
)
@@ -347,6 +351,8 @@ def test_hash_attention_top_k_masker_add_mask_input_validation(
347351
queries=test_tensors["queries"],
348352
values=test_tensors["values"],
349353
attention_mask=None,
354+
scaling=1.0,
355+
dropout=0.0,
350356
sparse_meta_data={},
351357
previous_mask=empty_previous_mask,
352358
)
@@ -375,6 +381,8 @@ def test_hash_attention_top_k_masker_add_mask_full_previous(
375381
queries=test_tensors["queries"],
376382
values=test_tensors["values"],
377383
attention_mask=None,
384+
scaling=1.0,
385+
dropout=0.0,
378386
sparse_meta_data={},
379387
previous_mask=full_previous_mask,
380388
layer_idx=0,
@@ -408,6 +416,8 @@ def test_hash_attention_top_k_masker_add_mask_small_sequence(
408416
queries=large_test_tensors["queries"],
409417
values=large_test_tensors["values"],
410418
attention_mask=None,
419+
scaling=1.0,
420+
dropout=0.0,
411421
sparse_meta_data={},
412422
previous_mask=empty_previous_mask,
413423
layer_idx=0,
@@ -440,6 +450,8 @@ def test_hash_attention_top_k_masker_add_mask_integer_heavy_size(
440450
queries=test_tensors["queries"],
441451
values=test_tensors["values"],
442452
attention_mask=None,
453+
scaling=1.0,
454+
dropout=0.0,
443455
sparse_meta_data={},
444456
previous_mask=empty_previous_mask,
445457
layer_idx=0,
@@ -481,6 +493,8 @@ def test_hash_attention_top_k_masker_add_mask_float_heavy_size(
481493
queries=test_tensors["queries"],
482494
values=test_tensors["values"],
483495
attention_mask=None,
496+
scaling=1.0,
497+
dropout=0.0,
484498
sparse_meta_data={},
485499
previous_mask=empty_previous_mask,
486500
layer_idx=0,
@@ -525,6 +539,8 @@ def test_hash_attention_top_k_masker_add_mask_merge_with_previous(
525539
queries=large_test_tensors["queries"],
526540
values=large_test_tensors["values"],
527541
attention_mask=None,
542+
scaling=1.0,
543+
dropout=0.0,
528544
sparse_meta_data={},
529545
previous_mask=previous_mask,
530546
layer_idx=0,
@@ -577,6 +593,8 @@ def test_hash_attention_top_k_masker_add_mask_signature_caching(
577593
queries=test_tensors["queries"],
578594
values=test_tensors["values"],
579595
attention_mask=None,
596+
scaling=1.0,
597+
dropout=0.0,
580598
sparse_meta_data=sparse_meta_data,
581599
previous_mask=empty_previous_mask,
582600
layer_idx=0,
@@ -605,6 +623,8 @@ def test_hash_attention_top_k_masker_add_mask_signature_caching(
605623
queries=test_tensors["queries"],
606624
values=test_tensors["values"],
607625
attention_mask=None,
626+
scaling=1.0,
627+
dropout=0.0,
608628
sparse_meta_data=sparse_meta_data,
609629
previous_mask=empty_previous_mask,
610630
layer_idx=0,
@@ -640,6 +660,8 @@ def test_hash_attention_top_k_masker_add_mask_different_activations(
640660
queries=test_tensors["queries"],
641661
values=test_tensors["values"],
642662
attention_mask=None,
663+
scaling=1.0,
664+
dropout=0.0,
643665
sparse_meta_data={},
644666
previous_mask=empty_previous_mask,
645667
layer_idx=0,

tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_oracle_top_k.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def test_oracle_top_k_masker_add_mask_full_previous(self):
9595
queries=queries,
9696
values=values,
9797
attention_mask=None,
98+
scaling=1.0,
99+
dropout=0.0,
98100
sparse_meta_data=None,
99101
previous_mask=full_previous_mask,
100102
)
@@ -129,6 +131,8 @@ def test_oracle_top_k_masker_add_mask_small_sequence(self):
129131
queries=queries,
130132
values=values,
131133
attention_mask=None,
134+
scaling=1.0,
135+
dropout=0.0,
132136
sparse_meta_data=None,
133137
previous_mask=empty_previous_mask,
134138
)
@@ -163,6 +167,8 @@ def test_oracle_top_k_masker_add_mask_integer_heavy_size(self):
163167
queries=queries,
164168
values=values,
165169
attention_mask=None,
170+
scaling=1.0,
171+
dropout=0.0,
166172
sparse_meta_data=None,
167173
previous_mask=empty_previous_mask,
168174
)
@@ -205,6 +211,8 @@ def test_oracle_top_k_masker_add_mask_float_heavy_size(self):
205211
queries=queries,
206212
values=values,
207213
attention_mask=None,
214+
scaling=1.0,
215+
dropout=0.0,
208216
sparse_meta_data=None,
209217
previous_mask=empty_previous_mask,
210218
)
@@ -260,6 +268,8 @@ def test_oracle_top_k_masker_add_mask_avoids_previous_active(self):
260268
queries=queries,
261269
values=values,
262270
attention_mask=None,
271+
scaling=1.0,
272+
dropout=0.0,
263273
sparse_meta_data=None,
264274
previous_mask=previous_mask,
265275
)
@@ -313,6 +323,8 @@ def test_oracle_top_k_masker_add_mask_merge_with_previous(self):
313323
queries=queries,
314324
values=values,
315325
attention_mask=None,
326+
scaling=1.0,
327+
dropout=0.0,
316328
sparse_meta_data=None,
317329
previous_mask=previous_mask,
318330
)
@@ -364,6 +376,8 @@ def test_oracle_top_k_masker_add_mask_edge_case_heavy_size_zero(self):
364376
queries=queries,
365377
values=values,
366378
attention_mask=None,
379+
scaling=1.0,
380+
dropout=0.0,
367381
sparse_meta_data=None,
368382
previous_mask=empty_previous_mask,
369383
)
@@ -398,6 +412,8 @@ def test_oracle_top_k_masker_add_mask_edge_case_heavy_size_one(self):
398412
queries=queries,
399413
values=values,
400414
attention_mask=None,
415+
scaling=1.0,
416+
dropout=0.0,
401417
sparse_meta_data=None,
402418
previous_mask=empty_previous_mask,
403419
)

0 commit comments

Comments
 (0)