diff --git a/benchmark/executor.py b/benchmark/executor.py index 0dd5c60a..9d033068 100644 --- a/benchmark/executor.py +++ b/benchmark/executor.py @@ -217,12 +217,21 @@ def _benchmark_worker( # Import here to avoid issues with multiprocessing from sparse_attention_hub.adapters.huggingface import ModelAdapterHF + from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig + + # Extract recovery settings if available + recovery_kwargs = {} + if isinstance(stub.sparse_attention_config, ResearchAttentionConfig): + recovery_kwargs['recovery_enabled'] = stub.sparse_attention_config.recovery_enabled + recovery_kwargs['recovery_interval'] = stub.sparse_attention_config.recovery_interval + recovery_kwargs['recovery_dense_attention'] = stub.sparse_attention_config.recovery_dense_attention adapter = ModelAdapterHF( model_name=stub.model_name, sparse_attention_config=stub.sparse_attention_config, model_kwargs=stub.adapter_config.model_kwargs, - tokenizer_kwargs=stub.adapter_config.tokenizer_kwargs + tokenizer_kwargs=stub.adapter_config.tokenizer_kwargs, + **recovery_kwargs ) # Create benchmark instance diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index 289fa0f1..6434f3ea 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -15,7 +15,7 @@ from pathlib import Path # Add the project root to the path -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from benchmark.executor import BenchmarkExecutor from benchmark.executor_config import BenchmarkConfig, AdapterConfig @@ -23,33 +23,55 @@ from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import ( LocalMaskerConfig, SinkMaskerConfig ) +from sparse_attention_hub.sparse_attention import ( + ChannelConfig, + HashAttentionTopKMaskerConfig, +) + +from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import ( + AdaptiveSamplingMaskerConfig +) +from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import ( + OracleTopKConfig +) # ============================================================================ # CONFIGURATION # ============================================================================ # GPU Configuration -GPUS = [0,2,7] # Use all available GPUs -MAX_CONCURRENT_RUNS = 3 # One per GPU +GPUS = [0] # Use all available GPUs +MAX_CONCURRENT_RUNS = 1 # One per GPU # Model List MODELS = [ - "microsoft/Phi-4-mini-instruct", - "meta-llama/Llama-3.2-1B-Instruct", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ] +usa_weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" +weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" + +#from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa +#create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") + # Sparse Attention Configurations + SPARSE_CONFIGS = [ - # Dense baseline (no sparse attention) - ("dense", None), - - # StreamingLLM configurations - ("streaming_conservative", ResearchAttentionConfig(masker_configs=[ - SinkMaskerConfig(sink_size=4), - LocalMaskerConfig(window_size=16) - ])), + #("dense", None), + ("test_oracle_topk_norecovery_10pct_r1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + OracleTopKConfig(heavy_size=0.1), + ], + recovery_enabled=False, + recovery_interval=32000, + )) ] + + + # Benchmark List # 1. InfiniteBench - using passkey task infinite_bench_config = BenchmarkConfig( @@ -107,15 +129,7 @@ # List of all sample configurations BENCHMARKS = [ - infinite_bench_config, - ruler_config, - loogle_config, - zero_scrolls_config, - longbenchv2_config, - aime2024_config, - aime2025_config, - longbench_config, - mock_benchmark_config + aime2024_config ] @@ -124,6 +138,7 @@ adapter_name="huggingface", model_kwargs={ "torch_dtype": torch.bfloat16, + "attn_implementation" : "flash_attention_2" }, tokenizer_kwargs={ "padding_side": "left", @@ -132,23 +147,23 @@ # Generation Parameters GENERATION_KWARGS = { - "max_new_tokens": 50, - "do_sample": False, - "temperature": 1.0, - "top_p": 1.0, + "max_new_tokens": 32768, + "do_sample": True, + "temperature": 0.6, + "top_p": 0.95, "pad_token_id": None, } # Request Parameters REQUEST_KWARGS = { - "max_context_length": 256, - "max_requests": 2, # Limit for testing + "max_context_length": 32768, + "max_requests": 30, # Limit for testing } # Execution Settings -RESULT_DIR = "./benchmark_results" +RESULT_DIR = "./benchmark_results_test.1" ENABLE_RESUMABILITY = True -TIMEOUT_PER_BENCHMARK = 3600.0 # 1 hour +TIMEOUT_PER_BENCHMARK = 60 * 60 * 24 # 1 day # ============================================================================ # MAIN EXECUTION diff --git a/sparse_attention_hub/adapters/huggingface.py b/sparse_attention_hub/adapters/huggingface.py index 49a6f083..888abd5d 100644 --- a/sparse_attention_hub/adapters/huggingface.py +++ b/sparse_attention_hub/adapters/huggingface.py @@ -13,10 +13,100 @@ from ..sparse_attention.base import SparseAttention, SparseAttentionConfig from ..sparse_attention.research_attention.base import ResearchAttention from .base import ModelAdapter, Request, RequestResponse +from tqdm import tqdm INT_MAX = 2**31 - 1 +def _apply_temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor: + """Apply temperature scaling to logits. + + Args: + logits: Input logits tensor + temperature: Temperature parameter (1.0 = no scaling, <1.0 = sharper, >1.0 = smoother) + + Returns: + Temperature-scaled logits + """ + if temperature <= 0: + raise ValueError("Temperature must be positive") + return logits / temperature + + +def _apply_top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor: + """Apply top-p (nucleus) filtering to logits. + + Args: + logits: Input logits tensor of shape [..., vocab_size] + top_p: Cumulative probability threshold (0.0 to 1.0) + + Returns: + Filtered logits with low-probability tokens set to -inf + """ + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be between 0.0 and 1.0") + + if top_p >= 1.0: + return logits + + # Sort logits in descending order + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + + # Convert to probabilities and compute cumulative sum + sorted_probs = torch.nn.functional.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Create mask for tokens to keep (cumulative probability <= top_p) + # Keep at least one token (the highest probability one) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = False # Keep the top token + + # Scatter back to original indices + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + + # Set filtered tokens to -inf + filtered_logits = logits.clone() + filtered_logits[indices_to_remove] = float('-inf') + + return filtered_logits + + +def _sample_token( + logits: torch.Tensor, + do_sample: bool = True, + temperature: float = 1.0, + top_p: float = 1.0, +) -> torch.Tensor: + """Sample next token from logits with optional temperature and top_p filtering. + + Args: + logits: Logits tensor of shape [..., vocab_size] + do_sample: Whether to use sampling (True) or greedy decoding (False) + temperature: Temperature for scaling logits (only used if do_sample=True) + top_p: Top-p threshold for nucleus sampling (only used if do_sample=True) + + Returns: + Sampled token indices + """ + if not do_sample: + # Greedy decoding + return torch.argmax(logits, dim=-1) + + # Apply temperature scaling + if temperature != 1.0: + logits = _apply_temperature_scaling(logits, temperature) + + # Apply top-p filtering + if top_p < 1.0: + logits = _apply_top_p_filtering(logits, top_p) + + # Convert to probabilities and sample + probs = torch.nn.functional.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + class ModelAdapterHF(ModelAdapter): """ModelAdapter for HuggingFace integration. Provides concrete implementations for huggingface's transformer library. @@ -29,6 +119,9 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = None, + recovery_enabled: bool = False, + recovery_interval: int = 400, + recovery_dense_attention: Optional[str] = None, **kwargs: Dict[str, Any], ) -> None: """Initialize HuggingFace adapter. @@ -39,13 +132,24 @@ def __init__( model_kwargs: Additional keyword arguments for model creation device: Device to run the model on TODO: support dynamic and multipledevice placement tokenizer_kwargs: Additional keyword arguments for tokenizer creation + recovery_enabled: Whether to enable recovery mechanism during generation + recovery_interval: Number of tokens after which to trigger recovery + (regenerate embeddings with full attention) + recovery_dense_attention: Override attention implementation for recovery + (if None, uses original implementation) """ super().__init__(model_name, sparse_attention_config, **kwargs) self._registered_attention_name: Optional[str] = None self._custom_attention_fn: Optional[Callable] = None + self._original_implementations: Dict[str, str] = {} self.model_kwargs = model_kwargs or {} self.tokenizer_kwargs = tokenizer_kwargs or {} + # Recovery mechanism parameters + self.recovery_enabled = recovery_enabled + self.recovery_interval = recovery_interval + self.recovery_dense_attention = recovery_dense_attention + # more useful parameters to store self.device = ( device if device else ("cuda" if torch.cuda.is_available() else "cpu") @@ -316,6 +420,9 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: ): original_implementations[name] = module.config._attn_implementation + # Store original implementations as instance variable for use during generation + self._original_implementations = original_implementations + # Ensure custom attention function is registered (reuse if already registered) custom_attention_name: str = self._ensure_attention_registered() @@ -325,6 +432,7 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: if hasattr(module, "config") and hasattr( module.config, "_attn_implementation" ): + #print(f"Switching to sparse attention for {name}", module.config._attn_implementation, "->", custom_attention_name, flush=True) module.config._attn_implementation = custom_attention_name yield @@ -333,8 +441,12 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: # Restore original implementations for name, module in self.model.named_modules(): if name in original_implementations: + #print(f"Restoring original implementation for {name}", module.config._attn_implementation, "->", original_implementations[name], flush=True) module.config._attn_implementation = original_implementations[name] + # Clean up instance variable + self._original_implementations = {} + def _generate_response( self, question_tokens: torch.Tensor, @@ -363,6 +475,11 @@ def _generate_response( max_new_tokens: int = generation_kwargs.get("max_new_tokens", 50) # type: ignore context_length: int = context_outputs.past_key_values.get_seq_length() + # Extract sampling parameters from generation_kwargs + do_sample: bool = generation_kwargs.get("do_sample", False) + temperature: float = generation_kwargs.get("temperature", 1.0) + top_p: float = generation_kwargs.get("top_p", 1.0) + position_ids = torch.arange( context_length, context_length + question_tokens.shape[1], @@ -379,13 +496,37 @@ def _generate_response( ) position_ids = position_ids[:, -1:] + 1 - generated_ids = [question_outputs.logits[0, -1].argmax()] + + # Use proper sampling instead of greedy argmax + first_token = _sample_token( + question_outputs.logits[0, -1:], + do_sample=do_sample, + temperature=temperature, + top_p=top_p + ) + generated_ids = [first_token.squeeze()] should_stop_token_ids = self.model.generation_config.eos_token_id if not isinstance(should_stop_token_ids, list): should_stop_token_ids = [should_stop_token_ids] - for i in range(max_new_tokens - 1): + # Track newly generated tokens for recovery mechanism + new_tokens_generated = 0 + generation_start_cache_length = context_outputs.past_key_values.get_seq_length() + + if self.recovery_enabled: + print(f"Recovery enabled: regenerate embeddings every " + f"{self.recovery_interval} new tokens") + else: + print("Recovery disabled: using sparse attention for answer generation") + + # Print sampling configuration + if do_sample: + print(f"Sampling enabled: temperature={temperature}, top_p={top_p}") + else: + print("Greedy decoding enabled") + + for i in tqdm(range(max_new_tokens - 1)): with torch.no_grad(): outputs = self.model( input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0), @@ -393,14 +534,135 @@ def _generate_response( position_ids=position_ids + i, sparse_meta_data=sparse_meta_data, ) - # TODO: support other forms of decoding - new_id = outputs.logits[0, -1].argmax() + + # Use proper sampling instead of greedy argmax + new_id = _sample_token( + outputs.logits[0, -1:], + do_sample=do_sample, + temperature=temperature, + top_p=top_p + ).squeeze() generated_ids.append(new_id) + new_tokens_generated += 1 if new_id.item() in should_stop_token_ids: break + # Check if we need to regenerate embeddings (only if recovery is enabled) + if (self.recovery_enabled and + new_tokens_generated >= self.recovery_interval and + i < max_new_tokens - 2): # Don't regenerate on the last token + + print(f"Regenerating embeddings for {new_tokens_generated} " + f"newly generated tokens") + self._regenerate_embeddings_for_new_tokens( + context_outputs.past_key_values, + generated_ids[-new_tokens_generated:], # Last N generated tokens + generation_start_cache_length, + new_tokens_generated, + sparse_meta_data, + self._original_implementations + ) + new_tokens_generated = 0 # Reset counter + answer: str = self.tokenizer.decode( torch.stack(generated_ids), skip_special_tokens=True ) + return answer + + def _regenerate_embeddings_for_new_tokens( + self, + cache: Any, + new_token_ids: List[torch.Tensor], + start_cache_length: int, + num_new_tokens: int, + sparse_meta_data: Dict[str, Any], + original_implementations: Dict[str, str] + ) -> None: + """Regenerate embeddings for newly generated tokens using full attention. + + This removes the KV cache entries for the newly generated tokens and regenerates + them using full attention (dense mode), then continues with sparse attention. + + Args: + cache: The KV cache to modify + new_token_ids: List of newly generated token IDs + start_cache_length: Cache length when generation started + num_new_tokens: Number of new tokens to regenerate embeddings for + sparse_meta_data: Sparse metadata dictionary + original_implementations: Dict mapping module names to their original + attention implementations + """ + current_cache_length = cache.get_seq_length() + + # Remove embeddings for the newly generated tokens (keep everything before them) + keep_length = current_cache_length - num_new_tokens + + print(f"Removing embeddings for {num_new_tokens} tokens " + f"(keeping first {keep_length} tokens)") + + # Truncate cache to remove new token embeddings + for layer_idx in range(len(cache.key_cache)): + if cache.key_cache[layer_idx] is not None: + cache.key_cache[layer_idx] = ( + cache.key_cache[layer_idx][:, :, :keep_length] + ) + if cache.value_cache[layer_idx] is not None: + cache.value_cache[layer_idx] = ( + cache.value_cache[layer_idx][:, :, :keep_length] + ) + + # Handle quantized caches if present + if hasattr(cache, "_quantized_key_cache"): + for layer_idx in range(len(cache._quantized_key_cache)): + if cache._quantized_key_cache[layer_idx] is not None: + cache._quantized_key_cache[layer_idx] = ( + cache._quantized_key_cache[layer_idx][:, :, :keep_length] + ) + if cache._quantized_value_cache[layer_idx] is not None: + cache._quantized_value_cache[layer_idx] = ( + cache._quantized_value_cache[layer_idx][:, :, :keep_length] + ) + # Regenerate embeddings using full attention (one forward pass) + print(f"Regenerating embeddings using full attention for {num_new_tokens} tokens") + + # Create input tensor for the new tokens + new_tokens_tensor = torch.stack(new_token_ids).unsqueeze(0).to(self.model.device) + + # Create position IDs for the new tokens + position_ids = torch.arange( + keep_length, keep_length + num_new_tokens, device=self.model.device + ).unsqueeze(0) + + # Temporarily disable sparse mode to force dense attention + print("Forcing dense attention for regeneration") + + # Store current sparse implementations and switch to dense implementations + for name, module in self.model.named_modules(): + has_config = hasattr(module, "config") + has_attn_impl = has_config and hasattr(module.config, "_attn_implementation") + if name in original_implementations and has_attn_impl: + # Use override if provided, otherwise use original implementation + dense_implementation = ( + self.recovery_dense_attention or original_implementations[name] + ) + module.config._attn_implementation = dense_implementation + try: + # Regenerate embeddings with dense attention + with torch.no_grad(): + self.model( + input_ids=new_tokens_tensor, + past_key_values=cache, + position_ids=position_ids, + ) + + print(f"Successfully regenerated embeddings. Cache length: {cache.get_seq_length()}") + + finally: + # Restore sparse attention implementations + for name, module in self.model.named_modules(): + if hasattr(module, "config") and hasattr( + module.config, "_attn_implementation" + ): + module.config._attn_implementation = self._registered_attention_name diff --git a/sparse_attention_hub/sparse_attention/research_attention/base.py b/sparse_attention_hub/sparse_attention/research_attention/base.py index 34189068..bd31374b 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/base.py +++ b/sparse_attention_hub/sparse_attention/research_attention/base.py @@ -26,6 +26,9 @@ class ResearchAttentionConfig(SparseAttentionConfig): """Configuration class for research attention mechanisms.""" masker_configs: List[MaskerConfig] + recovery_enabled: bool = False + recovery_interval: int = 400 + recovery_dense_attention: Optional[str] = None class ResearchAttention(SparseAttention): @@ -87,6 +90,8 @@ def custom_attention( Returns: Tuple of attention output and optional attention weights. """ + #if kwargs["layer_idx"] == 0: + # print(f"ResearchAttention.custom_attention called", flush=True) # Create an empty Mask object mask_shape: Tuple[int, int, int, int] = ( queries.shape[0], diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py index 83c4f877..7539a3fd 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py +++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py @@ -327,7 +327,7 @@ def _update_and_return_key_signatures( concatenated_signatures: torch.Tensor = torch.cat( [cached_signatures, new_signatures], dim=2 ) - sparse_meta_data["key"][layer_idx] = concatenated_signatures + #sparse_meta_data["key"][layer_idx] = concatenated_signatures return concatenated_signatures def _compute_hashattention_score(