Skip to content

Commit 9cc2996

Browse files
committed
Fix base sampling budget in adaptive
1 parent 35da60f commit 9cc2996

File tree

1 file changed

+1
-1
lines changed
  • sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations

1 file changed

+1
-1
lines changed

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,13 @@ def add_mask(
300300
num_base_samples,
301301
previous_mask.dtype,
302302
)
303-
304303
# Compute denominators and budget
305304
sampled_denominator = apply_inv_mask_sum(expwts, base_sampling_mask)
306305
estimated_denominator = static_denominator + sampled_denominator
307306
budget = self._compute_adaptive_budget(
308307
std_estimate, estimated_denominator, sampling_range
309308
)
309+
budget = torch.clamp(budget, min=num_base_samples, max=sampling_range)
310310

311311
# Create adaptive sampling mask
312312
sampling_probabilities = (budget / sampling_range).to(previous_mask.dtype)

0 commit comments

Comments
 (0)