Skip to content

Commit a87d271

Browse files
committed
Fix Linting errors
1 parent 5824fa6 commit a87d271

File tree

5 files changed

+28
-13
lines changed

5 files changed

+28
-13
lines changed

sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def _compute_top_p_thresholds(
119119
)
120120
# if top_p is 1.0, then threshold_positions will be equal to sorted_scores.shape[-1]
121121
# which is not a valid index, so we clamp it to the last valid index
122-
threshold_positions = torch.clamp(threshold_positions, max=sorted_scores.shape[-1] - 1)
122+
threshold_positions = torch.clamp(
123+
threshold_positions, max=sorted_scores.shape[-1] - 1
124+
)
123125
thresholds = torch.gather(sorted_scores, dim=-1, index=threshold_positions)
124126
return thresholds
125127

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ def _compute_exp_attention_scores(
187187

188188
def _get_sampling_range(self, seq_len_keys: int) -> tuple[int, int, int]:
189189
"""Get sampling range and validate it.
190-
190+
191191
Args:
192192
seq_len_keys: Number of keys in the sequence.
193-
193+
194194
Returns:
195195
Tuple of (start_idx, end_idx, sampling_range).
196-
196+
197197
Raises:
198198
ValueError: If the computed sampling range is invalid.
199199
"""
@@ -202,13 +202,13 @@ def _get_sampling_range(self, seq_len_keys: int) -> tuple[int, int, int]:
202202
start_idx: int = int(self.init_offset * seq_len_keys)
203203
else:
204204
start_idx = self.init_offset
205-
205+
206206
# Compute end index
207207
if isinstance(self.local_offset, float):
208208
end_idx: int = seq_len_keys - int(self.local_offset * seq_len_keys)
209209
else:
210210
end_idx = seq_len_keys - self.local_offset
211-
211+
212212
sampling_range = end_idx - start_idx
213213

214214
if sampling_range <= 0:

tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_hashattention_top_k.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ def test_compute_hashattetion_scores(self, basic_config, test_tensors):
250250
keys=test_tensors["keys"],
251251
queries=test_tensors["queries"],
252252
attention_mask=None,
253-
previous_dense_mask=torch.zeros(test_tensors["batch_size"], test_tensors["num_heads"], test_tensors["seq_len_queries"], test_tensors["seq_len_keys"]),
253+
previous_dense_mask=torch.zeros(
254+
test_tensors["batch_size"],
255+
test_tensors["num_heads"],
256+
test_tensors["seq_len_queries"],
257+
test_tensors["seq_len_keys"],
258+
),
254259
sparse_meta_data=sparse_meta_data,
255260
layer_idx=0,
256261
)

tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_oracle_top_p.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def max_normalized(x):
159159
scores = masker._compute_exp_attention_scores(
160160
keys,
161161
queries,
162-
previous_dense_mask=torch.zeros(batch_size, num_heads, seq_len_queries, seq_len_keys),
162+
previous_dense_mask=torch.zeros(
163+
batch_size, num_heads, seq_len_queries, seq_len_keys
164+
),
163165
attention_mask=None,
164166
scaling=1.0,
165167
)

tests/unit/sparse_attention/research_attention/maskers/sampling/test_adaptive_sampling.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def test_compute_exp_attention_scores(self, masker, sample_tensors):
196196
"""Test exponential attention scores computation."""
197197
keys, queries, _, _ = sample_tensors
198198

199-
exp_scores = masker._compute_exp_attention_scores(queries, keys, scaling=1.0, attention_mask=None)
199+
exp_scores = masker._compute_exp_attention_scores(
200+
queries, keys, scaling=1.0, attention_mask=None
201+
)
200202

201203
assert exp_scores.shape == (2, 4, 8, 16)
202204
assert torch.all(exp_scores >= 0) # Exponential should be non-negative
@@ -357,7 +359,8 @@ def test_add_mask_early_exit(self, masker, sample_tensors):
357359
# Create a full mask
358360
full_mask = Mask.create_full_mask((2, 4, 8, 16), dtype=torch.float32)
359361

360-
result = masker.add_mask(keys,
362+
result = masker.add_mask(
363+
keys,
361364
queries,
362365
values,
363366
attention_mask,
@@ -376,7 +379,8 @@ def test_add_mask_basic(self, masker, sample_tensors):
376379
# Create an empty mask
377380
empty_mask = Mask.create_empty_mask((2, 4, 8, 16), dtype=torch.float32)
378381

379-
result = masker.add_mask(keys,
382+
result = masker.add_mask(
383+
keys,
380384
queries,
381385
values,
382386
attention_mask,
@@ -420,7 +424,8 @@ def test_device_consistency(self, masker, sample_tensors):
420424

421425
empty_mask = Mask.create_empty_mask((2, 4, 8, 16), dtype=torch.float32)
422426

423-
result = masker.add_mask(keys,
427+
result = masker.add_mask(
428+
keys,
424429
queries,
425430
values,
426431
attention_mask,
@@ -443,7 +448,8 @@ def test_numerical_stability(self, masker, sample_tensors):
443448

444449
empty_mask = Mask.create_empty_mask((2, 4, 8, 16), dtype=torch.float32)
445450

446-
result = masker.add_mask(keys,
451+
result = masker.add_mask(
452+
keys,
447453
queries,
448454
values,
449455
attention_mask,

0 commit comments

Comments
 (0)