diff --git a/docs/baselines/Bucket_Attn.md b/docs/baselines/Bucket_Attn.md new file mode 100644 index 00000000..56acb3bf --- /dev/null +++ b/docs/baselines/Bucket_Attn.md @@ -0,0 +1,189 @@ +# Bucket Attention Sparse Attention Baseline + +## 1. Overview + +**Bucket Attention** is a sparse attention mechanism inspired by **RACE (Repeated Arrays of Count Estimators)** and **LSH-based soft hashing**. + +Instead of evaluating all query–key dot products, Bucket Attention: + +1. **Hard-hashes keys** into buckets using Sign Random Projection (SRP). +2. **Soft-hashes queries** to obtain probability distributions over the same buckets. +3. **Selects the top-t buckets** per query for each hash table. +4. Builds a **candidate set** by taking the union of all keys that fall into selected buckets. +5. Performs **value-aware-collision ranking** to recover the true Top-K candidates for attention. + +--- + +## 2. Hard-Hashing Keys (Sign Random Projection) + +We use **L independent hash tables**, each containing **P random hyperplanes**. + +### 2.1 Projection onto hyperplanes + +For a key vector $\( k_i \)$: $z_{\ell,p}(i) = \langle k_i,\ w_{\ell,p} \rangle$ + +### 2.2 Sign bit + +$$ +b_{\ell,p}(i) = \mathbf{1}[z_{\ell,p}(i) \ge 0] +$$ + +### 2.3 Bucket index (big-endian) + +$$ +\text{bucket}_\ell(i) += \sum_{p=1}^{P} b_{\ell,p}(i)\ 2^{P - p} +$$ + + +Keys that hash to the same bucket ID are treated as belonging to the same locality cluster. + +--- + +## 3. Soft-Hashing Queries + +Queries are "soft-assigned" to buckets using the same hyperplanes: + +1. Project queries: $z_{\ell,p}(q)$ +2. Apply nonlinearity: $\tanh(z_{\ell,p}(q))$ +4. Compute similarity to all **R hypercube corners** $\( c_r \in \{-1,+1\}^P \)$: + +$$ +\text{logits}_{q,\ell,r} += \sum_{p=1}^{P} \tanh(z_{\ell,p}(q)) \cdot c_r[p] +$$ + +A softmax yields per-table bucket probabilities: + +$$ +P(r \mid q, \ell) = \text{softmax}_r(\text{logits}_{q,\ell,r}) +$$ + +## 5. Bucket Selection (Union of Matching Buckets Across Tables) + +Once keys and queries have been hashed, Bucket Attention determines which keys +are *candidates* for each query by checking whether they land in any of the +query’s top-t buckets across the L hash tables. + +### 5.1 Key–Query Bucket Matching + +For each hash table ℓ: + +- Each key `i` has a hard bucket assignment + +$$ +\text{bucket}_\ell(i) \in \{0,\dots,R-1\}. +$$ + +- Each query `q` has a list of **top-t buckets**: + +$$ +\text{Top}_t(q,\ell) = \{r_1, \dots, r_t\}. +$$ + +A key `i` is considered a match for query `q` in table ℓ if: + +$$ +\text{bucket}_\ell(i) \in \text{Top}_t(q,\ell). +$$ + +### 5.2 Candidate Selection + +A key becomes a **candidate** if it matches in *any* of the L tables: + +$$ +\text{candidate}(q,i) += \bigvee_{\ell = 1}^{L}\ \mathbf{1}\big[ +\text{bucket}_\ell(i) \in \text{Top}_t(q,\ell) +\big]. +$$ + + +This mask represents the **union of all selected buckets** across tables. + +### 5.3 Collision Counts + +Beyond binary membership, we count how many tables vote for each key: + +$$ +\text{collisions}(q,i) += \sum_{\ell=1}^{L} +\mathbf{1}\big[ +\text{bucket}_\ell(i) \in \text{Top}_t(q,\ell) +\big]. +$$ + +- If `collisions = 0`: the key was never selected. +- If `collisions >= 1`: the key is a candidate. +- If `collisions` is large: multiple tables agree that the key is relevant. + +Collision counts behave as a **soft similarity signal**, often correlating with true attention weight. + +--- + +## 6. Value-Aware Scoring (Final Ranking) + +Candidate keys are then ranked before selecting the final top-K heavy tokens. +The ranking combines: + +1. **Collision strength** +2. **Value vector magnitude** + +### 6.1 Value Norm + +For each key value vector $\( v_i \)$: + +$$ +\|v_i\|_2 += \sqrt{\sum_{d} v_{i,d}^2}. +$$ + +This norm measures how much contribution the value vector can make to the +output—keys with larger values have greater influence. + + +### 6.2 Combined Collision Count + Value Score + +The final score for query $\( q \)$ and key $\( i \)$ is: + +$$ +\text{score}(q,i) += \text{collisions}(q,i)\ \cdot\ \|v_i\|_2. +$$ + +Interpretation: + +- **High collision count ⇒ key is repeatedly hashed near the query** +- **High value norm ⇒ key has potential to contribute strongly** +- The product balances structural similarity (hashing) and content magnitude (values) + + +### Example config in sparse-attention-hub +``` + config = ResearchAttentionConfig(masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + BucketMaskerConfig(K=4, L=31, top_t=4, heavy_size=0.2), + ]) +``` + +### Experimental Setup +Some datasets from the RULER benchmark + +In general, as the sparsity drops, there is a need to increase L (hash tables). + - Full recovery for 20% sparsity can be done with 30-32 tables. + - Full recovery for 10% sparsity can be done with 50-52 tables. + - Full recovery for 5% sparsity can be done with 78-80 tables. + +Our Results with model - meta-llama/Llama-3.1-8B-Instruct: + +| Dataset | Token Budget 1600 (0.05) | Token Budget 3200 (0.10) | Token Budget 6400 (0.20) | +|:-------------- |:------------------------:|:-------------------------:|:-------------------------:| +| **vt** | | | 92 | +| **fwe** | | | 93.3 | +| **multikey_2** | | 94 | 96 | +| **qa_2** | | 56 | 58 | +| **qa_1** | | 80 | 80 | +| **multikey_3** | 94 | 100 | 100 | + + diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/__init__.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/__init__.py index 83918321..42df4c36 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/__init__.py +++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/__init__.py @@ -7,6 +7,7 @@ SinkMasker, SinkMaskerConfig, ) +from .bucket_top_k import BucketMasker, BucketMaskerConfig from .double_sparsity_top_k import ( DoubleSparsityTopKMasker, DoubleSparsityTopKMaskerConfig, @@ -24,6 +25,8 @@ "OracleTopK", "QuestTopKMasker", "OracleTopPMasker", + "BucketMasker", + "BucketMaskerConfig", "PQCache", "HashAttentionTopKMasker", "DoubleSparsityTopKMasker", diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/bucket_top_k.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/bucket_top_k.py new file mode 100644 index 00000000..b8ca9cc4 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/bucket_top_k.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch + +from sparse_attention_hub.sparse_attention.research_attention.maskers.base import ( + AttentionTensorDimensions, + MaskerConfig, + MaskerRegistry, +) +from sparse_attention_hub.sparse_attention.utils.kv_utils import ( + _get_num_key_value_groups, + repeat_kv, +) +from sparse_attention_hub.sparse_attention.utils.mask import Mask + +from ..base import TopKMasker, TopKMaskerConfig +from .utils.bucket_utils import ( + attention_mask_to_allowed_prob, + get_collision_counts, + get_hyper_planes, + get_protos_T, + hard_hash, + soft_hash, +) + + +@dataclass +class BucketMaskerConfig(TopKMaskerConfig): + """ + Minimal masker config: + + • K: # of hyperplanes per table (buckets = 2**K) + • L: # of hash tables (independent sketches) + • top_t: # of buckets selected per table (per (B,H,Q)) + + heavy_size (inherited from TopKMaskerConfig) is used as the *sample size*: + M = _calculate_effective_size(heavy_size, N_keys) + We select up to M keys from the union of selected-bucket tokens using a value-aware score. + """ + + K: int = 4 + L: int = 1 + top_t: int = 4 + + +@MaskerRegistry.register(BucketMaskerConfig) +class BucketMasker(TopKMasker): + """ + L-table sparsity (mask-only): + + 1) Hard SRP hash keys with L sets of K planes → bucket ids per table. + 2) Soft SRP hash queries per table (tanh + /√d vs hypercube corners). + 3) Select top_t buckets per table for each (B,H,Q). + 4) Candidate = union of tokens in any selected bucket across tables. + 5) Within candidates, select up to M keys per (B,H,Q) using a *value-aware* score: + score[b,h,q,i] ∝ (# collisions across tables) * ||v_i||. + + Returns a packed boolean mask [B,H,Q,N]. + """ + + def __init__(self, config: BucketMaskerConfig) -> None: + super().__init__(config) + + if config.K <= 0: + raise ValueError("K (hyperplanes) must be a positive integer") + if config.L <= 0: + raise ValueError("L (hash tables) must be a positive integer") + if config.top_t <= 0: + raise ValueError("top_t must be a positive integer") + + self.P: int = int(config.K) + self.L: int = int(config.L) + self.top_t: int = int(config.top_t) + self.heavy_size = config.heavy_size + + # caches + self._planes_cache: Dict[ + Tuple[int, torch.device, torch.dtype, int, int], torch.Tensor + ] = {} + self._protos_cache: Dict[ + Tuple[int, torch.device, torch.dtype], torch.Tensor + ] = {} + self._seed = 123456789 + self._rng_cache: Dict[torch.device, torch.Generator] = {} + + def _rng(self, device: torch.device) -> Optional[torch.Generator]: + if self._seed is None: + return None + g = self._rng_cache.get(device) + if g is None: + g = torch.Generator(device=device) + # Option: offset seeds per device to keep sequences distinct + g.manual_seed(self._seed + 7777) + self._rng_cache[device] = g + return g + + # ---------- Public API ---------- + + def add_mask( + self, + keys: torch.Tensor, # [B, H_k or G, N, D] + queries: torch.Tensor, # [B, H, Q, D] + values: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float, + sparse_meta_data: Dict, + previous_mask: Mask, + **kwargs, + ) -> Mask: + # Respect a fully-open previous mask + if previous_mask.is_full_mask(): + return previous_mask + + dims: AttentionTensorDimensions = self._extract_tensor_dimensions(keys, queries) + heavy_tokens: int = self._calculate_effective_size( + self.heavy_size, dims.seq_len_keys + ) + if self._should_use_full_attention(dims, heavy_tokens): + return self._create_full_mask( + dims, previous_mask.dtype, previous_mask.device + ) + + # 1) Align to MHA if KV are grouped (GQA/MQA) + ngroups = _get_num_key_value_groups(queries, keys) + keys_rep = repeat_kv(keys, ngroups) # [B,H,N,D] + B, H, N, D = keys_rep.shape + _, _, Q, _ = queries.shape + + # 2) SRP planes & corners + planes = get_hyper_planes( + cache=self._planes_cache, + D=D, + L=self.L, + P=self.P, + device=keys_rep.device, + dtype=keys_rep.dtype, + rng=self._rng(keys_rep.device), + ) # [L,P,D] + protosT = get_protos_T( + cache=self._protos_cache, + P=self.P, + device=keys_rep.device, + dtype=keys_rep.dtype, + ) # [P,R] + R = 1 << self.P + top_t = max(1, min(self.top_t, R)) + + # 3) Hard-hash keys per table → [B,H,L,N] + key_buckets = hard_hash(keys_rep, planes) # [B,H,L,N] + + # 4) Soft-hash queries per table → probs [B,H,Q,L,R] + q_probs = soft_hash(queries, planes, protosT) # [B,H,Q,L,R] + + # 5) Select top_t buckets per table → [B,H,Q,L,top_t] + top_buckets = torch.topk(q_probs, k=top_t, dim=-1, largest=True).indices + + # 6) Candidate union across tables + collision counts → [B,H,Q,N], [B,H,Q,N] + candidate_mask, collision_counts = get_collision_counts( + key_buckets, top_buckets + ) # candidate_mask: bool + + # Convert external attention mask to allowed probabilities in [0,1], + allowed_prob = None + if attention_mask is not None: + # [B,1,*,N] float in [0,1] + allowed_prob = attention_mask_to_allowed_prob(attention_mask, N) + + # For fallback when we have no candidates, we derive a boolean "allowed" mask + # from the probabilities (allowed iff prob > 0). + allowed_bool = allowed_prob > 0 + if allowed_bool.dim() == 3: + # [B,*,N] -> [B,1,*,N] to match allowed_prob + allowed_bool = allowed_bool.unsqueeze(1) + allowed_bool = allowed_bool.expand_as(candidate_mask) # [B,H,Q,N] + else: + # Everything allowed + allowed_bool = torch.ones_like(candidate_mask, dtype=torch.bool) + + no_cands = ~candidate_mask.any(dim=-1, keepdim=True) # [B,H,Q,1] + candidate_mask = torch.where( + no_cands, allowed_bool, candidate_mask + ) # [B,H,Q,N] + + # 8) Budget from heavy_size + M = max(0, min(int(self._calculate_effective_size(self.heavy_size, N)), N)) + if M == 0: + return previous_mask + Km = min(M, N) + + # 9a) Align values to heads and compute ||v_i|| per key + v_rep = repeat_kv( + values, _get_num_key_value_groups(queries, values) + ) # [B,H,N,Dv] + v_mag = torch.linalg.vector_norm(v_rep.float(), ord=2, dim=-1) # [B,H,N] + + # 9b) Value-aware score: score[b,h,q,i] = (# collisions) * ||v_i|| + collision_counts_f = collision_counts.to(torch.float32) # [B,H,Q,N] + raw_scores = collision_counts_f * v_mag.unsqueeze(2) # [B,H,Q,N] + + # 9c) Deterministic top-k on value-aware scores within candidates + scores = raw_scores.masked_fill(~candidate_mask, -torch.inf) # [B,H,Q,N] + top_idx = torch.topk(scores, k=Km, dim=-1, largest=True).indices # [B,H,Q,Km] + + # 9d) Enforce per-row effective K = min(M, #candidates) + cand_counts = candidate_mask.sum(dim=-1) # [B,H,Q] + k_each = cand_counts.clamp_max(M) # [B,H,Q] + keep = torch.arange(Km, device=keys_rep.device).view( + 1, 1, 1, Km + ) < k_each.unsqueeze( + -1 + ) # [B,H,Q,Km] bool + + # 9e) Scatter to boolean mask (robust to ties / duplicates) + acc = torch.zeros((B, H, Q, N), device=keys_rep.device, dtype=torch.int16) + acc.scatter_add_(dim=-1, index=top_idx, src=keep.to(acc.dtype)) + final_mask = acc > 0 # [B,H,Q,N] bool + + # Previous dense mask as probabilities in [0,1] + dense_prev = previous_mask.get_dense_mask() # [B,H,Q,N] + if not dense_prev.dtype.is_floating_point: + dense_prev = dense_prev.to(scores.dtype) + dense_prev = dense_prev.clamp_(0.0, 1.0) + + # Our new bucket mask as {0,1} float + dense_bucket = final_mask.to(dense_prev.dtype) # [B,H,Q,N] + + # Probabilistic OR: keep anything that either previous_mask or bucket mask allows + dense_mask = torch.maximum(dense_prev, dense_bucket) + + # Gate by external attention mask probabilities + if allowed_prob is not None: + ap = allowed_prob.to(dense_mask.dtype) # [B,1,*,N] + dense_mask = dense_mask * ap.expand_as(dense_mask) + + mask_shape = (B, H, Q, N) + return Mask.create_mask_from_dense_mask( + mask_shape, dense_mask, dtype=previous_mask.dtype + ) + + def _should_use_full_attention( + self, dims: AttentionTensorDimensions, heavy_tokens: int + ) -> bool: + """Full attention if the key sequence is within budget.""" + return dims.seq_len_keys <= max(1, heavy_tokens) + + @classmethod + def create_from_config(cls, config: MaskerConfig) -> "BucketMasker": + if not isinstance(config, BucketMaskerConfig): + raise ValueError(f"Invalid config type: {type(config)}") + return cls(config) diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/utils/bucket_utils.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/utils/bucket_utils.py new file mode 100644 index 00000000..fc479bd7 --- /dev/null +++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/utils/bucket_utils.py @@ -0,0 +1,175 @@ +"""Bucket utility functions.""" + +import itertools +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +PlanesCache = Dict[Tuple[int, torch.device, torch.dtype, int, int], torch.Tensor] +ProtosCache = Dict[Tuple[int, torch.device, torch.dtype], torch.Tensor] + + +def get_hyper_planes( + cache: PlanesCache, + D: int, + L: int, + P: int, + device: torch.device, + dtype: torch.dtype, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + """ + Independent SRP planes per table: + planes: [L, P, D] + + Caches in the provided `cache` dict so that multiple calls + with the same (D, device, dtype, L, P) reuse the planes. + """ + key = (D, device, dtype, L, P) + planes = cache.get(key) + if planes is None: + base = torch.randn( + (L, P, D), + device=device, + dtype=torch.float32, + generator=rng, + ) + planes = base.to(dtype) + cache[key] = planes + return planes + + +def get_protos_T( + cache: ProtosCache, + P: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Hypercube corners: protos_T in {-1,+1}^{P}, shape [P, R] + + Uses the given `cache` dict to memoize by: + (P, device, dtype) + """ + key = (P, device, dtype) + protos_T = cache.get(key) + if protos_T is None: + corners = torch.tensor( + list(itertools.product([-1.0, +1.0], repeat=P)), + device=device, + dtype=torch.float32, + ) # [R, P] + protos_T = corners.t().to(dtype) # [P, R] + cache[key] = protos_T + return protos_T + + +def pack_bits(bits: torch.Tensor) -> torch.Tensor: + """ + Pack last-dim bits into integer codes (big-endian). + bits: [..., P] bool + returns: [...] int16 + """ + P = bits.shape[-1] + weights = 1 << torch.arange( + P - 1, -1, -1, device=bits.device, dtype=torch.int16 + ) # [P] with MSB first + return torch.sum( + bits.to(torch.int16) * weights.view(*([1] * (bits.ndim - 1)), P), + dim=-1, + ) + + +def hard_hash(tensor: torch.Tensor, planes: torch.Tensor) -> torch.Tensor: + """ + tensor: [B,H,N,D], planes: [L,P,D] + returns bucket codes per table: [B,H,L,N] + """ + # [B,H,N,L,P] + proj = torch.einsum("bhnd,lkd->bhnlk", tensor, planes) + bits = proj >= 0 # bool + # [B,H,N,L] + codes = pack_bits(bits) + # [B,H,L,N] + return codes.permute(0, 1, 3, 2).contiguous() + + +def soft_hash( + queries: torch.Tensor, + planes: torch.Tensor, + protos_T: torch.Tensor, +) -> torch.Tensor: + """ + queries: [B,H,Q,D] + planes: [L,P,D] + protos_T: [P,R] + returns soft bucket probabilities: [B,H,Q,L,R] + """ + # [B,H,Q,L,P] + q_proj = torch.einsum("bhqd,lkd->bhqlk", queries, planes) + temp = math.sqrt(queries.size(-1)) + logits = torch.einsum( + "bhqlk,kr->bhqlr", + torch.tanh(q_proj) / max(temp, 1e-6), + protos_T, + ) # [B,H,Q,L,R] + return F.softmax(logits, dim=-1) + + +def get_collision_counts( + key_buckets: torch.Tensor, # [B,H,L,N] + top_buckets: torch.Tensor, # [B,H,Q,L,top_t] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + For each table ℓ, mark tokens whose bucket matches any selected bucket in that table. + Then union across tables, and also return per-(B,H,Q,N) collision counts. + + Returns: + candidate_mask: [B,H,Q,N] (bool) + collision_counts: [B,H,Q,N] (int) # # of tables where (q,i) matched + """ + B, H, L, N = key_buckets.shape + _, _, Q, _, top_t = top_buckets.shape + + # match_any[b,h,q,l,i] = True if key_buckets[b,h,l,i] equals + # any of top_buckets[b,h,q,l,t] over t. + match_any = torch.zeros( + (B, H, Q, L, N), dtype=torch.bool, device=key_buckets.device + ) + + # [B,H,1,L,N], broadcasts across Q and the last dim + kb = key_buckets.unsqueeze(2) # [B,H,1,L,N] + + for t in range(top_t): + # Select the t-th chosen bucket per (B,H,Q,L) + tb_t = top_buckets[..., t].unsqueeze(-1) # [B,H,Q,L,1] + match_any |= kb == tb_t # [B,H,Q,L,N] + + # Union across L tables → candidate mask [B,H,Q,N] + candidate_mask = match_any.any(dim=3) + + # Collision counts: number of tables where (q,i) matched + collision_counts = match_any.sum(dim=3) # [B,H,Q,N] + + return candidate_mask, collision_counts + + +def attention_mask_to_allowed_prob( + attention_mask: torch.Tensor, K: int +) -> torch.Tensor: + """ + Convert attention_mask to allowed-probabilities in [0,1], shape [B,1,*,K]. + Heuristics: + - bool masks: 0 => allow (1.0), 1 => forbid (0.0) + - additive float masks: >=0 => allow (1.0), negative => forbid (0.0) + """ + am = attention_mask[..., :K] + if am.dtype == torch.bool: + allowed = (am == 0).to(torch.float32) + else: + allowed = (am >= 0).to(torch.float32) + if allowed.dim() == 3: + allowed = allowed.unsqueeze(1) # [B,1,*,K] + return allowed diff --git a/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn.py b/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn.py new file mode 100644 index 00000000..4383b01b --- /dev/null +++ b/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn.py @@ -0,0 +1,262 @@ +import re + +import pytest +import torch + +from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations.bucket_top_k import ( + BucketMasker, + BucketMaskerConfig, +) +from sparse_attention_hub.sparse_attention.utils.mask import Mask + + +@pytest.mark.unit +class TestBucketMaskerImplementation: + """Tests for BucketMasker (bucket attention).""" + + def test_bucket_masker_config_creation(self): + """Config can be created and fields are set correctly.""" + config = BucketMaskerConfig( + heavy_size=0.05, + K=4, + L=2, + top_t=3, + ) + assert config is not None + assert config.heavy_size == 0.05 + assert config.K == 4 + assert config.L == 2 + assert config.top_t == 3 + + def test_bucket_masker_config_validation(self): + from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations.bucket_top_k import ( + BucketMasker, + BucketMaskerConfig, + ) + + msg = "K (hyperplanes) must be a positive integer" + + with pytest.raises(ValueError, match=re.escape(msg)): + config = BucketMaskerConfig(heavy_size=0.05, K=0, L=1, top_t=1) + BucketMasker(config) + + msg = "L (hash tables) must be a positive integer" + with pytest.raises(ValueError, match=re.escape(msg)): + config = BucketMaskerConfig(heavy_size=0.05, K=4, L=0, top_t=1) + BucketMasker(config) + + msg = "top_t must be a positive integer" + with pytest.raises(ValueError, match=re.escape(msg)): + config = BucketMaskerConfig(heavy_size=0.05, K=4, L=1, top_t=0) + BucketMasker(config) + + def test_bucket_masker_creation(self): + """BucketMasker can be created from config.""" + config = BucketMaskerConfig( + heavy_size=0.05, + K=4, + L=2, + top_t=3, + ) + masker = BucketMasker.create_from_config(config) + assert isinstance(masker, BucketMasker) + # Optional: check that config got attached + assert masker.heavy_size == config.heavy_size + assert masker.P == config.K + assert masker.L == config.L + assert masker.top_t == config.top_t + + def _make_dummy_inputs(self, device="cpu"): + """Helper to create small synthetic Q/K/V + attention_mask.""" + B, H, N, Q, D = 2, 4, 16, 3, 8 + torch.manual_seed(0) + + keys = torch.randn(B, H, N, D, device=device) + queries = torch.randn(B, H, Q, D, device=device) + values = torch.randn(B, H, N, D, device=device) + + # Standard [B,1,1,N] additive mask: allow all + attention_mask = torch.zeros(B, 1, 1, N, device=device) + + # Empty previous mask (all zeros) + dense_prev = torch.zeros(B, H, Q, N, device=device) + previous_mask = Mask.create_mask_from_dense_mask( + (B, H, Q, N), dense_prev, dtype=torch.float32 + ) + return keys, queries, values, attention_mask, previous_mask + + def test_bucket_masker_basic_add_mask_shapes(self): + """add_mask should produce a Mask with correct dense shape.""" + config = BucketMaskerConfig( + heavy_size=0.25, # select about 25% of tokens + K=4, + L=2, + top_t=2, + ) + masker = BucketMasker.create_from_config(config) + keys, queries, values, attention_mask, previous_mask = self._make_dummy_inputs() + + new_mask = masker.add_mask( + keys=keys, + queries=queries, + values=values, + attention_mask=attention_mask, + scaling=1.0, + dropout=0.0, + sparse_meta_data={}, + previous_mask=previous_mask, + ) + assert isinstance(new_mask, Mask) + + dense = new_mask.get_dense_mask() + B, H, Q, N = 2, 4, 3, 16 + assert dense.shape == (B, H, Q, N) + + # Values should be between 0 and 1 (Quest-style probabilities) + assert dense.min() >= 0.0 + assert dense.max() <= 1.0 + + def test_bucket_masker_respects_heavy_size_budget(self): + """Total selected tokens per (B,H,Q) should not exceed heavy_size-based budget.""" + B, H, Q, N = 2, 4, 3, 32 + config = BucketMaskerConfig( + heavy_size=0.25, # about 8 tokens out of 32 + K=4, + L=2, + top_t=2, + ) + masker = BucketMasker.create_from_config(config) + + keys = torch.randn(B, H, N, 8) + queries = torch.randn(B, H, Q, 8) + values = torch.randn(B, H, N, 8) + attention_mask = torch.zeros(B, 1, 1, N) + prev = Mask.create_mask_from_dense_mask( + (B, H, Q, N), + torch.zeros(B, H, Q, N), + dtype=torch.float32, + ) + + new_mask = masker.add_mask( + keys=keys, + queries=queries, + values=values, + attention_mask=attention_mask, + scaling=1.0, + dropout=0.0, + sparse_meta_data={}, + previous_mask=prev, + ) + dense = new_mask.get_dense_mask() # [B,H,Q,N] + + # Compute effective heavy tokens as used inside the masker + effective_M = masker._calculate_effective_size(masker.heavy_size, N) + # For each (b,h,q) row, number of active tokens should be <= effective_M + active_per_row = (dense > 0).sum(dim=-1) # [B,H,Q] + assert torch.all(active_per_row <= effective_M) + + def test_bucket_masker_attention_mask_boolean(self): + """Blocked positions in a boolean attention_mask should remain masked out.""" + config = BucketMaskerConfig( + heavy_size=0.5, + K=4, + L=2, + top_t=2, + ) + masker = BucketMasker.create_from_config(config) + + B, H, N, Q, D = 1, 2, 16, 2, 8 + keys = torch.randn(B, H, N, D) + queries = torch.randn(B, H, Q, D) + values = torch.randn(B, H, N, D) + + # Boolean mask: allow first half, forbid second half + attention_mask = torch.zeros(B, 1, 1, N, dtype=torch.bool) + attention_mask[..., N // 2 :] = True # blocked + + prev = Mask.create_mask_from_dense_mask( + (B, H, Q, N), + torch.zeros(B, H, Q, N), + dtype=torch.float32, + ) + + new_mask = masker.add_mask( + keys=keys, + queries=queries, + values=values, + attention_mask=attention_mask, + scaling=1.0, + dropout=0.0, + sparse_meta_data={}, + previous_mask=prev, + ) + dense = new_mask.get_dense_mask() # [B,H,Q,N] + + # Second half must be zeroed out by gating + tail = dense[..., N // 2 :] + assert torch.all(tail == 0.0) + + def test_bucket_masker_attention_mask_additive(self): + """Blocked positions in an additive mask (<0) should remain masked out.""" + config = BucketMaskerConfig( + heavy_size=0.5, + K=4, + L=2, + top_t=2, + ) + masker = BucketMasker.create_from_config(config) + + B, H, N, Q, D = 1, 2, 16, 2, 8 + keys = torch.randn(B, H, N, D) + queries = torch.randn(B, H, Q, D) + values = torch.randn(B, H, N, D) + + # Additive mask: 0 = allowed, -1e9 = blocked + attention_mask = torch.zeros(B, 1, 1, N) + attention_mask[..., N // 2 :] = -1e9 # blocked + + prev = Mask.create_mask_from_dense_mask( + (B, H, Q, N), + torch.zeros(B, H, Q, N), + dtype=torch.float32, + ) + + new_mask = masker.add_mask( + keys=keys, + queries=queries, + values=values, + attention_mask=attention_mask, + scaling=1.0, + dropout=0.0, + sparse_meta_data={}, + previous_mask=prev, + ) + dense = new_mask.get_dense_mask() # [B,H,Q,N] + + # Second half must be zeroed out + tail = dense[..., N // 2 :] + assert torch.all(tail == 0.0) + + def test_bucket_masker_deterministic_given_seed(self): + """With the same config and inputs, BucketMasker should be deterministic.""" + config = BucketMaskerConfig( + heavy_size=0.25, + K=4, + L=2, + top_t=2, + ) + masker1 = BucketMasker.create_from_config(config) + masker2 = BucketMasker.create_from_config(config) + + keys, queries, values, attention_mask, previous_mask = self._make_dummy_inputs() + + out1 = masker1.add_mask( + keys, queries, values, attention_mask, 1.0, 0.0, {}, previous_mask + ) + out2 = masker2.add_mask( + keys, queries, values, attention_mask, 1.0, 0.0, {}, previous_mask + ) + + dense1 = out1.get_dense_mask() + dense2 = out2.get_dense_mask() + assert torch.allclose(dense1, dense2) diff --git a/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn_utils.py b/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn_utils.py new file mode 100644 index 00000000..746e9966 --- /dev/null +++ b/tests/unit/sparse_attention/research_attention/maskers/fixed/implementations/test_bucket_attn_utils.py @@ -0,0 +1,272 @@ +import pytest +import torch + +from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations.utils.bucket_utils import ( + attention_mask_to_allowed_prob, + get_collision_counts, + get_hyper_planes, + get_protos_T, + hard_hash, + pack_bits, + soft_hash, +) + + +@pytest.mark.unit +class TestBucketUtils: + def test_get_hyper_planes_basic(self): + """get_hyper_planes returns correctly-shaped, cached tensors.""" + cache = {} + D, L, P = 16, 3, 4 + device = torch.device("cpu") + dtype = torch.float32 + + planes1 = get_hyper_planes( + cache=cache, + D=D, + L=L, + P=P, + device=device, + dtype=dtype, + rng=torch.Generator(device=device).manual_seed(0), + ) + assert planes1.shape == (L, P, D) + assert planes1.dtype == dtype + assert len(cache) == 1 + + # Same key -> same object from cache (no reallocation) + planes2 = get_hyper_planes( + cache=cache, + D=D, + L=L, + P=P, + device=device, + dtype=dtype, + rng=torch.Generator(device=device).manual_seed(123), + ) + assert planes2 is planes1 + + # Different (D,L,P) -> new entry in cache + planes3 = get_hyper_planes( + cache=cache, + D=8, + L=L, + P=P, + device=device, + dtype=dtype, + rng=torch.Generator(device=device).manual_seed(0), + ) + assert planes3.shape == (L, P, 8) + assert planes3 is not planes1 + assert len(cache) == 2 + + def test_get_protos_T_basic(self): + """get_protos_T returns hypercube corners with correct shape and caching.""" + cache = {} + P = 3 + R = 1 << P + device = torch.device("cpu") + dtype = torch.float32 + + protos1 = get_protos_T( + cache=cache, + P=P, + device=device, + dtype=dtype, + ) + assert protos1.shape == (P, R) + assert protos1.dtype == dtype + assert len(cache) == 1 + + # All entries must be ±1 + assert torch.all(torch.isin(protos1, torch.tensor([-1.0, 1.0]))) + + # Same key -> cached + protos2 = get_protos_T( + cache=cache, + P=P, + device=device, + dtype=dtype, + ) + assert protos2 is protos1 + + # Different P -> new entry + protos3 = get_protos_T( + cache=cache, + P=P + 1, + device=device, + dtype=dtype, + ) + assert protos3.shape == (P + 1, 1 << (P + 1)) + assert len(cache) == 2 + + def test_pack_bits_known_values(self): + """pack_bits should pack bit patterns into integers in big-endian order.""" + # bits: [..., P] + bits = torch.tensor( + [ + [0, 0, 0, 0], # 0 + [0, 0, 0, 1], # 1 + [0, 0, 1, 0], # 2 + [1, 0, 0, 0], # 8 + [1, 1, 1, 1], # 15 + ], + dtype=torch.bool, + ) + codes = pack_bits(bits) # [5] + expected = torch.tensor([0, 1, 2, 8, 15], dtype=torch.int64) + assert torch.equal(codes, expected) + + def test_hard_hash_simple_planes(self): + """hard_hash should assign the same buckets for identical inputs and respect planes.""" + # Use simple deterministic planes so behavior is predictable + B, H, N, _ = 1, 1, 2, 2 + L, _ = 1, 2 + + # Planes: identity-like projections + planes = torch.tensor( + [ + [ # table 0 + [1.0, 0.0], # hyperplane 0 + [0.0, 1.0], # hyperplane 1 + ] + ] + ) # [L,P,D] + + # Two keys: [1,1] and [-1,-1] + keys = torch.tensor([[[[1.0, 1.0], [-1.0, -1.0]]]]) # [B,H,N,D] + + codes = hard_hash(keys, planes) # [B,H,L,N] + assert codes.shape == (B, H, L, N) + + # First key: projections [1,1] => bits [1,1] => code b'11' = 3 + # Second key: projections [-1,-1] => bits [0,0] => code b'00' = 0 + assert codes[0, 0, 0, 0].item() == 3 + assert codes[0, 0, 0, 1].item() == 0 + + # Identical keys => identical codes + keys2 = keys.clone() + codes2 = hard_hash(keys2, planes) + assert torch.equal(codes, codes2) + + def test_soft_hash_shapes_and_probs(self): + """soft_hash returns valid probability distributions per bucket.""" + B, H, Q, D = 2, 3, 4, 5 + L, P = 2, 3 + R = 1 << P + + torch.manual_seed(0) + queries = torch.randn(B, H, Q, D) + planes = torch.randn(L, P, D) + protos_T = get_protos_T( + cache={}, + P=P, + device=queries.device, + dtype=queries.dtype, + ) # [P,R] + + q_probs = soft_hash(queries, planes, protos_T) # [B,H,Q,L,R] + assert q_probs.shape == (B, H, Q, L, R) + + # Probabilities should be non-negative and sum to ~1 along R + assert torch.all(q_probs >= 0) + probs_sum = q_probs.sum(dim=-1) # [B,H,Q,L] + assert torch.allclose( + probs_sum, torch.ones_like(probs_sum), atol=1e-5, rtol=1e-5 + ) + + def test_get_collision_counts_tiny_example(self): + """get_collision_counts should correctly compute candidate_mask and collision_counts.""" + # Small hand-constructed example + # B=1,H=1,L=2,N=3, Q=2, top_t=1 + # Table 0 buckets: [0, 1, 2] + # Table 1 buckets: [1, 1, 0] + key_buckets = torch.tensor( + [ + [ # B + [ # H + [0, 1, 2], # L=0 + [1, 1, 0], # L=1 + ] + ] + ] + ) # [1,1,2,3] => [B,H,L,N] + + # For q0: in table 0 pick bucket 1, in table 1 pick bucket 0 + # For q1: in table 0 pick bucket 2, in table 1 pick bucket 1 + top_buckets = torch.tensor( + [ + [ + [ + [ + [1], # q0, L=0 + [0], # q0, L=1 + ], + [ + [2], # q1, L=0 + [1], # q1, L=1 + ], + ] + ] + ] + ) # shape: [1, 1, 2, 2, 1] + + candidate_mask, collision_counts = get_collision_counts( + key_buckets, top_buckets + ) + # Shapes + assert candidate_mask.shape == (1, 1, 2, 3) + assert collision_counts.shape == (1, 1, 2, 3) + + # Let's reason expected collisions: + # keys indices: i=0,1,2 + + # q0: + # table 0 bucket=1 -> matches key1 only + # table 1 bucket=0 -> matches key2 only + # => collisions(q0) = [0,1,1] + expected_coll_q0 = torch.tensor([0, 1, 1]) + + # q1: + # table 0 bucket=2 -> matches key2 only + # table 1 bucket=1 -> matches key0? no, key0=1 in T1? actually T1: [1,1,0] + # -> matches key0 and key1 + # => collisions(q1) = [1,1,1] (key2 matched in table0 only) + expected_coll_q1 = torch.tensor([1, 1, 1]) + + assert torch.equal(collision_counts[0, 0, 0], expected_coll_q0) + assert torch.equal(collision_counts[0, 0, 1], expected_coll_q1) + + # candidate_mask is True where collisions > 0 + assert torch.equal(candidate_mask, collision_counts > 0) + + def test_attention_mask_to_allowed_prob_bool(self): + """attention_mask_to_allowed_prob for boolean masks.""" + B, K = 2, 5 + # True = blocked, False = allowed + attention_mask = torch.tensor( + [ + [[False, False, True, True, False]], + [[True, False, True, False, False]], + ], + dtype=torch.bool, + ) # [B,1,K] or [B,*,K] + + allowed_prob = attention_mask_to_allowed_prob(attention_mask, K) + # expected: allowed_prob = 1 where False, 0 where True + expected = (~attention_mask).to(torch.float32).unsqueeze(1) # [B,1,1,K] + assert allowed_prob.shape == (B, 1, 1, K) + assert torch.equal(allowed_prob, expected) + + def test_attention_mask_to_allowed_prob_additive(self): + """attention_mask_to_allowed_prob for additive (float) masks.""" + B, K = 1, 4 + # >=0 => allowed (1.0), <0 => forbidden (0.0) + attention_mask = torch.tensor([[[0.0, 1.0, -1e9, -0.5]]]) # [B,1,K] + + allowed_prob = attention_mask_to_allowed_prob(attention_mask, K) + assert allowed_prob.shape == (B, 1, 1, K) + + # positions 0,1 >=0 => 1.0; positions 2,3 <0 => 0.0 + expected = torch.tensor([[[[1.0, 1.0, 0.0, 0.0]]]]) + assert torch.equal(allowed_prob, expected)