From f98974892edda3ea90d7e0f3e60547ccca0d086e Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Sat, 6 Sep 2025 07:30:47 +0000 Subject: [PATCH 1/5] WIP --- benchmark/executor.py | 11 +- benchmark/scripts/benchmark.py | 980 +++++++++++++++++- sparse_attention_hub/adapters/huggingface.py | 266 ++++- .../research_attention/base.py | 3 + 4 files changed, 1228 insertions(+), 32 deletions(-) diff --git a/benchmark/executor.py b/benchmark/executor.py index 0dd5c60a..9d033068 100644 --- a/benchmark/executor.py +++ b/benchmark/executor.py @@ -217,12 +217,21 @@ def _benchmark_worker( # Import here to avoid issues with multiprocessing from sparse_attention_hub.adapters.huggingface import ModelAdapterHF + from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig + + # Extract recovery settings if available + recovery_kwargs = {} + if isinstance(stub.sparse_attention_config, ResearchAttentionConfig): + recovery_kwargs['recovery_enabled'] = stub.sparse_attention_config.recovery_enabled + recovery_kwargs['recovery_interval'] = stub.sparse_attention_config.recovery_interval + recovery_kwargs['recovery_dense_attention'] = stub.sparse_attention_config.recovery_dense_attention adapter = ModelAdapterHF( model_name=stub.model_name, sparse_attention_config=stub.sparse_attention_config, model_kwargs=stub.adapter_config.model_kwargs, - tokenizer_kwargs=stub.adapter_config.tokenizer_kwargs + tokenizer_kwargs=stub.adapter_config.tokenizer_kwargs, + **recovery_kwargs ) # Create benchmark instance diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index 289fa0f1..df6167bf 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -15,7 +15,7 @@ from pathlib import Path # Add the project root to the path -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from benchmark.executor import BenchmarkExecutor from benchmark.executor_config import BenchmarkConfig, AdapterConfig @@ -23,31 +23,963 @@ from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import ( LocalMaskerConfig, SinkMaskerConfig ) +from sparse_attention_hub.sparse_attention import ( + ChannelConfig, + HashAttentionTopKMaskerConfig +) + +from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import ( + AdaptiveSamplingMaskerConfig +) # ============================================================================ # CONFIGURATION # ============================================================================ # GPU Configuration -GPUS = [0,2,7] # Use all available GPUs -MAX_CONCURRENT_RUNS = 3 # One per GPU +GPUS = [0] # Use all available GPUs +MAX_CONCURRENT_RUNS = 1 # One per GPU # Model List MODELS = [ - "microsoft/Phi-4-mini-instruct", - "meta-llama/Llama-3.2-1B-Instruct", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ] +usa_weight_file = "/home/ubuntu/alex/sparse-attention-hub/HashAttention-1.0/artifacts/llama3.1-8b-patch.64K.v1.pt" +weight_file = "/home/ubuntu/alex/sparse-attention-hub/HashAttention-1.0/artifacts/llama3.1-8b-patch.64K.v1.hat_weights.pkl" + +from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa +create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") + # Sparse Attention Configurations SPARSE_CONFIGS = [ # Dense baseline (no sparse attention) - ("dense", None), - - # StreamingLLM configurations - ("streaming_conservative", ResearchAttentionConfig(masker_configs=[ - SinkMaskerConfig(sink_size=4), - LocalMaskerConfig(window_size=16) - ])), + #("dense", None), + # hat2_NO_recovery_heavy_0.05 - 4 iterations + ("hat2_NO_recovery_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=False, + )), + + ("hat2_NO_recovery_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=False, + )), + + ("hat2_NO_recovery_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=False, + )), + + ("hat2_NO_recovery_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=False, + )), + + # hat2_recovery_10000_heavy_0.05 - 4 iterations + ("hat2_recovery_10000_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=10000, + )), + + ("hat2_recovery_10000_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=10000, + )), + + ("hat2_recovery_10000_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=10000, + )), + + ("hat2_recovery_10000_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=10000, + )), + + # hat2_recovery_100_heavy_0.05 - 4 iterations + ("hat2_recovery_100_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=100, + )), + + ("hat2_recovery_100_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=100, + )), + + ("hat2_recovery_100_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=100, + )), + + ("hat2_recovery_100_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=100, + )), + + # hat2_recovery_200_heavy_0.05 - 4 iterations + ("hat2_recovery_200_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=200, + )), + + ("hat2_recovery_200_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=200, + )), + + ("hat2_recovery_200_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=200, + )), + + ("hat2_recovery_200_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=200, + )), + + # hat2_recovery_300_heavy_0.05 - 4 iterations + ("hat2_recovery_300_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=300, + )), + + ("hat2_recovery_300_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=300, + )), + + ("hat2_recovery_300_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=300, + )), + + ("hat2_recovery_300_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=300, + )), + + # hat2_recovery_500_heavy_0.05 - 4 iterations + ("hat2_recovery_500_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=500, + )), + + ("hat2_recovery_500_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=500, + )), + + ("hat2_recovery_500_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=500, + )), + + ("hat2_recovery_500_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=500, + )), + + # hat2_recovery_1000_heavy_0.05 - 4 iterations + ("hat2_recovery_1000_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=1000, + )), + + ("hat2_recovery_1000_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=1000, + )), + + ("hat2_recovery_1000_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=1000, + )), + + ("hat2_recovery_1000_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=1000, + )), + + # hat2_recovery_2000_heavy_0.05 - 4 iterations + ("hat2_recovery_2000_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=2000, + )), + + ("hat2_recovery_2000_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=2000, + )), + + ("hat2_recovery_2000_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=2000, + )), + + ("hat2_recovery_2000_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=2000, + )), + + # hat2_recovery_5000_heavy_0.05 - 4 iterations + ("hat2_recovery_5000_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=5000, + )), + + ("hat2_recovery_5000_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=5000, + )), + + ("hat2_recovery_5000_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=5000, + )), + + ("hat2_recovery_5000_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=5000, + )), + + # hat2_recovery_10000_heavy_0.05 - 4 iterations (Note: This was duplicated earlier, so I'm placing it here in proper order) + ("hat2_recovery_20000_heavy_0.05_1", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=20000, + )), + + ("hat2_recovery_20000_heavy_0.05_2", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=20000, + )), + + ("hat2_recovery_20000_heavy_0.05_3", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=20000, + )), + + ("hat2_recovery_20000_heavy_0.05_4", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.05, # 20% error bound + delta=0.05, # 20% confidence bound + init_offset=0.01, # Start sampling after local window + local_offset=0.01 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=20000, + )), ] # Benchmark List @@ -107,15 +1039,7 @@ # List of all sample configurations BENCHMARKS = [ - infinite_bench_config, - ruler_config, - loogle_config, - zero_scrolls_config, - longbenchv2_config, - aime2024_config, - aime2025_config, - longbench_config, - mock_benchmark_config + aime2024_config ] @@ -132,23 +1056,23 @@ # Generation Parameters GENERATION_KWARGS = { - "max_new_tokens": 50, - "do_sample": False, - "temperature": 1.0, - "top_p": 1.0, + "max_new_tokens": 32768, + "do_sample": True, + "temperature": 0.6, + "top_p": 0.95, "pad_token_id": None, } # Request Parameters REQUEST_KWARGS = { - "max_context_length": 256, - "max_requests": 2, # Limit for testing + "max_context_length": 32768, + "max_requests": 30, # Limit for testing } # Execution Settings RESULT_DIR = "./benchmark_results" ENABLE_RESUMABILITY = True -TIMEOUT_PER_BENCHMARK = 3600.0 # 1 hour +TIMEOUT_PER_BENCHMARK = 60 * 60 * 24 # 1 day # ============================================================================ # MAIN EXECUTION diff --git a/sparse_attention_hub/adapters/huggingface.py b/sparse_attention_hub/adapters/huggingface.py index 49a6f083..7d42f7f5 100644 --- a/sparse_attention_hub/adapters/huggingface.py +++ b/sparse_attention_hub/adapters/huggingface.py @@ -17,6 +17,95 @@ INT_MAX = 2**31 - 1 +def _apply_temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor: + """Apply temperature scaling to logits. + + Args: + logits: Input logits tensor + temperature: Temperature parameter (1.0 = no scaling, <1.0 = sharper, >1.0 = smoother) + + Returns: + Temperature-scaled logits + """ + if temperature <= 0: + raise ValueError("Temperature must be positive") + return logits / temperature + + +def _apply_top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor: + """Apply top-p (nucleus) filtering to logits. + + Args: + logits: Input logits tensor of shape [..., vocab_size] + top_p: Cumulative probability threshold (0.0 to 1.0) + + Returns: + Filtered logits with low-probability tokens set to -inf + """ + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be between 0.0 and 1.0") + + if top_p >= 1.0: + return logits + + # Sort logits in descending order + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + + # Convert to probabilities and compute cumulative sum + sorted_probs = torch.nn.functional.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + + # Create mask for tokens to keep (cumulative probability <= top_p) + # Keep at least one token (the highest probability one) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = False # Keep the top token + + # Scatter back to original indices + indices_to_remove = sorted_indices_to_remove.scatter( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + + # Set filtered tokens to -inf + filtered_logits = logits.clone() + filtered_logits[indices_to_remove] = float('-inf') + + return filtered_logits + + +def _sample_token( + logits: torch.Tensor, + do_sample: bool = True, + temperature: float = 1.0, + top_p: float = 1.0, +) -> torch.Tensor: + """Sample next token from logits with optional temperature and top_p filtering. + + Args: + logits: Logits tensor of shape [..., vocab_size] + do_sample: Whether to use sampling (True) or greedy decoding (False) + temperature: Temperature for scaling logits (only used if do_sample=True) + top_p: Top-p threshold for nucleus sampling (only used if do_sample=True) + + Returns: + Sampled token indices + """ + if not do_sample: + # Greedy decoding + return torch.argmax(logits, dim=-1) + + # Apply temperature scaling + if temperature != 1.0: + logits = _apply_temperature_scaling(logits, temperature) + + # Apply top-p filtering + if top_p < 1.0: + logits = _apply_top_p_filtering(logits, top_p) + + # Convert to probabilities and sample + probs = torch.nn.functional.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + class ModelAdapterHF(ModelAdapter): """ModelAdapter for HuggingFace integration. Provides concrete implementations for huggingface's transformer library. @@ -29,6 +118,9 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = None, + recovery_enabled: bool = False, + recovery_interval: int = 400, + recovery_dense_attention: Optional[str] = None, **kwargs: Dict[str, Any], ) -> None: """Initialize HuggingFace adapter. @@ -39,13 +131,24 @@ def __init__( model_kwargs: Additional keyword arguments for model creation device: Device to run the model on TODO: support dynamic and multipledevice placement tokenizer_kwargs: Additional keyword arguments for tokenizer creation + recovery_enabled: Whether to enable recovery mechanism during generation + recovery_interval: Number of tokens after which to trigger recovery + (regenerate embeddings with full attention) + recovery_dense_attention: Override attention implementation for recovery + (if None, uses original implementation) """ super().__init__(model_name, sparse_attention_config, **kwargs) self._registered_attention_name: Optional[str] = None self._custom_attention_fn: Optional[Callable] = None + self._original_implementations: Dict[str, str] = {} self.model_kwargs = model_kwargs or {} self.tokenizer_kwargs = tokenizer_kwargs or {} + # Recovery mechanism parameters + self.recovery_enabled = recovery_enabled + self.recovery_interval = recovery_interval + self.recovery_dense_attention = recovery_dense_attention + # more useful parameters to store self.device = ( device if device else ("cuda" if torch.cuda.is_available() else "cpu") @@ -316,6 +419,9 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: ): original_implementations[name] = module.config._attn_implementation + # Store original implementations as instance variable for use during generation + self._original_implementations = original_implementations + # Ensure custom attention function is registered (reuse if already registered) custom_attention_name: str = self._ensure_attention_registered() @@ -335,6 +441,9 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: if name in original_implementations: module.config._attn_implementation = original_implementations[name] + # Clean up instance variable + self._original_implementations = {} + def _generate_response( self, question_tokens: torch.Tensor, @@ -363,6 +472,11 @@ def _generate_response( max_new_tokens: int = generation_kwargs.get("max_new_tokens", 50) # type: ignore context_length: int = context_outputs.past_key_values.get_seq_length() + # Extract sampling parameters from generation_kwargs + do_sample: bool = generation_kwargs.get("do_sample", False) + temperature: float = generation_kwargs.get("temperature", 1.0) + top_p: float = generation_kwargs.get("top_p", 1.0) + position_ids = torch.arange( context_length, context_length + question_tokens.shape[1], @@ -379,12 +493,36 @@ def _generate_response( ) position_ids = position_ids[:, -1:] + 1 - generated_ids = [question_outputs.logits[0, -1].argmax()] + + # Use proper sampling instead of greedy argmax + first_token = _sample_token( + question_outputs.logits[0, -1:], + do_sample=do_sample, + temperature=temperature, + top_p=top_p + ) + generated_ids = [first_token.squeeze()] should_stop_token_ids = self.model.generation_config.eos_token_id if not isinstance(should_stop_token_ids, list): should_stop_token_ids = [should_stop_token_ids] + # Track newly generated tokens for recovery mechanism + new_tokens_generated = 0 + generation_start_cache_length = context_outputs.past_key_values.get_seq_length() + + if self.recovery_enabled: + print(f"Recovery enabled: regenerate embeddings every " + f"{self.recovery_interval} new tokens") + else: + print("Recovery disabled: using sparse attention for answer generation") + + # Print sampling configuration + if do_sample: + print(f"Sampling enabled: temperature={temperature}, top_p={top_p}") + else: + print("Greedy decoding enabled") + for i in range(max_new_tokens - 1): with torch.no_grad(): outputs = self.model( @@ -393,14 +531,136 @@ def _generate_response( position_ids=position_ids + i, sparse_meta_data=sparse_meta_data, ) - # TODO: support other forms of decoding - new_id = outputs.logits[0, -1].argmax() + + # Use proper sampling instead of greedy argmax + new_id = _sample_token( + outputs.logits[0, -1:], + do_sample=do_sample, + temperature=temperature, + top_p=top_p + ).squeeze() generated_ids.append(new_id) + new_tokens_generated += 1 if new_id.item() in should_stop_token_ids: break + # Check if we need to regenerate embeddings (only if recovery is enabled) + if (self.recovery_enabled and + new_tokens_generated >= self.recovery_interval and + i < max_new_tokens - 2): # Don't regenerate on the last token + + print(f"Regenerating embeddings for {new_tokens_generated} " + f"newly generated tokens") + self._regenerate_embeddings_for_new_tokens( + context_outputs.past_key_values, + generated_ids[-new_tokens_generated:], # Last N generated tokens + generation_start_cache_length, + new_tokens_generated, + sparse_meta_data, + self._original_implementations + ) + new_tokens_generated = 0 # Reset counter + answer: str = self.tokenizer.decode( torch.stack(generated_ids), skip_special_tokens=True ) + return answer + + def _regenerate_embeddings_for_new_tokens( + self, + cache: Any, + new_token_ids: List[torch.Tensor], + start_cache_length: int, + num_new_tokens: int, + sparse_meta_data: Dict[str, Any], + original_implementations: Dict[str, str] + ) -> None: + """Regenerate embeddings for newly generated tokens using full attention. + + This removes the KV cache entries for the newly generated tokens and regenerates + them using full attention (dense mode), then continues with sparse attention. + + Args: + cache: The KV cache to modify + new_token_ids: List of newly generated token IDs + start_cache_length: Cache length when generation started + num_new_tokens: Number of new tokens to regenerate embeddings for + sparse_meta_data: Sparse metadata dictionary + original_implementations: Dict mapping module names to their original + attention implementations + """ + current_cache_length = cache.get_seq_length() + + # Remove embeddings for the newly generated tokens (keep everything before them) + keep_length = current_cache_length - num_new_tokens + + print(f"Removing embeddings for {num_new_tokens} tokens " + f"(keeping first {keep_length} tokens)") + + # Truncate cache to remove new token embeddings + for layer_idx in range(len(cache.key_cache)): + if cache.key_cache[layer_idx] is not None: + cache.key_cache[layer_idx] = ( + cache.key_cache[layer_idx][:, :, :keep_length] + ) + if cache.value_cache[layer_idx] is not None: + cache.value_cache[layer_idx] = ( + cache.value_cache[layer_idx][:, :, :keep_length] + ) + + # Handle quantized caches if present + if hasattr(cache, "_quantized_key_cache"): + for layer_idx in range(len(cache._quantized_key_cache)): + if cache._quantized_key_cache[layer_idx] is not None: + cache._quantized_key_cache[layer_idx] = ( + cache._quantized_key_cache[layer_idx][:, :, :keep_length] + ) + if cache._quantized_value_cache[layer_idx] is not None: + cache._quantized_value_cache[layer_idx] = ( + cache._quantized_value_cache[layer_idx][:, :, :keep_length] + ) + # Regenerate embeddings using full attention (one forward pass) + print(f"Regenerating embeddings using full attention for {num_new_tokens} tokens") + + # Create input tensor for the new tokens + new_tokens_tensor = torch.stack(new_token_ids).unsqueeze(0).to(self.model.device) + + # Create position IDs for the new tokens + position_ids = torch.arange( + keep_length, keep_length + num_new_tokens, device=self.model.device + ).unsqueeze(0) + + # Temporarily disable sparse mode to force dense attention + print("Forcing dense attention for regeneration") + + # Store current sparse implementations and switch to dense implementations + current_sparse_implementations: Dict[str, str] = {} + for name, module in self.model.named_modules(): + has_config = hasattr(module, "config") + has_attn_impl = has_config and hasattr(module.config, "_attn_implementation") + if name in original_implementations and has_attn_impl: + current_sparse_implementations[name] = module.config._attn_implementation + # Use override if provided, otherwise use original implementation + dense_implementation = ( + self.recovery_dense_attention or original_implementations[name] + ) + module.config._attn_implementation = dense_implementation + try: + # Regenerate embeddings with dense attention + with torch.no_grad(): + self.model( + input_ids=new_tokens_tensor, + past_key_values=cache, + position_ids=position_ids, + ) + + print(f"Successfully regenerated embeddings. Cache length: {cache.get_seq_length()}") + + finally: + # Restore sparse attention implementations + for name, module in self.model.named_modules(): + if name in current_sparse_implementations: + module.config._attn_implementation = current_sparse_implementations[name] + print("Restored sparse attention mode") diff --git a/sparse_attention_hub/sparse_attention/research_attention/base.py b/sparse_attention_hub/sparse_attention/research_attention/base.py index 34189068..aa709a9a 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/base.py +++ b/sparse_attention_hub/sparse_attention/research_attention/base.py @@ -26,6 +26,9 @@ class ResearchAttentionConfig(SparseAttentionConfig): """Configuration class for research attention mechanisms.""" masker_configs: List[MaskerConfig] + recovery_enabled: bool = False + recovery_interval: int = 400 + recovery_dense_attention: Optional[str] = None class ResearchAttention(SparseAttention): From 0d49f4289fc93e99e9560b46a5b8886ad86c00a1 Mon Sep 17 00:00:00 2001 From: Aditya Desai Date: Sat, 13 Sep 2025 04:19:20 +0000 Subject: [PATCH 2/5] Fixes 1. Fix restoring sparse attntion 2. load correct patch for hashattention 3. do not cache hashattention keys --- benchmark/scripts/benchmark.py | 196 ++++++++++-------- sparse_attention_hub/adapters/huggingface.py | 13 +- .../research_attention/base.py | 2 + .../implementations/hashattention_top_k.py | 2 +- 4 files changed, 123 insertions(+), 90 deletions(-) diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index df6167bf..3aee2051 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -45,8 +45,8 @@ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ] -usa_weight_file = "/home/ubuntu/alex/sparse-attention-hub/HashAttention-1.0/artifacts/llama3.1-8b-patch.64K.v1.pt" -weight_file = "/home/ubuntu/alex/sparse-attention-hub/HashAttention-1.0/artifacts/llama3.1-8b-patch.64K.v1.hat_weights.pkl" +usa_weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" +weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") @@ -61,8 +61,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -83,8 +83,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -105,8 +105,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -127,8 +127,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -150,8 +150,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -173,8 +173,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -196,8 +196,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -219,8 +219,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -243,8 +243,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -266,8 +266,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -289,8 +289,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -312,8 +312,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -336,8 +336,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -359,8 +359,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -382,8 +382,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -405,8 +405,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -429,8 +429,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -452,8 +452,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -475,8 +475,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -498,8 +498,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -522,8 +522,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -545,8 +545,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -568,8 +568,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -591,8 +591,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -615,8 +615,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -638,8 +638,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -661,8 +661,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -684,8 +684,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -708,8 +708,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -731,8 +731,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -754,8 +754,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -777,8 +777,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -801,8 +801,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -824,8 +824,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -847,8 +847,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -870,8 +870,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -894,8 +894,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -917,8 +917,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -940,8 +940,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -963,8 +963,8 @@ SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) LocalMaskerConfig(window_size=128), # Local attention window HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, + hat_bits=64, + hat_mlp_layers=2, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, @@ -982,6 +982,36 @@ )), ] + +SPARSE_CONFIGS = [ + #("dense", None), + ("test_hat_adpative_hat", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) + LocalMaskerConfig(window_size=128), # Local attention window + HashAttentionTopKMaskerConfig(heavy_size=0.05, + hat_bits=64, + hat_mlp_layers=2, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.05, # 10% base sampling rate + epsilon=0.025, # 20% error bound + delta=0.025, # 20% confidence bound + init_offset=128, # Start sampling after local window + local_offset=128 # Sample within local context + ) + ], + recovery_enabled=True, + recovery_interval=25, + )), + +] + + + # Benchmark List # 1. InfiniteBench - using passkey task infinite_bench_config = BenchmarkConfig( @@ -1070,7 +1100,7 @@ } # Execution Settings -RESULT_DIR = "./benchmark_results" +RESULT_DIR = "./benchmark_results_test" ENABLE_RESUMABILITY = True TIMEOUT_PER_BENCHMARK = 60 * 60 * 24 # 1 day diff --git a/sparse_attention_hub/adapters/huggingface.py b/sparse_attention_hub/adapters/huggingface.py index 7d42f7f5..6e035945 100644 --- a/sparse_attention_hub/adapters/huggingface.py +++ b/sparse_attention_hub/adapters/huggingface.py @@ -431,6 +431,7 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: if hasattr(module, "config") and hasattr( module.config, "_attn_implementation" ): + #print(f"Switching to sparse attention for {name}", module.config._attn_implementation, "->", custom_attention_name, flush=True) module.config._attn_implementation = custom_attention_name yield @@ -439,6 +440,7 @@ def enable_sparse_mode(self) -> Generator[None, None, None]: # Restore original implementations for name, module in self.model.named_modules(): if name in original_implementations: + #print(f"Restoring original implementation for {name}", module.config._attn_implementation, "->", original_implementations[name], flush=True) module.config._attn_implementation = original_implementations[name] # Clean up instance variable @@ -523,7 +525,7 @@ def _generate_response( else: print("Greedy decoding enabled") - for i in range(max_new_tokens - 1): + for i in tqdm(range(max_new_tokens - 1)): with torch.no_grad(): outputs = self.model( input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0), @@ -636,12 +638,10 @@ def _regenerate_embeddings_for_new_tokens( print("Forcing dense attention for regeneration") # Store current sparse implementations and switch to dense implementations - current_sparse_implementations: Dict[str, str] = {} for name, module in self.model.named_modules(): has_config = hasattr(module, "config") has_attn_impl = has_config and hasattr(module.config, "_attn_implementation") if name in original_implementations and has_attn_impl: - current_sparse_implementations[name] = module.config._attn_implementation # Use override if provided, otherwise use original implementation dense_implementation = ( self.recovery_dense_attention or original_implementations[name] @@ -661,6 +661,7 @@ def _regenerate_embeddings_for_new_tokens( finally: # Restore sparse attention implementations for name, module in self.model.named_modules(): - if name in current_sparse_implementations: - module.config._attn_implementation = current_sparse_implementations[name] - print("Restored sparse attention mode") + if hasattr(module, "config") and hasattr( + module.config, "_attn_implementation" + ): + module.config._attn_implementation = self._registered_attention_name diff --git a/sparse_attention_hub/sparse_attention/research_attention/base.py b/sparse_attention_hub/sparse_attention/research_attention/base.py index aa709a9a..bd31374b 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/base.py +++ b/sparse_attention_hub/sparse_attention/research_attention/base.py @@ -90,6 +90,8 @@ def custom_attention( Returns: Tuple of attention output and optional attention weights. """ + #if kwargs["layer_idx"] == 0: + # print(f"ResearchAttention.custom_attention called", flush=True) # Create an empty Mask object mask_shape: Tuple[int, int, int, int] = ( queries.shape[0], diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py index 83c4f877..7539a3fd 100644 --- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py +++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py @@ -327,7 +327,7 @@ def _update_and_return_key_signatures( concatenated_signatures: torch.Tensor = torch.cat( [cached_signatures, new_signatures], dim=2 ) - sparse_meta_data["key"][layer_idx] = concatenated_signatures + #sparse_meta_data["key"][layer_idx] = concatenated_signatures return concatenated_signatures def _compute_hashattention_score( From ba91a8b9a8dfd7247e28536b1dd3233f16e17765 Mon Sep 17 00:00:00 2001 From: Aditya Desai Date: Tue, 16 Sep 2025 17:36:28 +0000 Subject: [PATCH 3/5] oracle-topk configs --- benchmark/scripts/benchmark.py | 97 +++++++++++++++++++++++++--------- 1 file changed, 71 insertions(+), 26 deletions(-) diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index 3aee2051..3089f798 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -25,13 +25,16 @@ ) from sparse_attention_hub.sparse_attention import ( ChannelConfig, - HashAttentionTopKMaskerConfig + HashAttentionTopKMaskerConfig, ) from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import ( AdaptiveSamplingMaskerConfig ) +from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import ( + OracleTopKConfig +) # ============================================================================ # CONFIGURATION # ============================================================================ @@ -52,7 +55,7 @@ create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") # Sparse Attention Configurations -SPARSE_CONFIGS = [ +ALEX_SPARSE_CONFIGS = [ # Dense baseline (no sparse attention) #("dense", None), # hat2_NO_recovery_heavy_0.05 - 4 iterations @@ -985,33 +988,75 @@ SPARSE_CONFIGS = [ #("dense", None), - ("test_hat_adpative_hat", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.025, # 20% error bound - delta=0.025, # 20% confidence bound - init_offset=128, # Start sampling after local window - local_offset=128 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=25, + ("test_oracle_topk_adaptive_norecovery", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + OracleTopKConfig(heavy_size=0.025), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.025, + epsilon=0.05, + delta=0.05, + init_offset=128, + local_offset=128 + ) + ], + recovery_enabled=False, + recovery_interval=32000, + )), + ("test_oracle_topk_adaptive_recovery_100", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + OracleTopKConfig(heavy_size=0.025), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.025, + epsilon=0.05, + delta=0.05, + init_offset=128, + local_offset=128 + ) + ], + recovery_enabled=True, + recovery_interval=100, + )), + ("test_oracle_topk_adaptive_recovery_400", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + OracleTopKConfig(heavy_size=0.025), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.025, + epsilon=0.05, + delta=0.05, + init_offset=128, + local_offset=128 + ) + ], + recovery_enabled=True, + recovery_interval=400, + )), + ("test_oracle_topk_adaptive_recovery_800", ResearchAttentionConfig( + masker_configs=[ + SinkMaskerConfig(sink_size=128), + LocalMaskerConfig(window_size=128), + OracleTopKConfig(heavy_size=0.025), + AdaptiveSamplingMaskerConfig( + base_rate_sampling=0.025, + epsilon=0.05, + delta=0.05, + init_offset=128, + local_offset=128 + ) + ], + recovery_enabled=True, + recovery_interval=800, )), - ] + # Benchmark List # 1. InfiniteBench - using passkey task infinite_bench_config = BenchmarkConfig( @@ -1096,11 +1141,11 @@ # Request Parameters REQUEST_KWARGS = { "max_context_length": 32768, - "max_requests": 30, # Limit for testing + "max_requests": 2, # Limit for testing } # Execution Settings -RESULT_DIR = "./benchmark_results_test" +RESULT_DIR = "./benchmark_results_test.1" ENABLE_RESUMABILITY = True TIMEOUT_PER_BENCHMARK = 60 * 60 * 24 # 1 day From 701fcd81d28ee379d528b0c282cdd6f72cdcffb2 Mon Sep 17 00:00:00 2001 From: Alex Cuadron Lafuente Date: Thu, 18 Sep 2025 14:57:08 -0700 Subject: [PATCH 4/5] WIP --- benchmark/scripts/benchmark.py | 999 +-------------------------------- 1 file changed, 29 insertions(+), 970 deletions(-) diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index 3089f798..0f8dd7bc 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -48,1010 +48,69 @@ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ] -usa_weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" -weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" +usa_weight_file = "/nvme/sparse-attention-hub/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" +weight_file = "/nvme/sparse-attention-hub/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") # Sparse Attention Configurations -ALEX_SPARSE_CONFIGS = [ - # Dense baseline (no sparse attention) - #("dense", None), - # hat2_NO_recovery_heavy_0.05 - 4 iterations - ("hat2_NO_recovery_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=False, - )), - - ("hat2_NO_recovery_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=False, - )), - - ("hat2_NO_recovery_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=False, - )), - - ("hat2_NO_recovery_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=False, - )), - - # hat2_recovery_10000_heavy_0.05 - 4 iterations - ("hat2_recovery_10000_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=10000, - )), - - ("hat2_recovery_10000_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=10000, - )), - - ("hat2_recovery_10000_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=10000, - )), - - ("hat2_recovery_10000_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=10000, - )), - - # hat2_recovery_100_heavy_0.05 - 4 iterations - ("hat2_recovery_100_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=100, - )), - - ("hat2_recovery_100_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=100, - )), - - ("hat2_recovery_100_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=100, - )), - - ("hat2_recovery_100_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=100, - )), - - # hat2_recovery_200_heavy_0.05 - 4 iterations - ("hat2_recovery_200_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=200, - )), - - ("hat2_recovery_200_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=200, - )), - - ("hat2_recovery_200_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=200, - )), - - ("hat2_recovery_200_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=200, - )), - - # hat2_recovery_300_heavy_0.05 - 4 iterations - ("hat2_recovery_300_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=300, - )), - - ("hat2_recovery_300_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=300, - )), - - ("hat2_recovery_300_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=300, - )), - - ("hat2_recovery_300_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=300, - )), - - # hat2_recovery_500_heavy_0.05 - 4 iterations - ("hat2_recovery_500_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=500, - )), - - ("hat2_recovery_500_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=500, - )), - - ("hat2_recovery_500_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=500, - )), - - ("hat2_recovery_500_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=500, - )), - - # hat2_recovery_1000_heavy_0.05 - 4 iterations - ("hat2_recovery_1000_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=1000, - )), - - ("hat2_recovery_1000_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=1000, - )), - - ("hat2_recovery_1000_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=1000, - )), - - ("hat2_recovery_1000_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=1000, - )), - - # hat2_recovery_2000_heavy_0.05 - 4 iterations - ("hat2_recovery_2000_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=2000, - )), - - ("hat2_recovery_2000_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=2000, - )), - - ("hat2_recovery_2000_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=2000, - )), - - ("hat2_recovery_2000_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=2000, - )), - - # hat2_recovery_5000_heavy_0.05 - 4 iterations - ("hat2_recovery_5000_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=5000, - )), - - ("hat2_recovery_5000_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=5000, - )), - - ("hat2_recovery_5000_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=5000, - )), - - ("hat2_recovery_5000_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=5000, - )), - - # hat2_recovery_10000_heavy_0.05 - 4 iterations (Note: This was duplicated earlier, so I'm placing it here in proper order) - ("hat2_recovery_20000_heavy_0.05_1", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=20000, - )), - - ("hat2_recovery_20000_heavy_0.05_2", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=20000, - )), - - ("hat2_recovery_20000_heavy_0.05_3", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=20000, - )), - - ("hat2_recovery_20000_heavy_0.05_4", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), # Keep first 128 tokens (sink attention) - LocalMaskerConfig(window_size=128), # Local attention window - HashAttentionTopKMaskerConfig(heavy_size=0.05, - hat_bits=64, - hat_mlp_layers=2, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.05, # 10% base sampling rate - epsilon=0.05, # 20% error bound - delta=0.05, # 20% confidence bound - init_offset=0.01, # Start sampling after local window - local_offset=0.01 # Sample within local context - ) - ], - recovery_enabled=True, - recovery_interval=20000, - )), -] - SPARSE_CONFIGS = [ #("dense", None), - ("test_oracle_topk_adaptive_norecovery", ResearchAttentionConfig( + ("test_oracle_topk_norecovery", ResearchAttentionConfig( masker_configs=[ SinkMaskerConfig(sink_size=128), LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.025), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.025, - epsilon=0.05, - delta=0.05, - init_offset=128, - local_offset=128 - ) + OracleTopKConfig(heavy_size=0.05), ], recovery_enabled=False, recovery_interval=32000, )), - ("test_oracle_topk_adaptive_recovery_100", ResearchAttentionConfig( + + ("test_oracle_topk_recovery_100", ResearchAttentionConfig( masker_configs=[ SinkMaskerConfig(sink_size=128), LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.025), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.025, - epsilon=0.05, - delta=0.05, - init_offset=128, - local_offset=128 - ) + OracleTopKConfig(heavy_size=0.05), ], recovery_enabled=True, recovery_interval=100, )), - ("test_oracle_topk_adaptive_recovery_400", ResearchAttentionConfig( + ("test_hat_topk_recovery_100", ResearchAttentionConfig( masker_configs=[ SinkMaskerConfig(sink_size=128), LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.025), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.025, - epsilon=0.05, - delta=0.05, - init_offset=128, - local_offset=128 + HashAttentionTopKMaskerConfig( + heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None ) ], recovery_enabled=True, - recovery_interval=400, + recovery_interval=100, )), - ("test_oracle_topk_adaptive_recovery_800", ResearchAttentionConfig( + ("test_hat_topk_no_recovery", ResearchAttentionConfig( masker_configs=[ SinkMaskerConfig(sink_size=128), LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.025), - AdaptiveSamplingMaskerConfig( - base_rate_sampling=0.025, - epsilon=0.05, - delta=0.05, - init_offset=128, - local_offset=128 + HashAttentionTopKMaskerConfig( + heavy_size=0.05, + hat_bits=32, + hat_mlp_layers=3, + hat_mlp_hidden_size=128, + hat_mlp_activation="silu", + hat_weight_file=weight_file, + hat_weights=None ) ], - recovery_enabled=True, - recovery_interval=800, - )), + recovery_enabled=False, + recovery_interval=32000, + )) ] From bb17b9ca52b172000519621cd25f7a85f397c4fe Mon Sep 17 00:00:00 2001 From: Aditya Desai Date: Tue, 23 Sep 2025 05:47:41 +0000 Subject: [PATCH 5/5] fixing max_requests, tqdm --- benchmark/scripts/benchmark.py | 59 +++----------------- sparse_attention_hub/adapters/huggingface.py | 1 + 2 files changed, 9 insertions(+), 51 deletions(-) diff --git a/benchmark/scripts/benchmark.py b/benchmark/scripts/benchmark.py index 0f8dd7bc..6434f3ea 100644 --- a/benchmark/scripts/benchmark.py +++ b/benchmark/scripts/benchmark.py @@ -48,65 +48,21 @@ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" ] -usa_weight_file = "/nvme/sparse-attention-hub/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" -weight_file = "/nvme/sparse-attention-hub/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" +usa_weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.pt" +weight_file = "/workspace/HashAttention-1.0/artifacts/DeepSeek-R1-Distill-Llama-8B-patch-layers2-dim64-max-context-24K.hat_weights.pkl" -from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa -create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") +#from sparse_attention_hub.sparse_attention.utils.hashattention_utils import create_hat_weights_file_from_usa +#create_hat_weights_file_from_usa(usa_weight_file, weight_file, num_layers=32, num_heads=32, device="cpu") # Sparse Attention Configurations SPARSE_CONFIGS = [ #("dense", None), - ("test_oracle_topk_norecovery", ResearchAttentionConfig( + ("test_oracle_topk_norecovery_10pct_r1", ResearchAttentionConfig( masker_configs=[ SinkMaskerConfig(sink_size=128), LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.05), - ], - recovery_enabled=False, - recovery_interval=32000, - )), - - ("test_oracle_topk_recovery_100", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), - LocalMaskerConfig(window_size=128), - OracleTopKConfig(heavy_size=0.05), - ], - recovery_enabled=True, - recovery_interval=100, - )), - ("test_hat_topk_recovery_100", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), - LocalMaskerConfig(window_size=128), - HashAttentionTopKMaskerConfig( - heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None - ) - ], - recovery_enabled=True, - recovery_interval=100, - )), - ("test_hat_topk_no_recovery", ResearchAttentionConfig( - masker_configs=[ - SinkMaskerConfig(sink_size=128), - LocalMaskerConfig(window_size=128), - HashAttentionTopKMaskerConfig( - heavy_size=0.05, - hat_bits=32, - hat_mlp_layers=3, - hat_mlp_hidden_size=128, - hat_mlp_activation="silu", - hat_weight_file=weight_file, - hat_weights=None - ) + OracleTopKConfig(heavy_size=0.1), ], recovery_enabled=False, recovery_interval=32000, @@ -182,6 +138,7 @@ adapter_name="huggingface", model_kwargs={ "torch_dtype": torch.bfloat16, + "attn_implementation" : "flash_attention_2" }, tokenizer_kwargs={ "padding_side": "left", @@ -200,7 +157,7 @@ # Request Parameters REQUEST_KWARGS = { "max_context_length": 32768, - "max_requests": 2, # Limit for testing + "max_requests": 30, # Limit for testing } # Execution Settings diff --git a/sparse_attention_hub/adapters/huggingface.py b/sparse_attention_hub/adapters/huggingface.py index 6e035945..888abd5d 100644 --- a/sparse_attention_hub/adapters/huggingface.py +++ b/sparse_attention_hub/adapters/huggingface.py @@ -13,6 +13,7 @@ from ..sparse_attention.base import SparseAttention, SparseAttentionConfig from ..sparse_attention.research_attention.base import ResearchAttention from .base import ModelAdapter, Request, RequestResponse +from tqdm import tqdm INT_MAX = 2**31 - 1