Skip to content

Commit a0adf1e

Browse files
committed
Add Benchmarking Example
1. Single Benchmark run (useful to run single model x config x dataset) - useful for debugging 2. Launcing a matrix using Executor
1 parent 9cc2996 commit a0adf1e

File tree

3 files changed

+356
-0
lines changed

3 files changed

+356
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Minimalistic benchmark runner for sparse attention evaluation.
4+
5+
This script defines models, sparse attention configurations, and benchmarks,
6+
then runs comprehensive experiments using BenchmarkExecutor.
7+
8+
Usage:
9+
python benchmark/benchmark.py
10+
"""
11+
12+
import os
13+
import sys
14+
import torch
15+
from pathlib import Path
16+
17+
# Add the project root to the path
18+
os.chdir("/home/apd10/code/sparse-attention-hub/")
19+
sys.path.insert(0, "/home/apd10/code/sparse-attention-hub/")
20+
21+
from benchmark.executor import BenchmarkExecutor
22+
from benchmark.executor_config import BenchmarkConfig, AdapterConfig
23+
from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
24+
from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
25+
LocalMaskerConfig, SinkMaskerConfig, OracleTopKConfig, OracleTopPMaskerConfig
26+
)
27+
from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import (
28+
AdaptiveSamplingMaskerConfig, RandomSamplingMaskerConfig, MagicPigConfig
29+
)
30+
31+
# ============================================================================
32+
# CONFIGURATION
33+
# ============================================================================
34+
35+
# GPU Configuration
36+
GPUS = [3] # Use all available GPUs
37+
MAX_CONCURRENT_RUNS = 1 # One per GPU
38+
39+
INTENDED_SPARSITY = 0.1
40+
41+
# Model List
42+
MODELS = [
43+
"meta-llama/Llama-3.1-8B-Instruct",
44+
]
45+
46+
# Sparse Attention Configurations
47+
SPARSE_CONFIGS = [
48+
# Dense baseline (no sparse attention)
49+
("dense", None),
50+
51+
# StreamingLLM configurations
52+
("streaming_conservative", ResearchAttentionConfig(masker_configs=[
53+
SinkMaskerConfig(sink_size=128),
54+
LocalMaskerConfig(window_size=INTENDED_SPARSITY)
55+
])),
56+
#Oracle-TopK
57+
("streaming_oracle_topk", ResearchAttentionConfig(masker_configs=[
58+
SinkMaskerConfig(sink_size=128),
59+
LocalMaskerConfig(window_size=128),
60+
OracleTopKConfig(heavy_size=INTENDED_SPARSITY)
61+
])),
62+
# Oracle-TopP
63+
("streaming_oracle_topp", ResearchAttentionConfig(masker_configs=[
64+
SinkMaskerConfig(sink_size=128),
65+
LocalMaskerConfig(window_size=128),
66+
OracleTopPMaskerConfig(top_p=0.85)
67+
])),
68+
# Adaptive Sampling
69+
("streaming_adaptive_sampling", ResearchAttentionConfig(masker_configs=[
70+
SinkMaskerConfig(sink_size=128),
71+
LocalMaskerConfig(window_size=128),
72+
OracleTopKConfig(heavy_size=128),
73+
AdaptiveSamplingMaskerConfig(base_rate_sampling=0.05, epsilon=0.25, delta=0.25, init_offset=128, local_offset=128)
74+
])),
75+
# Random Sampling
76+
("streaming_random_sampling", ResearchAttentionConfig(masker_configs=[
77+
SinkMaskerConfig(sink_size=128),
78+
LocalMaskerConfig(window_size=128),
79+
RandomSamplingMaskerConfig(sampling_rate=0.1)
80+
])),
81+
# MagicPig
82+
("streaming_magicpig", ResearchAttentionConfig(masker_configs=[
83+
SinkMaskerConfig(sink_size=128),
84+
LocalMaskerConfig(window_size=128),
85+
MagicPigConfig(lsh_l=8, lsh_k=8)
86+
])),
87+
]
88+
89+
# Benchmark List
90+
# 1. InfiniteBench - using passkey task
91+
infinite_bench_config = BenchmarkConfig(
92+
benchmark_name="infinite_bench",
93+
subsets=["passkey"]
94+
)
95+
96+
# 2. Ruler - using 4096 context length
97+
ruler_config = BenchmarkConfig(
98+
benchmark_name="ruler",
99+
subsets=["4096"]
100+
)
101+
102+
# 3. Loogle - using shortdep_qa task
103+
loogle_config = BenchmarkConfig(
104+
benchmark_name="loogle",
105+
subsets=["shortdep_qa"],
106+
#subsets=["longdep_qa"],
107+
#subsets=["shortdep_cloze"],
108+
#subsets=["longdep_summarization"],
109+
)
110+
111+
# 4. ZeroScrolls - using gov_report task
112+
zero_scrolls_config = BenchmarkConfig(
113+
benchmark_name="zero_scrolls",
114+
subsets=["default"]
115+
)
116+
117+
# 5. LongBenchv2 - using 0shot task
118+
longbenchv2_config = BenchmarkConfig(
119+
benchmark_name="longbenchv2",
120+
subsets=["0shot"]
121+
)
122+
123+
# 6. AIME2024 - using single task
124+
aime2024_config = BenchmarkConfig(
125+
benchmark_name="aime2024",
126+
subsets=["aime2024"]
127+
)
128+
129+
# 7. AIME2025 - using single task
130+
aime2025_config = BenchmarkConfig(
131+
benchmark_name="aime2025",
132+
subsets=["aime2025"]
133+
)
134+
135+
# 8. LongBench (existing) - using narrativeqa task
136+
longbench_config = BenchmarkConfig(
137+
benchmark_name="longbench",
138+
subsets=["passage_retrieval_en"]
139+
)
140+
141+
# 9. Mock Benchmark (existing) - using single task
142+
mock_benchmark_config = BenchmarkConfig(
143+
benchmark_name="mock_benchmark",
144+
subsets=["reading_comprehension"]
145+
)
146+
147+
# List of all sample configurations
148+
BENCHMARKS = [
149+
#infinite_bench_config,
150+
#ruler_config,
151+
loogle_config,
152+
#zero_scrolls_config,
153+
#longbenchv2_config,
154+
#aime2024_config,
155+
#aime2025_config,
156+
#longbench_config,
157+
#mock_benchmark_config
158+
]
159+
160+
161+
# Adapter Configuration
162+
ADAPTER_CONFIG = AdapterConfig(
163+
adapter_name="huggingface",
164+
model_kwargs={
165+
"torch_dtype": torch.bfloat16,
166+
"attn_implementation": "flash_attention_2",
167+
},
168+
tokenizer_kwargs={
169+
"padding_side": "left",
170+
}
171+
)
172+
173+
# Generation Parameters
174+
GENERATION_KWARGS = {
175+
"max_new_tokens": 32000,
176+
"do_sample": False,
177+
"temperature": 1.0,
178+
"top_p": 1.0,
179+
"pad_token_id": None,
180+
}
181+
182+
# Request Parameters
183+
REQUEST_KWARGS = {
184+
"max_context_length": 16000,
185+
}
186+
187+
# Execution Settings
188+
RESULT_DIR = "./benchmark_results"
189+
ENABLE_RESUMABILITY = True
190+
TIMEOUT_PER_BENCHMARK = 3600.0 # 1 hour
191+
192+
# ============================================================================
193+
# MAIN EXECUTION
194+
# ============================================================================
195+
196+
if __name__ == "__main__":
197+
print("🚀 Starting Minimalistic Benchmark Suite")
198+
print("=" * 50)
199+
200+
print(f"🔧 Configuration:")
201+
print(f" - GPUs: {GPUS}")
202+
print(f" - Models: {len(MODELS)}")
203+
for i, model in enumerate(MODELS, 1):
204+
print(f" {i}. {model}")
205+
print(f" - Sparse configs: {len(SPARSE_CONFIGS)}")
206+
for name, config in SPARSE_CONFIGS:
207+
if config is None:
208+
print(f" - {name}: dense (no sparse attention)")
209+
else:
210+
sink_size = config.masker_configs[0].sink_size
211+
window_size = config.masker_configs[1].window_size
212+
print(f" - {name}: Sink({sink_size}) + Local({window_size})")
213+
print(f" - Benchmarks: {len(BENCHMARKS)}")
214+
for i, benchmark in enumerate(BENCHMARKS, 1):
215+
if benchmark.subsets:
216+
print(f" {i}. {benchmark.benchmark_name}: {len(benchmark.subsets)} subsets")
217+
else:
218+
print(f" {i}. {benchmark.benchmark_name}: all subsets")
219+
print(f" - Max concurrent: {MAX_CONCURRENT_RUNS}")
220+
print(f" - Result dir: {RESULT_DIR}")
221+
print(f" - Resumability: {'enabled' if ENABLE_RESUMABILITY else 'disabled'}")
222+
223+
# Calculate total combinations
224+
total_models = len(MODELS)
225+
total_configs = len(SPARSE_CONFIGS)
226+
total_benchmarks = sum(len(b.subsets) if b.subsets else 1 for b in BENCHMARKS)
227+
total_combinations = total_models * total_configs * total_benchmarks
228+
229+
print(f"\n📊 Experiment Matrix: {total_combinations} total combinations")
230+
print(f" - Models: {total_models}")
231+
print(f" - Sparse configs: {total_configs}")
232+
print(f" - Benchmark-subsets: {total_benchmarks}")
233+
print(f" - Estimated time: {total_combinations * TIMEOUT_PER_BENCHMARK / 3600:.1f} hours (worst case)")
234+
235+
# Create executor
236+
print(f"\n🔧 Initializing BenchmarkExecutor...")
237+
executor = BenchmarkExecutor(
238+
gpu_ids=GPUS,
239+
max_concurrent_runs=MAX_CONCURRENT_RUNS,
240+
base_result_dir=RESULT_DIR,
241+
enable_resumability=ENABLE_RESUMABILITY,
242+
required_result_files=["raw_results.csv"],
243+
timeout_per_benchmark=TIMEOUT_PER_BENCHMARK,
244+
verbose=True
245+
)
246+
247+
# Run benchmarks
248+
print(f"\n🎯 Running Benchmark Matrix...")
249+
try:
250+
results = executor.run_benchmark_matrix(
251+
model_names=MODELS,
252+
sparse_attention_configs=SPARSE_CONFIGS,
253+
benchmark_configs=BENCHMARKS,
254+
adapter_config=ADAPTER_CONFIG,
255+
generation_kwargs=GENERATION_KWARGS,
256+
request_kwargs=REQUEST_KWARGS
257+
)
258+
259+
# Print summary
260+
print(f"\n✅ Benchmark Execution Completed!")
261+
print(f" - Total: {results.progress.total_stubs}")
262+
print(f" - Completed: {results.progress.completed_stubs}")
263+
print(f" - Failed: {results.progress.failed_stubs}")
264+
print(f" - Skipped: {results.progress.skipped_stubs}")
265+
print(f" - Results saved to: {RESULT_DIR}")
266+
267+
except KeyboardInterrupt:
268+
print(f"\n⚠️ Interrupted by user")
269+
print(f" Partial results in: {RESULT_DIR}")
270+
except Exception as e:
271+
print(f"\n❌ Execution failed: {e}")
272+
raise
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple Benchmark Example
4+
5+
A beginner-friendly example showing how to run a basic benchmark comparison
6+
between dense and sparse attention using the sparse-attention-hub framework.
7+
8+
This example uses the MockBenchmark (5 simple samples) for quick demonstration:
9+
- Easy-to-understand reading comprehension questions
10+
- Short contexts (<250 words each)
11+
- Fast execution for testing and learning
12+
13+
Usage:
14+
python 04_simple_benchmark_example.py
15+
"""
16+
17+
import os
18+
import time
19+
from pathlib import Path
20+
21+
import torch
22+
23+
# Ensure we're in the correct directory and add to Python path
24+
import sys
25+
26+
# Change to directory two levels below current location
27+
os.chdir('/home/apd10/code/sparse-attention-hub')
28+
sys.path.insert(0, '/home/apd10/code/sparse-attention-hub')
29+
30+
from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
31+
from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
32+
LocalMaskerConfig, SinkMaskerConfig, OracleTopKConfig
33+
)
34+
from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import (
35+
AdaptiveSamplingMaskerConfig
36+
)
37+
38+
from benchmark import LongBench
39+
from sparse_attention_hub.adapters import ModelAdapterHF
40+
41+
def main():
42+
model_name = "meta-llama/Llama-3.1-8B-Instruct"
43+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44+
45+
sparse_attention_config = ResearchAttentionConfig(masker_configs=[
46+
SinkMaskerConfig(sink_size=128),
47+
LocalMaskerConfig(window_size=128),
48+
OracleTopKConfig(heavy_size=128),
49+
AdaptiveSamplingMaskerConfig(base_rate_sampling=0.05, epsilon=0.25, delta=0.25, init_offset=128, local_offset=128)
50+
])
51+
52+
print(" ✓ Loading model...")
53+
adapter = ModelAdapterHF(
54+
model_name=model_name,
55+
sparse_attention_config=sparse_attention_config,
56+
model_kwargs= {"torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2"},
57+
generate_kwargs={"max_new_tokens": 32},
58+
device=device
59+
)
60+
61+
benchmark = LongBench(["passage_retrieval_en"])
62+
63+
result_dir = Path("./test_results")
64+
result_dir.mkdir(exist_ok=True)
65+
66+
benchmark.run_benchmark(adapter, result_dir, request_kwargs={"max_requests": 1, "max_context_length": 16000})
67+
68+
if __name__ == "__main__":
69+
main()

sparse_attention_hub/sparse_attention/utils/mask.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,18 @@ def merge_mask(
642642
data=final_data,
643643
dtype=self.dtype,
644644
)
645+
646+
def get_density(self) -> float:
647+
"""
648+
Get the sparsity of the mask.
649+
"""
650+
if self.is_full:
651+
return 1.0
652+
elif self.is_empty():
653+
return 0.0
654+
elif self.from_dense_mask:
655+
return float(torch.sum(self.mask > 0) / self.mask.numel())
656+
elif self.from_index:
657+
return float(len(self.indices)) / float(np.prod(self.shape))
658+
else:
659+
raise RuntimeError("Mask object is in an invalid state")

0 commit comments

Comments
 (0)