diff --git a/benchmark/base.py b/benchmark/base.py
index 3dbfc478..1eb87147 100644
--- a/benchmark/base.py
+++ b/benchmark/base.py
@@ -169,8 +169,20 @@ def _process_all_requests(
# Group by context for efficiency (following HashAttention approach)
df_context = dataset_df.groupby("context")
- for context, df_group in tqdm(df_context, desc="Processing contexts", total=dataset_df["context"].nunique()):
+ # Track total questions processed
+ total_questions = len(dataset_df)
+ questions_processed = 0
+
+ pbar = tqdm(df_context, desc="Processing contexts", total=dataset_df["context"].nunique())
+ for idx, (context, df_group) in enumerate(pbar):
questions: List[str] = df_group["question"].to_list()
+ questions_processed += len(questions)
+ # Update progress bar suffix to show number of questions
+ pbar.set_postfix({
+ 'ctx': idx+1,
+ 'q': len(questions),
+ 'total_q': f"{questions_processed}/{total_questions}"
+ })
try:
# Create request using current adapter interface (simplified)
diff --git a/benchmark/executor.py b/benchmark/executor.py
index 0dd5c60a..2cd5b256 100644
--- a/benchmark/executor.py
+++ b/benchmark/executor.py
@@ -436,8 +436,14 @@ def __init__(
def _setup_signal_handlers(self) -> None:
"""Set up signal handlers for graceful shutdown."""
- signal.signal(signal.SIGTERM, _signal_handler)
- signal.signal(signal.SIGINT, _signal_handler)
+ import threading
+
+ # Only set up signal handlers if we're in the main thread
+ if threading.current_thread() is threading.main_thread():
+ signal.signal(signal.SIGTERM, _signal_handler)
+ signal.signal(signal.SIGINT, _signal_handler)
+ else:
+ self.logger.info("Skipping signal handler setup - not in main thread (running in Ray worker)")
# Register cleanup function to run at exit
import atexit
diff --git a/benchmark/raytune/README.md b/benchmark/raytune/README.md
new file mode 100644
index 00000000..76950955
--- /dev/null
+++ b/benchmark/raytune/README.md
@@ -0,0 +1,185 @@
+# Benchmark Runner Scripts
+
+This directory contains scripts for running benchmarks with optimal configurations from Phase 1.
+
+## Scripts
+
+### 1. run_ray_benchmarks.py
+The main benchmark runner using Ray for efficient parallel execution.
+
+<<<<<<< HEAD
+1. **Phase 1**: Hyperparameter search to find optimal configs for each (model, task, masker) combination
+2. **Phase 2**: Parallel benchmark execution using the discovered optimal configs
+
+
+## Quick Start
+
+```bash
+## install uv
+curl -LsSf https://astral.sh/uv/install.sh | sh
+source $HOME/.local/bin/env
+
+## clone repo, to branch feature/raytune
+git clone https://github.com/xAlg-ai/sparse-attention-hub
+git checkout feature/raytune
+
+## build env
+uv sync
+### need to install flash-attn after init to avoid torch dependencies
+uv add flash-attn --no-build-isolation
+source .venv/bin/activate
+
+## login required for huggingface
+export HF_TOKEN=...
+huggingface-cli login --token $HF_TOKEN
+
+
+## run benchmark in debug mode
+python benchmark/raytune/run_full_benchmark.py --debug
+```
+
+Expected output:
+
+```
+.
+├── optimal_configs/ # Phase 1 outputs
+│ └── run_20240315_143022/ # Timestamped run directory
+│ ├── meta-llama_Llama-3.1-8B-Instruct_loogle_shortdep_qa_sink_local_5pct.json # Best config
+│ ├── meta-llama_Llama-3.1-8B-Instruct_loogle_shortdep_qa_sink_local_5pct_trials.json # Trial details
+│ ├── meta-llama_Llama-3.1-8B-Instruct_loogle_shortdep_qa_sink_local_5pct_analysis.csv # Ray analysis
+│ └── ... (3 files per model-task-masker combination)
+│
+├── ray_results/ # Ray Tune working directory
+│ └── search_runs/ # Hyperparameter search experiments
+│ └── ... (Ray Tune experiment artifacts)
+│
+└── benchmark_results/ # Phase 2 outputs
+ ├── benchmark_summary.json # Overall benchmark summary
+ └── meta-llama_Llama-3.1-8B-Instruct/ # Sanitized model name
+ ├── dense/ # Dense baseline config
+ │ └── loogle_shortdep_qa/ # Benchmark_subset
+ │ └── raw_results.csv
+ ├── sink_local_oracle_top_k_adaptive_sampling/ # Sparse config name
+ │ └── loogle_shortdep_qa/
+ │ ├── raw_results.csv
+ │ └── micro_metrics.jsonl # Sparse attention metrics
+ └── sink_local_random_sampling/ # Another sparse config
+ └── loogle_shortdep_qa/
+ ├── raw_results.csv
+ └── micro_metrics.jsonl
+```
+
+
+
+
+
+
+## Basic Usage
+=======
+**Features:**
+- Stateful Ray actors managing GPU resources
+- Fresh model initialization for each task (required due to unique optimized parameters)
+- Real-time progress tracking with ETA
+- Dry run mode to preview execution plan
+- Debug mode for testing with reduced parameters
+- Automatic GPU resource management
+- Resume capability (skips completed benchmarks)
+>>>>>>> 6d836e5 (add parallel benchmark executor with ray, visualization scripts)
+
+**Usage:**
+```bash
+# Basic usage (uses all available GPUs)
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531
+
+# Dry run to see what will be executed
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --dry-run
+
+# Debug mode - run 2-4 tasks with reduced parameters for testing
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --debug
+
+# Single GPU execution
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --num-actors 1
+
+# Maximum utilization with multiple actors per GPU (e.g., 2 actors per GPU)
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --actors-per-gpu 2
+
+# Resume from previous run
+python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --resume
+
+# Custom parameters
+python benchmark/raytune/run_ray_benchmarks.py \
+ --config-run run_20250818_203531 \
+ --max-new-tokens 200 \
+ --max-context-length 64000 \
+ --max-requests 50 \
+ --benchmark-results-dir ./my_results
+```
+
+### 2. list_benchmark_tasks.py
+Utility to list and inspect benchmark tasks from optimal configurations.
+
+**Usage:**
+```bash
+# List all tasks in table format
+python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531
+
+# Group by model
+python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531 --group-by model
+
+# Export to CSV
+python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531 --format csv > tasks.csv
+
+# Filter tasks
+python benchmark/raytune/list_benchmark_tasks.py \
+ --config-run run_20250818_203531 \
+ --filter-task loogle \
+ --filter-masker adaptive
+
+# Simple format for scripting
+python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531 --format simple
+```
+
+## Performance Tips
+
+1. **Model Loading**: Each task requires fresh model initialization due to unique optimized parameters from Phase 1. Model loading time is tracked and reported.
+
+2. **Actor Count**:
+ - Default: 1 actor per GPU for maximum parallelism
+ - Debug mode: Limited to 2 actors for faster testing
+ - Custom: Use `--num-actors` to control parallelism
+
+3. **Debug Mode**: Use `--debug` for quick testing:
+ - Runs only 2-4 diverse tasks
+ - Reduces max_new_tokens to 20
+ - Limits context length to 4096
+ - Processes only 2 requests per benchmark
+
+4. **Resume**: Completed benchmarks are automatically skipped based on the presence of `metrics.json`.
+
+## Output Structure
+
+Results are saved in the following structure:
+```
+benchmark_results_ray/
+├── meta-llama_Llama-3.1-8B-Instruct/
+│ ├── dense/
+│ │ ├── loogle_longdep_qa/
+│ │ │ ├── raw_results.csv
+│ │ │ ├── metrics.json
+│ │ │ └── micro_metrics.jsonl
+│ │ └── ...
+│ ├── sink_local_random_sampling/
+│ │ └── ...
+│ └── ...
+└── ...
+```
+
+## Monitoring Progress
+
+The Ray runner provides real-time progress updates:
+- Current task completion status with execution time
+- Model loading time for each task
+- Average model load time statistics
+- Estimated time remaining (ETA)
+- Tasks per second throughput
+- Total execution and model loading time summary
\ No newline at end of file
diff --git a/benchmark/raytune/advanced_benchmark_analysis.py b/benchmark/raytune/advanced_benchmark_analysis.py
new file mode 100644
index 00000000..45355877
--- /dev/null
+++ b/benchmark/raytune/advanced_benchmark_analysis.py
@@ -0,0 +1,596 @@
+#!/usr/bin/env python3
+"""
+Advanced analysis and visualization for sparse attention benchmarks.
+
+This script provides:
+- Statistical analysis with confidence intervals
+- Pareto frontier analysis
+- Performance regression analysis
+- Detailed breakdowns by metric type
+- Export capabilities for publication-ready figures
+
+Usage:
+ python advanced_benchmark_analysis.py --results-dir benchmark_results_ray
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional, Any
+from collections import defaultdict
+import pandas as pd
+import numpy as np
+from scipy import stats
+from sklearn.preprocessing import StandardScaler
+from sklearn.decomposition import PCA
+
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import plotly.express as px
+import plotly.figure_factory as ff
+
+
+class AdvancedBenchmarkAnalyzer:
+ """Advanced analysis for sparse attention benchmarks."""
+
+ def __init__(self, results_dir: Path):
+ self.results_dir = results_dir
+ self.data = self._load_comprehensive_results()
+ self._compute_statistics()
+ self._setup_professional_styling()
+
+ def _setup_professional_styling(self):
+ """Setup publication-quality styling."""
+ # Professional color palette
+ self.colors = px.colors.qualitative.D3
+ self.config_colors = {
+ 'dense': '#1f77b4',
+ 'sink_local_random_sampling': '#ff7f0e',
+ 'sink_local_oracle_top_k_adaptive_sampling': '#2ca02c',
+ 'sink_local_hash_attention_top_k_adaptive_sampling': '#d62728',
+ 'sink_local_oracle_top_p': '#9467bd',
+ 'sink_local_oracle_top_k': '#8c564b',
+ 'sink_local_hash_attention_top_k': '#e377c2',
+ 'sink_local_magic_pig': '#7f7f7f',
+ }
+
+ self.layout_template = go.layout.Template(
+ layout=go.Layout(
+ font=dict(family="Arial, sans-serif", size=14),
+ title_font=dict(size=22, family="Arial Black, sans-serif"),
+ hovermode='closest',
+ plot_bgcolor='rgba(240,240,240,0.1)',
+ paper_bgcolor='white',
+ xaxis=dict(
+ showgrid=True,
+ gridwidth=1,
+ gridcolor='rgba(128,128,128,0.2)',
+ showline=True,
+ linewidth=2,
+ linecolor='black',
+ zeroline=False
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridwidth=1,
+ gridcolor='rgba(128,128,128,0.2)',
+ showline=True,
+ linewidth=2,
+ linecolor='black',
+ zeroline=False
+ ),
+ margin=dict(l=100, r=100, t=120, b=100)
+ )
+ )
+
+ def _load_comprehensive_results(self) -> pd.DataFrame:
+ """Load results with detailed metrics and metadata."""
+ results = []
+
+ for model_dir in self.results_dir.iterdir():
+ if not model_dir.is_dir():
+ continue
+
+ model_name = model_dir.name
+
+ for config_dir in model_dir.iterdir():
+ if not config_dir.is_dir():
+ continue
+
+ config_name = config_dir.name
+
+ for task_dir in config_dir.iterdir():
+ if not task_dir.is_dir():
+ continue
+
+ task_name = task_dir.name
+
+ # Load all available data
+ result = self._load_task_result(
+ model_name, config_name, task_name, task_dir
+ )
+
+ if result:
+ results.append(result)
+
+ df = pd.DataFrame(results)
+
+ # Add derived metrics
+ if not df.empty:
+ df['efficiency_score'] = df.apply(
+ lambda x: x['overall_score'] / x['density'] if x['density'] > 0 else 0,
+ axis=1
+ )
+
+ # Normalize scores for comparison
+ if 'overall_score' in df.columns:
+ df['normalized_score'] = (df['overall_score'] - df['overall_score'].min()) / \
+ (df['overall_score'].max() - df['overall_score'].min())
+
+ return df
+
+ def _load_task_result(self, model: str, config: str, task: str,
+ task_dir: Path) -> Optional[Dict]:
+ """Load comprehensive result data for a single task."""
+ result = {
+ 'model': model,
+ 'config': config,
+ 'task': task,
+ 'config_type': 'sparse' if config != 'dense' else 'dense'
+ }
+
+ # Load metrics
+ metrics_file = task_dir / "metrics.json"
+ if not metrics_file.exists():
+ return None
+
+ with open(metrics_file, 'r') as f:
+ metrics = json.load(f)
+
+ result['overall_score'] = metrics.get('overall_score', 0)
+ result['total_samples'] = metrics.get('summary', {}).get('total_samples', 0)
+
+ # Extract all individual metrics
+ task_scores = metrics.get('task_scores', {})
+ if task_scores:
+ first_task = list(task_scores.values())[0]
+ for metric, value in first_task.items():
+ result[f'metric_{metric}'] = value
+
+ # Load micro metrics for sparse configs
+ if config != 'dense':
+ micro_stats = self._compute_micro_statistics(task_dir / "micro_metrics.jsonl")
+ result.update(micro_stats)
+ else:
+ # Dense baseline values
+ result['density'] = 1.0
+ result['attention_error'] = 0.0
+ result['density_std'] = 0.0
+ result['error_std'] = 0.0
+
+ return result
+
+ def _compute_micro_statistics(self, micro_metrics_file: Path) -> Dict:
+ """Compute statistics from micro metrics."""
+ stats = {
+ 'density': np.nan,
+ 'attention_error': np.nan,
+ 'density_std': np.nan,
+ 'error_std': np.nan,
+ 'density_percentiles': {},
+ 'error_percentiles': {}
+ }
+
+ if not micro_metrics_file.exists():
+ return stats
+
+ densities = []
+ errors = []
+
+ with open(micro_metrics_file, 'r') as f:
+ for line in f:
+ try:
+ entry = json.loads(line.strip())
+ if entry.get("metric") == "research_attention_density":
+ densities.append(entry["value"])
+ elif entry.get("metric") == "research_attention_output_error":
+ errors.append(entry["value"])
+ except:
+ continue
+
+ if densities:
+ stats['density'] = np.mean(densities)
+ stats['density_std'] = np.std(densities)
+ stats['density_percentiles'] = {
+ 'p25': np.percentile(densities, 25),
+ 'p50': np.percentile(densities, 50),
+ 'p75': np.percentile(densities, 75)
+ }
+
+ if errors:
+ stats['attention_error'] = np.mean(errors)
+ stats['error_std'] = np.std(errors)
+ stats['error_percentiles'] = {
+ 'p25': np.percentile(errors, 25),
+ 'p50': np.percentile(errors, 50),
+ 'p75': np.percentile(errors, 75)
+ }
+
+ return stats
+
+ def _compute_statistics(self):
+ """Compute statistical summaries and comparisons."""
+ if self.data.empty:
+ return
+
+ # Compute config-level statistics
+ self.config_stats = self.data.groupby('config').agg({
+ 'overall_score': ['mean', 'std', 'count'],
+ 'density': ['mean', 'std'],
+ 'attention_error': ['mean', 'std']
+ }).round(4)
+
+ # Compute task-level statistics
+ self.task_stats = self.data.groupby('task').agg({
+ 'overall_score': ['mean', 'std', 'count']
+ }).round(4)
+
+ # Statistical comparisons vs dense baseline
+ self.comparisons = self._compute_statistical_comparisons()
+
+ def _compute_statistical_comparisons(self) -> pd.DataFrame:
+ """Compute statistical comparisons against dense baseline."""
+ comparisons = []
+
+ dense_data = self.data[self.data['config'] == 'dense']
+ if dense_data.empty:
+ return pd.DataFrame()
+
+ for config in self.data['config'].unique():
+ if config == 'dense':
+ continue
+
+ config_data = self.data[self.data['config'] == config]
+
+ # Perform t-test for each task
+ for task in self.data['task'].unique():
+ dense_task = dense_data[dense_data['task'] == task]['overall_score']
+ config_task = config_data[config_data['task'] == task]['overall_score']
+
+ if len(dense_task) > 0 and len(config_task) > 0:
+ t_stat, p_value = stats.ttest_ind(dense_task, config_task)
+
+ comparisons.append({
+ 'config': config,
+ 'task': task,
+ 'dense_mean': dense_task.mean(),
+ 'config_mean': config_task.mean(),
+ 'difference': config_task.mean() - dense_task.mean(),
+ 'percent_change': ((config_task.mean() - dense_task.mean()) / dense_task.mean() * 100),
+ 't_statistic': t_stat,
+ 'p_value': p_value,
+ 'significant': p_value < 0.05
+ })
+
+ return pd.DataFrame(comparisons)
+
+ def create_pareto_frontier(self) -> go.Figure:
+ """Create Pareto frontier plot for density vs performance."""
+ sparse_data = self.data[self.data['config'] != 'dense'].copy()
+
+ # Compute Pareto frontier
+ pareto_points = []
+ sorted_data = sparse_data.sort_values('density')
+
+ max_score = -np.inf
+ for _, row in sorted_data.iterrows():
+ if row['overall_score'] >= max_score:
+ max_score = row['overall_score']
+ pareto_points.append(row)
+
+ pareto_df = pd.DataFrame(pareto_points)
+
+ # Create figure
+ fig = go.Figure()
+
+ # Add all points
+ for config in sparse_data['config'].unique():
+ config_data = sparse_data[sparse_data['config'] == config]
+
+ fig.add_trace(go.Scatter(
+ x=config_data['density'],
+ y=config_data['overall_score'],
+ mode='markers',
+ marker=dict(
+ size=12,
+ color=self.config_colors.get(config, '#000000'),
+ line=dict(width=2, color='white'),
+ opacity=0.8
+ ),
+ name=config.replace('_', ' ').title(),
+ text=config_data['task'],
+ hovertemplate='%{text}
Density: %{x:.3f}
Score: %{y:.3f}'
+ ))
+
+ # Add Pareto frontier
+ if not pareto_df.empty:
+ fig.add_trace(go.Scatter(
+ x=pareto_df['density'],
+ y=pareto_df['overall_score'],
+ mode='lines',
+ line=dict(color='red', width=3, dash='dash'),
+ name='Pareto Frontier',
+ showlegend=True
+ ))
+
+ # Add dense baseline
+ dense_score = self.data[self.data['config'] == 'dense']['overall_score'].mean()
+ fig.add_hline(
+ y=dense_score,
+ line_dash="dot",
+ line_color="black",
+ annotation_text="Dense Baseline",
+ annotation_position="right"
+ )
+
+ fig.update_layout(
+ template=self.layout_template,
+ title='Pareto Frontier: Density vs Performance Trade-off',
+ xaxis_title='Attention Density',
+ yaxis_title='Overall Performance Score',
+ height=700,
+ width=1000,
+ xaxis=dict(range=[0, 1.05]),
+ legend=dict(
+ yanchor="bottom",
+ y=0.01,
+ xanchor="right",
+ x=0.99,
+ bgcolor="rgba(255,255,255,0.8)",
+ bordercolor="black",
+ borderwidth=1
+ )
+ )
+
+ return fig
+
+ def create_statistical_comparison_plot(self) -> go.Figure:
+ """Create plot showing statistical comparisons vs baseline."""
+ if self.comparisons.empty:
+ return go.Figure()
+
+ # Aggregate by config
+ config_comparison = self.comparisons.groupby('config').agg({
+ 'percent_change': 'mean',
+ 'significant': 'sum',
+ 'task': 'count'
+ }).reset_index()
+
+ config_comparison.columns = ['config', 'avg_percent_change', 'num_significant', 'num_tasks']
+ config_comparison['percent_significant'] = config_comparison['num_significant'] / config_comparison['num_tasks'] * 100
+
+ # Create figure
+ fig = go.Figure()
+
+ # Add bars
+ fig.add_trace(go.Bar(
+ x=config_comparison['config'],
+ y=config_comparison['avg_percent_change'],
+ marker_color=[self.config_colors.get(c, '#000000') for c in config_comparison['config']],
+ text=config_comparison['percent_significant'].round(1),
+ texttemplate='%{text}% significant',
+ textposition='outside',
+ hovertemplate='Config: %{x}
Avg Change: %{y:.1f}%
Significant Tests: %{text}'
+ ))
+
+ # Add significance threshold
+ fig.add_hline(y=0, line_dash="solid", line_color="black", line_width=2)
+
+ fig.update_layout(
+ template=self.layout_template,
+ title='Performance Change vs Dense Baseline
Percentage of statistically significant differences shown',
+ xaxis_title='Sparse Attention Configuration',
+ yaxis_title='Average Performance Change (%)',
+ height=600,
+ xaxis_tickangle=-45,
+ showlegend=False
+ )
+
+ return fig
+
+ def create_comprehensive_dashboard(self, output_dir: str = "benchmark_analysis"):
+ """Create comprehensive analysis dashboard with multiple views."""
+ output_path = Path(output_dir)
+ output_path.mkdir(exist_ok=True)
+
+ # Create main dashboard
+ fig = make_subplots(
+ rows=3, cols=2,
+ subplot_titles=(
+ 'Pareto Frontier Analysis',
+ 'Statistical Comparisons',
+ 'Performance Distribution by Config',
+ 'Error vs Density Correlation',
+ 'Task Difficulty Analysis',
+ 'Efficiency Scores'
+ ),
+ row_heights=[0.35, 0.35, 0.3],
+ vertical_spacing=0.08,
+ horizontal_spacing=0.1,
+ specs=[
+ [{"type": "scatter"}, {"type": "bar"}],
+ [{"type": "violin"}, {"type": "scatter"}],
+ [{"type": "bar"}, {"type": "scatter"}]
+ ]
+ )
+
+ # 1. Pareto Frontier
+ pareto = self.create_pareto_frontier()
+ for trace in pareto.data:
+ fig.add_trace(trace, row=1, col=1)
+
+ # 2. Statistical Comparisons
+ stats_comp = self.create_statistical_comparison_plot()
+ for trace in stats_comp.data:
+ fig.add_trace(trace, row=1, col=2)
+
+ # 3. Performance Distribution
+ sparse_data = self.data[self.data['config'] != 'dense']
+ for config in sparse_data['config'].unique():
+ config_data = sparse_data[sparse_data['config'] == config]
+ fig.add_trace(go.Violin(
+ y=config_data['overall_score'],
+ name=config.replace('_', ' ').title(),
+ marker_color=self.config_colors.get(config, '#000000'),
+ box_visible=True,
+ meanline_visible=True
+ ), row=2, col=1)
+
+ # 4. Error vs Density
+ if 'attention_error' in sparse_data.columns:
+ fig.add_trace(go.Scatter(
+ x=sparse_data['density'],
+ y=sparse_data['attention_error'],
+ mode='markers',
+ marker=dict(
+ size=8,
+ color=sparse_data['overall_score'],
+ colorscale='Viridis',
+ showscale=True,
+ colorbar=dict(title="Score", x=1.02)
+ ),
+ text=sparse_data['config'],
+ hovertemplate='Config: %{text}
Density: %{x:.3f}
Error: %{y:.3f}'
+ ), row=2, col=2)
+
+ # 5. Task Difficulty
+ task_avg = self.data.groupby('task')['overall_score'].mean().sort_values()
+ fig.add_trace(go.Bar(
+ x=task_avg.values,
+ y=task_avg.index,
+ orientation='h',
+ marker_color='lightblue'
+ ), row=3, col=1)
+
+ # 6. Efficiency Scores
+ if 'efficiency_score' in self.data.columns:
+ efficiency_data = self.data[self.data['efficiency_score'] > 0]
+ for config in efficiency_data['config'].unique():
+ config_data = efficiency_data[efficiency_data['config'] == config]
+ fig.add_trace(go.Scatter(
+ x=config_data['density'],
+ y=config_data['efficiency_score'],
+ mode='markers',
+ name=config,
+ marker=dict(size=10)
+ ), row=3, col=2)
+
+ # Update layout
+ fig.update_layout(
+ template=self.layout_template,
+ title_text="Comprehensive Sparse Attention Benchmark Analysis",
+ title_font_size=26,
+ height=1800,
+ showlegend=False
+ )
+
+ # Save dashboard
+ dashboard_file = output_path / "comprehensive_dashboard.html"
+ fig.write_html(
+ dashboard_file,
+ include_plotlyjs='cdn'
+ )
+
+ # Generate additional analyses
+ self._generate_detailed_reports(output_path)
+
+ print(f"Analysis complete. Results saved to: {output_path}")
+
+ return fig
+
+ def _generate_detailed_reports(self, output_path: Path):
+ """Generate detailed reports and additional visualizations."""
+ # 1. Summary statistics
+ summary_stats = pd.DataFrame({
+ 'Configuration': self.config_stats.index,
+ 'Avg Score': self.config_stats[('overall_score', 'mean')],
+ 'Std Score': self.config_stats[('overall_score', 'std')],
+ 'Avg Density': self.config_stats[('density', 'mean')],
+ 'Avg Error': self.config_stats[('attention_error', 'mean')]
+ })
+ summary_stats.to_csv(output_path / "summary_statistics.csv", index=False)
+
+ # 2. Detailed comparisons
+ if not self.comparisons.empty:
+ self.comparisons.to_csv(output_path / "statistical_comparisons.csv", index=False)
+
+ # 3. Best configurations per task
+ best_configs = []
+ for task in self.data['task'].unique():
+ task_data = self.data[self.data['task'] == task]
+ best = task_data.loc[task_data['overall_score'].idxmax()]
+ best_configs.append({
+ 'task': task,
+ 'best_config': best['config'],
+ 'score': best['overall_score'],
+ 'density': best.get('density', 1.0)
+ })
+
+ pd.DataFrame(best_configs).to_csv(output_path / "best_configs_per_task.csv", index=False)
+
+ # 4. Performance correlation matrix
+ if len(self.data.columns) > 10:
+ metric_cols = [col for col in self.data.columns if col.startswith('metric_')]
+ if metric_cols:
+ corr_matrix = self.data[metric_cols].corr()
+
+ fig_corr = go.Figure(data=go.Heatmap(
+ z=corr_matrix.values,
+ x=corr_matrix.columns,
+ y=corr_matrix.columns,
+ colorscale='RdBu',
+ zmid=0,
+ text=np.round(corr_matrix.values, 2),
+ texttemplate='%{text}',
+ textfont={"size": 10}
+ ))
+
+ fig_corr.update_layout(
+ title='Metric Correlation Matrix',
+ height=800,
+ width=800
+ )
+
+ fig_corr.write_html(output_path / "metric_correlations.html")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Advanced analysis of sparse attention benchmarks")
+ parser.add_argument("--results-dir", type=str, default="benchmark_results_ray",
+ help="Directory containing benchmark results")
+ parser.add_argument("--output-dir", type=str, default="benchmark_analysis",
+ help="Output directory for analysis results")
+
+ args = parser.parse_args()
+
+ results_dir = Path(args.results_dir)
+ if not results_dir.exists():
+ print(f"Error: Results directory {results_dir} not found")
+ sys.exit(1)
+
+ # Run analysis
+ analyzer = AdvancedBenchmarkAnalyzer(results_dir)
+ analyzer.create_comprehensive_dashboard(args.output_dir)
+
+ # Print summary
+ print("\nConfiguration Performance Summary:")
+ print(analyzer.config_stats)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/benchmark/raytune/analyze_trials.py b/benchmark/raytune/analyze_trials.py
new file mode 100755
index 00000000..b63c273d
--- /dev/null
+++ b/benchmark/raytune/analyze_trials.py
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+"""
+Utility script to analyze Ray Tune trial results from Phase 1.
+
+This script demonstrates how to access and analyze the metadata from Ray trials
+for post-analysis purposes.
+"""
+
+import argparse
+import json
+import pandas as pd
+from pathlib import Path
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+
+def load_trial_data(optimal_configs_dir: Path):
+ """Load all trial data from the optimal configs directory."""
+ all_trials = []
+
+ # Find all trial JSON files
+ trial_files = list(optimal_configs_dir.glob("*_trials.json"))
+
+ for trial_file in trial_files:
+ with open(trial_file, 'r') as f:
+ data = json.load(f)
+
+ # Add metadata to each trial
+ for trial in data['trials']:
+ trial['model'] = data['model']
+ trial['task'] = data['task']
+ trial['masker_name'] = data['masker_name']
+ trial['objective_function'] = data['objective_function']
+ trial['is_best'] = trial['trial_id'] == data['best_trial_id']
+
+ all_trials.extend(data['trials'])
+
+ # Also check if CSV exists
+ csv_path = Path(data.get('analysis_dataframe_path', ''))
+ if csv_path.exists():
+ print(f" → Found detailed analysis CSV: {csv_path}")
+
+ return pd.DataFrame(all_trials)
+
+
+def analyze_objective_performance(df: pd.DataFrame):
+ """Analyze performance across different objective functions."""
+ print("\n" + "="*60)
+ print("OBJECTIVE FUNCTION ANALYSIS")
+ print("="*60)
+
+ # Group by objective function
+ obj_stats = df.groupby('objective_function')['score'].agg(['mean', 'min', 'max', 'count'])
+ print("\nScore statistics by objective function:")
+ print(obj_stats)
+
+ # Best trials only
+ best_trials = df[df['is_best']]
+ best_by_obj = best_trials.groupby('objective_function')['score'].agg(['mean', 'count'])
+ print("\nBest trial scores by objective function:")
+ print(best_by_obj)
+
+
+def analyze_hyperparameter_impact(df: pd.DataFrame):
+ """Analyze impact of different hyperparameters on scores."""
+ print("\n" + "="*60)
+ print("HYPERPARAMETER IMPACT ANALYSIS")
+ print("="*60)
+
+ # Extract hyperparameters from config
+ hyperparam_cols = []
+ for idx, row in df.iterrows():
+ config = row['config']
+ for key, value in config.items():
+ if key not in hyperparam_cols:
+ hyperparam_cols.append(key)
+ df.loc[idx, f'hp_{key}'] = value
+ else:
+ df.loc[idx, f'hp_{key}'] = value
+
+ # Analyze each hyperparameter's impact
+ for hp in hyperparam_cols:
+ hp_col = f'hp_{hp}'
+ if hp_col in df.columns:
+ print(f"\nImpact of {hp}:")
+ hp_stats = df.groupby(hp_col)['score'].agg(['mean', 'count', 'std'])
+ print(hp_stats.sort_values('mean').head(10))
+
+
+def analyze_sparsity_achievement(optimal_configs_dir: Path):
+ """Analyze how well different configs achieve target sparsity."""
+ print("\n" + "="*60)
+ print("SPARSITY ACHIEVEMENT ANALYSIS")
+ print("="*60)
+
+ # Load optimal configs to get actual densities
+ config_files = list(optimal_configs_dir.glob("*.json"))
+ config_files = [f for f in config_files if not f.name.endswith("_trials.json")]
+
+ sparsity_data = []
+ for config_file in config_files:
+ with open(config_file, 'r') as f:
+ config = json.load(f)
+
+ if 'score' in config:
+ sparsity_data.append({
+ 'model': config['model'],
+ 'task': config['task'],
+ 'masker_name': config['masker_name'],
+ 'score': config['score'],
+ 'num_trials': config.get('num_trials', 0)
+ })
+
+ if sparsity_data:
+ sparsity_df = pd.DataFrame(sparsity_data)
+ print("\nConfiguration performance summary:")
+ print(sparsity_df.groupby('masker_name')['score'].agg(['mean', 'min', 'max', 'count']))
+
+
+def plot_trial_scores(df: pd.DataFrame, output_dir: Path):
+ """Create visualizations of trial scores."""
+ output_dir.mkdir(exist_ok=True)
+
+ # Plot 1: Score distribution by objective function
+ plt.figure(figsize=(10, 6))
+ sns.boxplot(data=df, x='objective_function', y='score')
+ plt.xticks(rotation=45)
+ plt.title('Score Distribution by Objective Function')
+ plt.tight_layout()
+ plt.savefig(output_dir / 'scores_by_objective.png')
+ plt.close()
+
+ # Plot 2: Score vs trial for each task
+ tasks = df['task'].unique()
+ fig, axes = plt.subplots(len(tasks), 1, figsize=(10, 4*len(tasks)))
+ if len(tasks) == 1:
+ axes = [axes]
+
+ for ax, task in zip(axes, tasks):
+ task_df = df[df['task'] == task]
+ for masker in task_df['masker_name'].unique():
+ masker_df = task_df[task_df['masker_name'] == masker]
+ ax.scatter(range(len(masker_df)), masker_df['score'], label=masker[:20], alpha=0.6)
+ ax.set_title(f'Trial Scores for {task}')
+ ax.set_xlabel('Trial Number')
+ ax.set_ylabel('Score')
+ ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
+
+ plt.tight_layout()
+ plt.savefig(output_dir / 'trial_progression.png')
+ plt.close()
+
+ print(f"\nPlots saved to {output_dir}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Analyze Ray Tune trial results")
+ parser.add_argument("--optimal-configs-dir", default="./optimal_configs",
+ help="Directory containing optimal configs and trial data")
+ parser.add_argument("--output-dir", default="./trial_analysis",
+ help="Directory for output plots and analysis")
+ parser.add_argument("--run", type=str,
+ help="Specific run directory to analyze (e.g., 'run_20240315_143022')")
+ args = parser.parse_args()
+
+ base_optimal_configs_dir = Path(args.optimal_configs_dir)
+ output_dir = Path(args.output_dir)
+
+ if not base_optimal_configs_dir.exists():
+ print(f"Error: Directory {base_optimal_configs_dir} does not exist")
+ return
+
+ # Handle timestamped directories
+ if args.run:
+ optimal_configs_dir = base_optimal_configs_dir / args.run
+ if not optimal_configs_dir.exists():
+ print(f"Error: Specified run {optimal_configs_dir} does not exist")
+ return
+ else:
+ # Find the most recent run_* directory
+ run_dirs = sorted([d for d in base_optimal_configs_dir.glob("run_*") if d.is_dir()])
+ if run_dirs:
+ optimal_configs_dir = run_dirs[-1] # Most recent
+ print(f"Using most recent run: {optimal_configs_dir.name}")
+ else:
+ # Fallback to base directory for backward compatibility
+ optimal_configs_dir = base_optimal_configs_dir
+
+ print(f"Loading trial data from {optimal_configs_dir}")
+ df = load_trial_data(optimal_configs_dir)
+
+ if df.empty:
+ print("No trial data found!")
+ return
+
+ print(f"\nLoaded {len(df)} trials")
+ print(f"Models: {df['model'].unique()}")
+ print(f"Tasks: {df['task'].unique()}")
+ print(f"Masker types: {df['masker_name'].unique()[:5]}...") # Show first 5
+ print(f"Objective functions: {df['objective_function'].unique()}")
+
+ # Run analyses
+ analyze_objective_performance(df)
+ analyze_hyperparameter_impact(df)
+ analyze_sparsity_achievement(optimal_configs_dir)
+
+ # Create plots
+ try:
+ plot_trial_scores(df, output_dir)
+ except Exception as e:
+ print(f"Warning: Could not create plots: {e}")
+
+ # Save combined dataframe
+ output_file = output_dir / "all_trials_data.csv"
+ output_dir.mkdir(exist_ok=True)
+ df.to_csv(output_file, index=False)
+ print(f"\nAll trial data saved to {output_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/raytune/create_specific_plots.py b/benchmark/raytune/create_specific_plots.py
new file mode 100755
index 00000000..8bd8152c
--- /dev/null
+++ b/benchmark/raytune/create_specific_plots.py
@@ -0,0 +1,396 @@
+#!/usr/bin/env python3
+"""
+Create specific plots for sparse attention benchmark results.
+
+Plot 1: Density vs Performance per task (subplots)
+Plot 2: Dashboard with task-based comparisons
+"""
+
+import argparse
+import json
+import sys
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional
+import numpy as np
+import pandas as pd
+
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import plotly.express as px
+
+
+def load_benchmark_data(results_dir: Path) -> pd.DataFrame:
+ """Load benchmark results into a DataFrame."""
+ results = []
+
+ for model_dir in results_dir.iterdir():
+ if not model_dir.is_dir():
+ continue
+
+ model_name = model_dir.name
+
+ for config_dir in model_dir.iterdir():
+ if not config_dir.is_dir():
+ continue
+
+ config_name = config_dir.name
+
+ for task_dir in config_dir.iterdir():
+ if not task_dir.is_dir():
+ continue
+
+ task_name = task_dir.name
+
+ # Load metrics
+ metrics_file = task_dir / "metrics.json"
+ if not metrics_file.exists():
+ continue
+
+ with open(metrics_file, 'r') as f:
+ metrics = json.load(f)
+
+ result = {
+ 'model': model_name,
+ 'config': config_name,
+ 'task': task_name,
+ 'performance': metrics.get('overall_score', 0)
+ }
+
+ # Load density and error for sparse configs
+ if config_name != 'dense':
+ micro_metrics_file = task_dir / "micro_metrics.jsonl"
+ if micro_metrics_file.exists():
+ densities = []
+ errors = []
+
+ with open(micro_metrics_file, 'r') as f:
+ for line in f:
+ try:
+ entry = json.loads(line.strip())
+ if entry.get("metric") == "research_attention_density":
+ densities.append(entry["value"])
+ elif entry.get("metric") == "research_attention_output_error":
+ errors.append(entry["value"])
+ except:
+ continue
+
+ result['density'] = np.mean(densities) if densities else np.nan
+ result['error'] = np.mean(errors) if errors else np.nan
+ else:
+ # Dense baseline
+ result['density'] = 1.0
+ result['error'] = 0.0
+
+ results.append(result)
+
+ return pd.DataFrame(results)
+
+
+def create_density_performance_subplots(data: pd.DataFrame, output_path: Path):
+ """Create density vs performance plot with subplots per task."""
+ # Get unique tasks
+ tasks = sorted(data['task'].unique())
+
+ # Define markers for different configs
+ config_markers = {
+ 'dense': 'square',
+ 'sink_local_random_sampling': 'circle',
+ 'sink_local_oracle_top_k_adaptive_sampling': 'diamond',
+ 'sink_local_hash_attention_top_k_adaptive_sampling': 'cross',
+ 'sink_local_oracle_top_p': 'x',
+ 'sink_local_oracle_top_k': 'triangle-up',
+ 'sink_local_hash_attention_top_k': 'triangle-down',
+ 'sink_local_magic_pig': 'star'
+ }
+
+ # Define colors - blue to green gradient (dark to light)
+ config_colors = {
+ 'dense': '#08519c', # Dark blue
+ 'sink_local_random_sampling': '#2171b5', # Medium blue
+ 'sink_local_oracle_top_k_adaptive_sampling': '#4292c6', # Light blue
+ 'sink_local_hash_attention_top_k_adaptive_sampling': '#6baed6', # Lighter blue
+ 'sink_local_oracle_top_p': '#4eb3a6', # Blue-green
+ 'sink_local_oracle_top_k': '#41ab5d', # Medium green
+ 'sink_local_hash_attention_top_k': '#238b45', # Dark green
+ 'sink_local_magic_pig': '#005a32' # Darkest green
+ }
+
+ # Calculate grid size
+ n_tasks = len(tasks)
+ n_cols = 3
+ n_rows = (n_tasks + n_cols - 1) // n_cols
+
+ # Create subplots
+ fig = make_subplots(
+ rows=n_rows,
+ cols=n_cols,
+ subplot_titles=[task.replace('_', ' ').title() for task in tasks],
+ vertical_spacing=0.15,
+ horizontal_spacing=0.1
+ )
+
+ # Add traces for each task
+ for idx, task in enumerate(tasks):
+ row = idx // n_cols + 1
+ col = idx % n_cols + 1
+
+ task_data = data[data['task'] == task]
+
+ # Add scatter points for each config
+ for config in sorted(task_data['config'].unique()):
+ config_data = task_data[task_data['config'] == config]
+
+ fig.add_trace(
+ go.Scatter(
+ x=config_data['density'],
+ y=config_data['performance'],
+ mode='markers',
+ name=config.replace('_', ' ').title() if idx == 0 else None, # Only show legend for first subplot
+ showlegend=(idx == 0),
+ legendgroup=config, # Link legend across all subplots
+ marker=dict(
+ symbol=config_markers.get(config, 'circle'),
+ size=12,
+ color=config_colors.get(config, '#000000'),
+ line=dict(width=1, color='white')
+ ),
+ hovertemplate=f'{config.replace("_", " ").title()}
Density: %{{x:.3f}}
Performance: %{{y:.3f}}'
+ ),
+ row=row,
+ col=col
+ )
+
+ # Update axes
+ fig.update_xaxes(title_text="Density", range=[0, 1.05], row=row, col=col)
+ fig.update_yaxes(title_text="Performance", row=row, col=col)
+
+ # Update layout
+ fig.update_layout(
+ title="Density vs Performance by Task",
+ height=300 * n_rows,
+ width=1400, # Increased width to accommodate legend
+ font=dict(size=12),
+ plot_bgcolor='white',
+ paper_bgcolor='white',
+ legend=dict(
+ orientation="v",
+ yanchor="middle",
+ y=0.5,
+ xanchor="left",
+ x=1.05,
+ bgcolor="rgba(255, 255, 255, 0.8)",
+ bordercolor="rgba(0, 0, 0, 0.2)",
+ borderwidth=1
+ ),
+ margin=dict(r=200) # Add right margin for legend
+ )
+
+ # Ensure subplot titles are horizontal
+ for annotation in fig['layout']['annotations']:
+ annotation['textangle'] = 0
+
+ # Save
+ output_file = output_path / "density_vs_performance_by_task.html"
+ fig.write_html(output_file)
+ print(f"Saved: {output_file}")
+
+
+def create_task_comparison_dashboard(data: pd.DataFrame, output_path: Path):
+ """Create dashboard with three plots comparing metrics across tasks."""
+ # Create subplots
+ fig = make_subplots(
+ rows=3,
+ cols=1,
+ subplot_titles=[
+ "Performance Delta from Dense Baseline",
+ "Average Density by Task",
+ "Average Error by Task"
+ ],
+ vertical_spacing=0.12,
+ row_heights=[0.33, 0.33, 0.34]
+ )
+
+ # Get unique tasks and configs
+ tasks = sorted(data['task'].unique())
+ configs = sorted(data['config'].unique())
+
+ # Define colors - blue to green gradient (dark to light)
+ config_colors = {
+ 'dense': '#08519c', # Dark blue
+ 'sink_local_random_sampling': '#2171b5', # Medium blue
+ 'sink_local_oracle_top_k_adaptive_sampling': '#4292c6', # Light blue
+ 'sink_local_hash_attention_top_k_adaptive_sampling': '#6baed6', # Lighter blue
+ 'sink_local_oracle_top_p': '#4eb3a6', # Blue-green
+ 'sink_local_oracle_top_k': '#41ab5d', # Medium green
+ 'sink_local_hash_attention_top_k': '#238b45', # Dark green
+ 'sink_local_magic_pig': '#005a32' # Darkest green
+ }
+
+ # Get dense baseline performance for each task
+ dense_performance = {}
+ dense_data = data[data['config'] == 'dense']
+ for task in tasks:
+ task_data = dense_data[dense_data['task'] == task]
+ dense_performance[task] = task_data['performance'].mean() if not task_data.empty else 0
+
+ # Plot 1: Performance difference from dense baseline
+ for config in configs:
+ if config == 'dense':
+ continue # Skip dense since we're showing delta from dense
+
+ config_data = data[data['config'] == config]
+
+ # Calculate mean performance difference per task
+ task_performance = []
+ for task in tasks:
+ task_data = config_data[config_data['task'] == task]
+ perf = task_data['performance'].mean() if not task_data.empty else 0
+ # Calculate difference from dense baseline
+ perf_diff = perf - dense_performance.get(task, 0)
+ task_performance.append(perf_diff)
+
+ fig.add_trace(
+ go.Bar(
+ name=config.replace('_', ' ').title(),
+ x=tasks,
+ y=task_performance,
+ marker_color=config_colors.get(config, '#000000'),
+ hovertemplate=f'{config.replace("_", " ").title()}
Task: %{{x}}
Performance Delta: %{{y:.3f}}',
+ legendgroup=config # Link legend across all plots
+ ),
+ row=1,
+ col=1
+ )
+
+ # Plot 2: Density by task (only sparse configs)
+ sparse_configs = [c for c in configs if c != 'dense']
+ for config in sparse_configs:
+ config_data = data[data['config'] == config]
+
+ # Calculate mean density per task
+ task_density = []
+ for task in tasks:
+ task_data = config_data[config_data['task'] == task]
+ density = task_data['density'].mean() if not task_data.empty else np.nan
+ task_density.append(density)
+
+ fig.add_trace(
+ go.Bar(
+ name=config.replace('_', ' ').title(),
+ x=tasks,
+ y=task_density,
+ marker_color=config_colors.get(config, '#000000'),
+ hovertemplate=f'{config.replace("_", " ").title()}
Task: %{{x}}
Density: %{{y:.3f}}',
+ showlegend=False, # Use same legend as plot 1
+ legendgroup=config # Link legend across all plots
+ ),
+ row=2,
+ col=1
+ )
+
+ # Plot 3: Error by task (only sparse configs)
+ for config in sparse_configs:
+ config_data = data[data['config'] == config]
+
+ # Calculate mean error per task
+ task_error = []
+ for task in tasks:
+ task_data = config_data[config_data['task'] == task]
+ error = task_data['error'].mean() if not task_data.empty else np.nan
+ task_error.append(error)
+
+ fig.add_trace(
+ go.Bar(
+ name=config.replace('_', ' ').title(),
+ x=tasks,
+ y=task_error,
+ marker_color=config_colors.get(config, '#000000'),
+ hovertemplate=f'{config.replace("_", " ").title()}
Task: %{{x}}
Error: %{{y:.3f}}',
+ showlegend=False, # Use same legend as plot 1
+ legendgroup=config # Link legend across all plots
+ ),
+ row=3,
+ col=1
+ )
+
+ # Update axes
+ fig.update_xaxes(title_text="Task", row=3, col=1)
+ fig.update_xaxes(tickangle=0)
+
+ fig.update_yaxes(title_text="Performance Delta", row=1, col=1)
+ fig.update_yaxes(title_text="Density", row=2, col=1)
+ fig.update_yaxes(title_text="Error", row=3, col=1)
+
+ # Update layout
+ fig.update_layout(
+ title="Task-wise Comparison Dashboard",
+ height=1200,
+ width=1200,
+ barmode='group',
+ font=dict(size=12),
+ plot_bgcolor='white',
+ paper_bgcolor='white',
+ legend=dict(
+ orientation="v",
+ yanchor="top",
+ y=0.98,
+ xanchor="left",
+ x=1.02,
+ bgcolor="rgba(255, 255, 255, 0.8)",
+ bordercolor="rgba(0, 0, 0, 0.2)",
+ borderwidth=1
+ ),
+ margin=dict(r=200) # Add right margin for legend
+ )
+
+ # Ensure subplot titles are horizontal
+ for annotation in fig['layout']['annotations']:
+ annotation['textangle'] = 0
+
+ # Save
+ output_file = output_path / "task_comparison_dashboard.html"
+ fig.write_html(output_file)
+ print(f"Saved: {output_file}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Create specific plots for benchmark results")
+ parser.add_argument("--results-dir", type=str, default="benchmark_results_ray",
+ help="Directory containing benchmark results")
+ parser.add_argument("--output-dir", type=str, default="plots",
+ help="Output directory for plots")
+
+ args = parser.parse_args()
+
+ results_dir = Path(args.results_dir)
+ if not results_dir.exists():
+ print(f"Error: Results directory {results_dir} not found")
+ sys.exit(1)
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(exist_ok=True)
+
+ # Load data
+ print("Loading benchmark data...")
+ data = load_benchmark_data(results_dir)
+
+ if data.empty:
+ print("No data found!")
+ sys.exit(1)
+
+ print(f"Loaded {len(data)} benchmark results")
+
+ # Create plots
+ print("\nCreating density vs performance subplots...")
+ create_density_performance_subplots(data, output_dir)
+
+ print("\nCreating task comparison dashboard...")
+ create_task_comparison_dashboard(data, output_dir)
+
+ print("\nAll plots created successfully!")
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/benchmark/raytune/generic_config_optimizer.py b/benchmark/raytune/generic_config_optimizer.py
new file mode 100755
index 00000000..764d5c7c
--- /dev/null
+++ b/benchmark/raytune/generic_config_optimizer.py
@@ -0,0 +1,340 @@
+"""Task-specific config optimizer for sparse attention configs.
+
+This module provides optimizers that work with masker configs that define their own
+search spaces, enabling per-task optimization and caching.
+
+Key Features:
+- Each masker config defines its own get_search_space() method
+- Per-task optimization and caching
+- Support for composite configs (ResearchAttentionConfig with multiple maskers)
+- Task-specific parameter tuning
+- Benchmark integration
+"""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional, Type, List
+
+from ray import tune
+
+
+class SparseConfigOptimizer(ABC):
+ """Base class for sparse attention config optimizers."""
+
+ @abstractmethod
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ """Create Ray Tune search space for the config type and task."""
+ pass
+
+ @abstractmethod
+ def create_config_from_params(self, params: Dict[str, Any]) -> Any:
+ """Create config instance from optimization parameters."""
+ pass
+
+ @abstractmethod
+ def optimize_for_task(self, task_name: str, num_samples: int = 10) -> Any:
+ """Run optimization for a specific task and return best config."""
+ pass
+
+ @property
+ @abstractmethod
+ def config_type_name(self) -> str:
+ """Get the name of the config type for caching."""
+ pass
+
+
+class CompositeConfigOptimizer(SparseConfigOptimizer):
+ """Optimizer for composite configs like ResearchAttentionConfig with multiple maskers."""
+
+ def __init__(self, masker_configs: List[Type], config_name: str, overrides: Optional[Dict[str, Any]] = None):
+ """Initialize composite optimizer.
+
+ Args:
+ masker_configs: List of masker config classes to optimize
+ config_name: Name for caching purposes
+ overrides: Optional manual overrides for specific fields (prefixed by masker name)
+ """
+ self.masker_configs = masker_configs
+ self._config_name = config_name
+ self.overrides = overrides or {}
+ self.logger = logging.getLogger(__name__)
+
+ # Validate that all masker configs have get_search_space method
+ for masker_class in masker_configs:
+ if not hasattr(masker_class, 'get_search_space'):
+ raise ValueError(f"Masker config {masker_class.__name__} must implement get_search_space() method")
+
+ # Cache for task-specific best configs
+ self.task_cache = {}
+
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ """Create combined search space from all masker configs for a specific task."""
+ combined_space = {}
+
+ for masker_class in self.masker_configs:
+ masker_name = masker_class.__name__.lower().replace('config', '')
+
+ # Get search space from the masker config class
+ masker_space = masker_class.get_search_space(task_name)
+
+ # Apply any overrides for this masker
+ prefix = f"{masker_name}_"
+ for key, value in self.overrides.items():
+ if key.startswith(prefix):
+ param_name = key[len(prefix):]
+ masker_space[param_name] = value
+
+ # Prefix each parameter with masker name to avoid conflicts
+ for param_name, param_space in masker_space.items():
+ combined_space[f"{masker_name}_{param_name}"] = param_space
+
+ return combined_space
+
+ def create_config_from_params(self, params: Dict[str, Any]) -> Any:
+ """Create ResearchAttentionConfig from optimization parameters."""
+ from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
+
+ masker_instances = []
+
+ for masker_class in self.masker_configs:
+ masker_name = masker_class.__name__.lower().replace('config', '')
+
+ # Extract parameters for this masker
+ masker_params = {}
+ prefix = f"{masker_name}_"
+ for param_name, param_value in params.items():
+ if param_name.startswith(prefix):
+ masker_params[param_name[len(prefix):]] = param_value
+
+ # Create masker instance
+ masker_instance = masker_class(**masker_params)
+ masker_instances.append(masker_instance)
+
+ return ResearchAttentionConfig(masker_configs=masker_instances)
+
+ def optimize_for_task(self, task_name: str, num_samples: int = 10) -> Any:
+ """Run optimization for a specific task and return best config."""
+ # Check cache first
+ cache_key = f"{task_name}_{num_samples}"
+ if cache_key in self.task_cache:
+ self.logger.info(f"Using cached best config for task {task_name}")
+ return self.task_cache[cache_key]
+
+ self.logger.info(f"Starting optimization for task {task_name} with {num_samples} samples")
+
+ # Create search space for this task
+ search_space = self.create_search_space(task_name)
+
+ # Run Ray Tune optimization
+ analysis = tune.run(
+ self._objective_function,
+ config=search_space,
+ num_samples=num_samples,
+ resources_per_trial={"cpu": 1, "gpu": 0.25},
+ name=f"optimize_{self._config_name}_{task_name}",
+ local_dir="./ray_results"
+ )
+
+ # Get best config
+ best_trial = analysis.get_best_trial("score", "max", "last")
+ best_config = self.create_config_from_params(best_trial.config)
+
+ # Cache the result
+ self.task_cache[cache_key] = best_config
+
+ self.logger.info(f"Best config for {task_name}: {best_config}")
+ return best_config
+
+ def _objective_function(self, config: Dict[str, Any]) -> Dict[str, float]:
+ """Objective function for Ray Tune optimization."""
+ # Create config instance
+ attention_config = self.create_config_from_params(config)
+
+ # TODO: Integrate with benchmark runner
+ # For now, return random score - replace with actual benchmark
+ import random
+ score = random.random()
+
+ return {"score": score}
+
+ @property
+ def config_type_name(self) -> str:
+ """Get the name of the config type for caching."""
+ return self._config_name
+
+
+class SingleConfigOptimizer(SparseConfigOptimizer):
+ """Optimizer for single masker configs."""
+
+ def __init__(self, config_class: Type, config_name: str, overrides: Optional[Dict[str, Any]] = None):
+ """Initialize single config optimizer.
+
+ Args:
+ config_class: The masker config class to optimize
+ config_name: Name for caching purposes
+ overrides: Optional manual overrides for specific fields
+ """
+ self.config_class = config_class
+ self._config_name = config_name
+ self.overrides = overrides or {}
+ self.logger = logging.getLogger(__name__)
+
+ # Validate that the config class has get_search_space method
+ if not hasattr(config_class, 'get_search_space'):
+ raise ValueError(f"Config class {config_class.__name__} must implement get_search_space() method")
+
+ # Cache for task-specific best configs
+ self.task_cache = {}
+
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ """Create search space from the config class for a specific task."""
+ search_space = self.config_class.get_search_space(task_name)
+
+ # Apply any overrides
+ for key, value in self.overrides.items():
+ search_space[key] = value
+
+ return search_space
+
+ def create_config_from_params(self, params: Dict[str, Any]) -> Any:
+ """Create config instance from optimization parameters."""
+ return self.config_class(**params)
+
+ def optimize_for_task(self, task_name: str, num_samples: int = 10) -> Any:
+ """Run optimization for a specific task and return best config."""
+ # Check cache first
+ cache_key = f"{task_name}_{num_samples}"
+ if cache_key in self.task_cache:
+ self.logger.info(f"Using cached best config for task {task_name}")
+ return self.task_cache[cache_key]
+
+ self.logger.info(f"Starting optimization for task {task_name} with {num_samples} samples")
+
+ # Create search space for this task
+ search_space = self.create_search_space(task_name)
+
+ # Run Ray Tune optimization
+ analysis = tune.run(
+ self._objective_function,
+ config=search_space,
+ num_samples=num_samples,
+ resources_per_trial={"cpu": 1, "gpu": 0.25},
+ name=f"optimize_{self._config_name}_{task_name}",
+ local_dir="./ray_results"
+ )
+
+ # Get best config
+ best_trial = analysis.get_best_trial("score", "max", "last")
+ best_config = self.create_config_from_params(best_trial.config)
+
+ # Cache the result
+ self.task_cache[cache_key] = best_config
+
+ self.logger.info(f"Best config for {task_name}: {best_config}")
+ return best_config
+
+ def _objective_function(self, config: Dict[str, Any]) -> Dict[str, float]:
+ """Objective function for Ray Tune optimization."""
+ # Create config instance
+ attention_config = self.create_config_from_params(config)
+
+ # TODO: Integrate with benchmark runner
+ # For now, return random score - replace with actual benchmark
+ import random
+ score = random.random()
+
+ return {"score": score}
+
+ @property
+ def config_type_name(self) -> str:
+ """Get the name of the config type for caching."""
+ return self._config_name
+
+
+def create_optimizer_for_config(config_class: Type, config_name: str, overrides: Optional[Dict[str, Any]] = None) -> SingleConfigOptimizer:
+ """Factory function to create a single config optimizer.
+
+ Args:
+ config_class: The masker config class to optimize
+ config_name: Name for caching purposes
+ overrides: Optional manual overrides for specific fields
+
+ Returns:
+ SingleConfigOptimizer instance
+
+ Example:
+ >>> from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations.basic_fixed import LocalMaskerConfig
+ >>> optimizer = create_optimizer_for_config(
+ ... LocalMaskerConfig,
+ ... "local_masker"
+ ... )
+ >>> best_config = optimizer.optimize_for_task("longbench_qasper", num_samples=20)
+ """
+ return SingleConfigOptimizer(config_class, config_name, overrides)
+
+
+def auto_create_composite_optimizer(masker_configs: List[Type], config_name: str, overrides: Optional[Dict[str, Any]] = None) -> CompositeConfigOptimizer:
+ """Factory function to create a composite optimizer with automatic search space discovery.
+
+ This is similar to create_composite_optimizer but emphasizes that it uses auto-discovery.
+
+ Args:
+ masker_configs: List of masker config classes to optimize
+ config_name: Name for caching purposes
+ overrides: Optional manual overrides for specific fields (prefixed by masker name)
+
+ Returns:
+ CompositeConfigOptimizer instance
+
+ Example:
+ >>> from sparse_attention_hub.sparse_attention.research_attention.maskers import MagicPigConfig, LocalMaskerConfig
+ >>> optimizer = auto_create_composite_optimizer(
+ ... [MagicPigConfig, LocalMaskerConfig],
+ ... "magic_pig_local"
+ ... )
+ >>> best_config = optimizer.optimize_for_task("longbench_qasper", num_samples=20)
+ """
+ return create_composite_optimizer(masker_configs, config_name, overrides)
+
+
+def create_composite_optimizer(masker_configs: List[Type], config_name: str, overrides: Optional[Dict[str, Any]] = None) -> CompositeConfigOptimizer:
+ """Factory function to create a composite optimizer for ResearchAttentionConfig.
+
+ Args:
+ masker_configs: List of masker config classes to optimize
+ config_name: Name for caching purposes
+ overrides: Optional manual overrides for specific fields (prefixed by masker name)
+
+ Returns:
+ CompositeConfigOptimizer instance
+
+ Example:
+ >>> from sparse_attention_hub.sparse_attention.research_attention.maskers import MagicPigConfig, LocalMaskerConfig
+ >>> optimizer = create_composite_optimizer(
+ ... [MagicPigConfig, LocalMaskerConfig],
+ ... "magic_pig_local",
+ ... overrides={"magicpig_lsh_l": tune.choice([4, 8, 12])}
+ ... )
+ >>> best_config = optimizer.optimize_for_task("longbench_qasper", num_samples=20)
+ """
+ return CompositeConfigOptimizer(masker_configs, config_name, overrides)
+
+
+# Task-specific optimization utilities
+def optimize_configs_for_all_tasks(optimizer: CompositeConfigOptimizer,
+ tasks: List[str],
+ num_samples: int = 10) -> Dict[str, Any]:
+ """Optimize configs for multiple tasks.
+
+ Args:
+ optimizer: CompositeConfigOptimizer instance
+ tasks: List of task names to optimize for
+ num_samples: Number of optimization samples per task
+
+ Returns:
+ Dictionary mapping task names to best configs
+ """
+ results = {}
+ for task in tasks:
+ results[task] = optimizer.optimize_for_task(task, num_samples)
+ return results
diff --git a/benchmark/raytune/list_benchmark_tasks.py b/benchmark/raytune/list_benchmark_tasks.py
new file mode 100644
index 00000000..98d66ef4
--- /dev/null
+++ b/benchmark/raytune/list_benchmark_tasks.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+"""
+List all benchmark tasks from optimal configs for easy inspection.
+
+Usage:
+ python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531
+ python benchmark/raytune/list_benchmark_tasks.py --config-run run_20250818_203531 --format csv > tasks.csv
+"""
+
+import argparse
+import json
+import sys
+from pathlib import Path
+from collections import defaultdict
+import csv
+
+
+def main():
+ parser = argparse.ArgumentParser(description="List benchmark tasks from optimal configs")
+ parser.add_argument("--config-run", type=str, required=True,
+ help="Config run directory name")
+ parser.add_argument("--optimal-configs-dir", default="./optimal_configs",
+ help="Base directory for optimal configurations")
+ parser.add_argument("--format", choices=["table", "csv", "json", "simple"], default="table",
+ help="Output format")
+ parser.add_argument("--group-by", choices=["model", "task", "masker", "none"], default="none",
+ help="Group tasks by field")
+ parser.add_argument("--filter-model", type=str, help="Filter by model name (substring match)")
+ parser.add_argument("--filter-task", type=str, help="Filter by task name (substring match)")
+ parser.add_argument("--filter-masker", type=str, help="Filter by masker name (substring match)")
+
+ args = parser.parse_args()
+
+ # Load configurations
+ config_dir = Path(args.optimal_configs_dir) / args.config_run
+ if not config_dir.exists():
+ print(f"Error: Config directory {config_dir} not found", file=sys.stderr)
+ sys.exit(1)
+
+ tasks = []
+ for config_file in sorted(config_dir.glob("*.json")):
+ if config_file.name.endswith(("_trials.json", "_analysis.csv")):
+ continue
+
+ try:
+ with open(config_file, "r") as f:
+ data = json.load(f)
+
+ # Apply filters
+ if args.filter_model and args.filter_model not in data["model"]:
+ continue
+ if args.filter_task and args.filter_task not in data["task"]:
+ continue
+ if args.filter_masker and args.filter_masker not in data["masker_name"]:
+ continue
+
+ tasks.append({
+ "model": data["model"],
+ "task": data["task"],
+ "masker": data["masker_name"],
+ "score": data.get("score", "N/A"),
+ "search_time": data.get("search_time", 0),
+ "num_trials": data.get("num_trials", 0),
+ "file": config_file.name
+ })
+ except Exception as e:
+ print(f"Warning: Failed to load {config_file}: {e}", file=sys.stderr)
+
+ if not tasks:
+ print("No tasks found matching criteria", file=sys.stderr)
+ sys.exit(1)
+
+ # Sort tasks
+ tasks.sort(key=lambda x: (x["model"], x["task"], x["masker"]))
+
+ # Output based on format
+ if args.format == "json":
+ print(json.dumps(tasks, indent=2))
+
+ elif args.format == "csv":
+ writer = csv.DictWriter(sys.stdout, fieldnames=["model", "task", "masker", "score", "search_time", "num_trials", "file"])
+ writer.writeheader()
+ writer.writerows(tasks)
+
+ elif args.format == "simple":
+ for task in tasks:
+ print(f"{task['model']} | {task['task']} | {task['masker']}")
+
+ else: # table format
+ # Group if requested
+ if args.group_by != "none":
+ groups = defaultdict(list)
+ for task in tasks:
+ key = task[args.group_by]
+ groups[key].append(task)
+
+ print(f"Tasks grouped by {args.group_by}:")
+ print("=" * 80)
+
+ for key in sorted(groups.keys()):
+ print(f"\n{args.group_by.upper()}: {key}")
+ print("-" * 80)
+
+ for task in groups[key]:
+ if args.group_by == "model":
+ print(f" {task['task']:30} | {task['masker']:30} | Score: {task['score']}")
+ elif args.group_by == "task":
+ print(f" {task['model']:30} | {task['masker']:30} | Score: {task['score']}")
+ else: # masker
+ print(f" {task['model']:30} | {task['task']:30} | Score: {task['score']}")
+
+ print(f" Total: {len(groups[key])} configurations")
+
+ else:
+ # Regular table
+ print(f"Benchmark Tasks from {args.config_run}")
+ print("=" * 120)
+ print(f"{'Model':35} | {'Task':25} | {'Masker':30} | {'Score':8} | {'Trials':6}")
+ print("-" * 120)
+
+ for task in tasks:
+ score_str = f"{task['score']:.4f}" if isinstance(task['score'], (int, float)) else str(task['score'])
+ print(f"{task['model']:35} | {task['task']:25} | {task['masker']:30} | {score_str:8} | {task['num_trials']:6}")
+
+ print("-" * 120)
+ print(f"Total: {len(tasks)} configurations")
+
+ # Summary statistics
+ print(f"\nSummary:")
+ models = set(t["model"] for t in tasks)
+ tasks_set = set(t["task"] for t in tasks)
+ maskers = set(t["masker"] for t in tasks)
+
+ print(f" Models: {len(models)}")
+ for model in sorted(models):
+ count = sum(1 for t in tasks if t["model"] == model)
+ print(f" - {model}: {count} configs")
+
+ print(f" Tasks: {len(tasks_set)}")
+ for task in sorted(tasks_set):
+ count = sum(1 for t in tasks if t["task"] == task)
+ print(f" - {task}: {count} configs")
+
+ print(f" Maskers: {len(maskers)}")
+ for masker in sorted(maskers):
+ count = sum(1 for t in tasks if t["masker"] == masker)
+ print(f" - {masker}: {count} configs")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/raytune/optimizer_factory.py b/benchmark/raytune/optimizer_factory.py
new file mode 100755
index 00000000..52818447
--- /dev/null
+++ b/benchmark/raytune/optimizer_factory.py
@@ -0,0 +1,140 @@
+"""
+Optimizer Factory for Sparse Attention Configurations.
+
+This module provides the core engine for creating optimizer objects that can
+translate sparse attention masker configurations into Ray Tune search spaces.
+
+The key design principle is that each masker's configuration class is responsible
+for defining its own tunable parameters via a `get_search_space()` static method.
+This factory then assembles these individual search spaces for optimization.
+"""
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Type, Optional
+
+from sparse_attention_hub.sparse_attention.research_attention import (
+ ResearchAttentionConfig,
+)
+
+class SparseConfigOptimizer(ABC):
+ """
+ Abstract Base Class for sparse attention config optimizers.
+
+ An optimizer's main responsibilities are to create a search space for Ray Tune
+ and to instantiate a valid attention configuration from a set of parameters
+ produced by a Ray Tune trial.
+ """
+
+ @abstractmethod
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ """Creates the Ray Tune search space for a given task."""
+ pass
+
+ @abstractmethod
+ def create_config_from_params(self, params: Dict[str, Any]) -> Any:
+ """Creates an attention configuration instance from a dictionary of parameters."""
+ pass
+
+class SingleConfigOptimizer(SparseConfigOptimizer):
+ """Optimizer for a single, non-composite masker configuration class."""
+
+ def __init__(self, config_class: Type):
+ if not hasattr(config_class, "get_search_space"):
+ raise TypeError(
+ f"Config class {config_class.__name__} must implement a "
+ "`get_search_space(task_name)` static method."
+ )
+ self.config_class = config_class
+
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ return self.config_class.get_search_space(task_name)
+
+ def create_config_from_params(self, params: Dict[str, Any]) -> Any:
+ return self.config_class(**params)
+
+class CompositeConfigOptimizer(SparseConfigOptimizer):
+ """Optimizer for a `ResearchAttentionConfig` composed of multiple maskers."""
+
+ def __init__(self, masker_configs: List[Type], template_config: Optional[ResearchAttentionConfig] = None):
+ self.masker_configs = []
+ self.template_config = template_config
+
+ # Create a mapping from masker class to template instance if template is provided
+ self.template_instances = {}
+ if template_config:
+ for template_masker in template_config.masker_configs:
+ self.template_instances[type(template_masker)] = template_masker
+
+ for masker_class in masker_configs:
+ if not hasattr(masker_class, "get_search_space"):
+ raise TypeError(
+ f"Masker config {masker_class.__name__} must implement a "
+ "`get_search_space(task_name)` static method."
+ )
+ self.masker_configs.append(masker_class)
+
+ def create_search_space(self, task_name: str) -> Dict[str, Any]:
+ """
+ Creates a combined search space from all component masker configs.
+ Each parameter is prefixed with its masker's name to prevent conflicts.
+ """
+ combined_space = {}
+ for masker_class in self.masker_configs:
+ masker_name = masker_class.__name__.lower().replace("config", "")
+ masker_space = masker_class.get_search_space(task_name)
+ for param_name, param_space in masker_space.items():
+ combined_space[f"{masker_name}_{param_name}"] = param_space
+ return combined_space
+
+ def create_config_from_params(self, params: Dict[str, Any]) -> ResearchAttentionConfig:
+ """Creates a ResearchAttentionConfig instance from the combined parameters."""
+ masker_instances = []
+ for masker_class in self.masker_configs:
+ masker_name = masker_class.__name__.lower().replace("config", "")
+ prefix = f"{masker_name}_"
+ masker_params = {
+ k[len(prefix) :]: v for k, v in params.items() if k.startswith(prefix)
+ }
+
+ # If we have a template for this masker type, use its fixed parameters
+ if masker_class in self.template_instances:
+ template_masker = self.template_instances[masker_class]
+ # Get all attributes from the template
+ template_dict = {}
+ for attr in dir(template_masker):
+ if not attr.startswith('_') and not callable(getattr(template_masker, attr)):
+ try:
+ value = getattr(template_masker, attr)
+ # Only include simple types that can be serialized
+ if isinstance(value, (int, float, str, bool, type(None))):
+ template_dict[attr] = value
+ except:
+ pass
+
+ # Update template with search params (search params override template)
+ template_dict.update(masker_params)
+ masker_instances.append(masker_class(**template_dict))
+ else:
+ masker_instances.append(masker_class(**masker_params))
+
+ return ResearchAttentionConfig(masker_configs=masker_instances)
+
+def create_optimizer(masker_configs: List[Type], template_config: Optional[ResearchAttentionConfig] = None) -> SparseConfigOptimizer:
+ """
+ Factory function to create the appropriate optimizer.
+
+ This function inspects the list of masker configurations and returns the
+ correct optimizer type.
+
+ Args:
+ masker_configs: List of masker configuration classes to optimize
+ template_config: Optional template configuration with fixed parameters
+ """
+ if not isinstance(masker_configs, list) or not masker_configs:
+ raise ValueError("`masker_configs` must be a non-empty list of config classes.")
+
+ logging.info(f"Creating optimizer for: {[c.__name__ for c in masker_configs]}")
+
+ if len(masker_configs) == 1:
+ return SingleConfigOptimizer(masker_configs[0])
+ return CompositeConfigOptimizer(masker_configs, template_config)
\ No newline at end of file
diff --git a/benchmark/raytune/run_full_benchmark.py b/benchmark/raytune/run_full_benchmark.py
new file mode 100755
index 00000000..3822c69f
--- /dev/null
+++ b/benchmark/raytune/run_full_benchmark.py
@@ -0,0 +1,1075 @@
+#!/usr/bin/env python3
+"""
+Two-Phase Benchmark System for Sparse Attention Methods.
+
+Phase 1: Hyperparameter search to find optimal configs for each (model, task, masker) combination
+Phase 2: Parallel benchmark execution using the discovered optimal configs
+
+Usage:
+ # Run both phases (default)
+ python benchmark/raytune/run_two_phase_benchmark.py
+
+ # Run only Phase 1 (config search)
+ python benchmark/raytune/run_two_phase_benchmark.py --phase 1
+
+ # Run only Phase 2 (benchmark execution)
+ python benchmark/raytune/run_two_phase_benchmark.py --phase 2
+
+ # Debug mode (minimal configs, fast execution)
+ python benchmark/raytune/run_two_phase_benchmark.py --debug
+
+ # Force re-search in Phase 1
+ python benchmark/raytune/run_two_phase_benchmark.py --phase 1 --force-search
+"""
+
+import argparse
+import json
+import logging
+import math
+import os
+import sys
+import time
+import traceback
+from pathlib import Path
+from datetime import datetime
+from typing import Dict, List, Any, Optional, Tuple
+from dataclasses import dataclass, asdict, field
+import pickle
+
+# Path setup
+current_dir = Path(__file__).parent
+root_path = current_dir.parent.parent
+sys.path.extend([str(current_dir), str(root_path)])
+os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + f":{current_dir}:{root_path}"
+
+import torch
+import pandas as pd
+from benchmark.executor import BenchmarkExecutor
+from benchmark.executor_config import AdapterConfig, BenchmarkConfig, BenchmarkResult
+from optimizer_factory import create_optimizer
+
+# Import all masker configs
+from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
+from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
+ LocalMaskerConfig,
+ SinkMaskerConfig,
+ OracleTopKConfig,
+ OracleTopPMaskerConfig,
+ HashAttentionTopKMaskerConfig,
+)
+from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import (
+ AdaptiveSamplingMaskerConfig,
+ RandomSamplingMaskerConfig,
+ MagicPigConfig,
+)
+
+try:
+ import ray
+ from ray import tune
+ from ray.tune.schedulers import ASHAScheduler
+ from ray.tune.search.hyperopt import HyperOptSearch
+except ImportError:
+ print("Error: Ray Tune required. Install with: pip install 'ray[tune]' hyperopt")
+ sys.exit(1)
+
+
+# Note: Configuration names are based on the masker classes used, not parameter values
+# Parameter values come from Ray Tune search, not from these initial configs
+
+
+@dataclass
+class OptimalConfig:
+ """Stores optimal configuration found in Phase 1."""
+ model: str
+ task: str
+ masker_name: str
+ sparse_config: Optional[ResearchAttentionConfig]
+ masker_classes: Optional[List] = field(default=None)
+ hyperparams: Dict[str, Any] = field(default_factory=dict)
+ score: float = 0.0
+ search_time: float = 0.0
+ num_trials: int = 0
+
+
+def create_sparsity_objective(target_density: float, penalty_weight: float = 10.0):
+ """Create an objective function that targets a specific sparsity level.
+
+ Args:
+ target_density: Target density level (e.g., 0.05 for 5% density)
+ penalty_weight: Weight for penalty when density exceeds target
+
+ Returns:
+ Objective function that can be used for optimization
+ """
+ def objective(error: float, density: float) -> float:
+ # Base objective: heavily weight error, lightly weight density
+ base_score = 0.99 * error + 0.01 * density
+
+ # Add penalty if density exceeds target
+ penalty = penalty_weight * max(0, density - target_density)
+
+ return base_score + penalty
+
+ objective.__name__ = f"objective_sparsity_{int(target_density * 100)}_percent"
+ return objective
+
+
+# Pre-defined objective functions for common sparsity levels
+OBJECTIVE_FUNCTIONS = {
+ "sparsity_5": create_sparsity_objective(0.05),
+ "sparsity_10": create_sparsity_objective(0.10),
+ "sparsity_15": create_sparsity_objective(0.15),
+ "sparsity_20": create_sparsity_objective(0.20),
+ "sparsity_25": create_sparsity_objective(0.25),
+ "default": lambda error, density: error + 0.1 * density + (5.0 if density > 0.5 else 0.0),
+}
+
+
+class Phase1BenchmarkRunner:
+ """Handles individual benchmark runs during config search."""
+
+ def __init__(self, config: dict):
+ self.config = config
+ self.executor = BenchmarkExecutor(
+ gpu_ids=[0], # Single GPU per trial
+ max_concurrent_runs=1,
+ base_result_dir=config["search_result_dir"],
+ enable_resumability=False,
+ required_result_files=["raw_results.csv"],
+ timeout_per_benchmark=config["search_timeout"],
+ verbose=False,
+ )
+ self.adapter_config = AdapterConfig(
+ adapter_name="huggingface",
+ model_kwargs={"torch_dtype": torch.bfloat16},
+ tokenizer_kwargs={"padding_side": "left"},
+ )
+ self.generation_kwargs = {
+ "max_new_tokens": config["search_max_new_tokens"],
+ "do_sample": False
+ }
+ self.request_kwargs = {
+ "max_context_length": config["search_max_context_length"],
+ "max_requests": config["search_max_requests"],
+ }
+
+ # Get objective function
+ self.objective_name = config.get("objective_function", "default")
+ self.objective_function = OBJECTIVE_FUNCTIONS.get(self.objective_name, OBJECTIVE_FUNCTIONS["default"])
+ logging.info(f"Using objective function: {self.objective_name}")
+
+ def __call__(self, attention_config, task_name: str, model_name: str) -> Tuple[float, float, float]:
+ """Run benchmark and return (score, density, error) tuple."""
+ try:
+ benchmark_name, subset_name = task_name.split("/", 1) if "/" in task_name else (task_name, None)
+ benchmark_config = BenchmarkConfig(
+ benchmark_name=benchmark_name,
+ subsets=[subset_name] if subset_name else None
+ )
+
+ results = self.executor.run_benchmark_matrix(
+ model_names=[model_name],
+ sparse_attention_configs=[("search", attention_config)],
+ benchmark_configs=[benchmark_config],
+ adapter_config=self.adapter_config,
+ generation_kwargs=self.generation_kwargs,
+ request_kwargs=self.request_kwargs,
+ )
+
+ # Extract score from results
+ if results.progress.completed_stubs > 0 and hasattr(results, "individual_results"):
+ completed = [r for r in results.individual_results if isinstance(r, BenchmarkResult)]
+ if completed:
+ result_dir = Path(completed[0].stub.result_dir)
+ metrics = self._extract_micro_metrics(result_dir)
+ error, density = metrics["attention_error"], metrics["density"]
+
+ # For dense configuration (density=1.0, error=0.0), use a simple score
+ if density == 1.0 and error == 0.0:
+ # Dense baseline: use benchmark accuracy metrics instead of sparse metrics
+ score = 0.1 # Small baseline score for dense
+ else:
+ # Use the selected objective function
+ score = self.objective_function(error, density)
+ # Also print to stdout so the test script can detect it
+ print(f"Objective: {self.objective_name}, Error: {error:.4f}, Density: {density:.4f}, Score: {score:.4f}")
+ logging.info(f"Objective: {self.objective_name}, Error: {error:.4f}, Density: {density:.4f}, Score: {score:.4f}")
+
+ return score, density, error
+
+ except Exception as e:
+ logging.error(f"Benchmark failed: {e}")
+
+ return 5.0, 1.0, 1.0 # Penalty score, worst-case density, and worst-case error
+
+ def _extract_micro_metrics(self, result_dir: Path) -> dict:
+ """Extract attention error and density from micro metrics."""
+ micro_metrics_file = result_dir / "micro_metrics.jsonl"
+ if not micro_metrics_file.exists():
+ # For dense configuration, micro_metrics.jsonl won't exist since no sparse attention is used
+ # Return default values: 0 error (perfect) and 1.0 density (fully dense)
+ logging.info(f"micro_metrics.jsonl not found in {result_dir}, using dense defaults")
+ return {"attention_error": 0.0, "density": 1.0}
+
+ errors, densities = [], []
+ with open(micro_metrics_file, "r") as f:
+ for line in f:
+ try:
+ entry = json.loads(line.strip())
+ metric, value = entry.get("metric"), entry.get("value")
+ if value is not None and not (isinstance(value, float) and math.isnan(value)):
+ if metric == "research_attention_output_error":
+ errors.append(float(value))
+ elif metric == "research_attention_density":
+ densities.append(float(value))
+ except (json.JSONDecodeError, ValueError, TypeError):
+ continue
+
+ return {
+ "attention_error": sum(errors) / len(errors) if errors else 1.0,
+ "density": sum(densities) / len(densities) if densities else 1.0
+ }
+
+
+class ConfigSearchManager:
+ """Manages Phase 1: Hyperparameter search for optimal configs."""
+
+ def __init__(self, base_config: dict):
+ self.config = base_config
+ # Add timestamp to the results directory
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ base_dir = Path(base_config["optimal_configs_dir"])
+ self.results_dir = base_dir / f"run_{timestamp}"
+ self.results_dir.mkdir(parents=True, exist_ok=True)
+ self.timestamp = timestamp
+ print(f"Saving optimal configs to: {self.results_dir}")
+
+ def search_optimal_config(
+ self,
+ model: str,
+ task: str,
+ masker_name: str,
+ masker_classes: Optional[List],
+ full_sparse_config: Optional[ResearchAttentionConfig] = None
+ ) -> OptimalConfig:
+ """Search for optimal hyperparameters for a single combination."""
+
+ config_file = self.results_dir / f"{model}_{task}_{masker_name}.json".replace("/", "_")
+
+ # Check if already exists
+ if config_file.exists() and not self.config.get("force_search", False):
+ print(f" → Loading existing config")
+ return self._load_config(config_file)
+
+ # Handle dense config (no optimization needed)
+ if masker_classes is None:
+ optimal = OptimalConfig(
+ model=model,
+ task=task,
+ masker_name=masker_name,
+ sparse_config=None,
+ masker_classes=None,
+ hyperparams={},
+ score=0.0,
+ search_time=0.0,
+ num_trials=1
+ )
+ self._save_config(optimal, config_file)
+ return optimal
+
+ # Run hyperparameter search
+ print(f" → Running hyperparameter search...")
+ start_time = time.time()
+
+ try:
+ # Create optimizer with template config for fixed parameters
+ optimizer = create_optimizer(masker_classes, full_sparse_config)
+
+ # Show what we're searching
+ search_space = optimizer.create_search_space(task)
+ print(f" → Search space parameters:")
+ for param, space_obj in search_space.items():
+ # Extract actual values from Ray Tune objects
+ if hasattr(space_obj, 'categories'):
+ values = space_obj.categories
+ print(f" - {param}: {values}")
+ else:
+ print(f" - {param}: {space_obj}")
+
+ # Create objective function
+ def objective(trial_config):
+ runner = Phase1BenchmarkRunner(self.config)
+ attention_config = optimizer.create_config_from_params(trial_config)
+ score, density, error = runner(attention_config, task, model)
+ return {"combined_score": score, "density": density, "error": error}
+
+ # Get Ray Tune components
+ search_space = optimizer.create_search_space(task)
+ scheduler = ASHAScheduler(
+ time_attr="training_iteration",
+ max_t=20,
+ grace_period=5,
+ reduction_factor=2
+ )
+ search_alg = HyperOptSearch(
+ metric="combined_score",
+ mode="min",
+ n_initial_points=max(1, self.config["num_samples"] // 4)
+ )
+
+ # Run Ray Tune
+ sanitized_name = f"{model}_{task}_{masker_name}".replace("/", "_")
+ analysis = tune.run(
+ objective,
+ config=search_space,
+ num_samples=self.config["num_samples"],
+ metric="combined_score",
+ mode="min",
+ scheduler=scheduler,
+ search_alg=search_alg,
+ resources_per_trial={"CPU": 1, "GPU": 1.0},
+ storage_path=os.path.abspath(self.config["ray_results_dir"]),
+ name=sanitized_name,
+ verbose=1, # Show Ray Tune progress
+ stop={"training_iteration": 1}, # One evaluation per config
+ )
+
+ # Get best config
+ best_trial = analysis.get_best_trial("combined_score", "min", "last")
+ best_config = optimizer.create_config_from_params(best_trial.config)
+
+ # Save detailed trial information for post-analysis
+ trials_info = []
+ for trial in analysis.trials:
+ trial_info = {
+ "trial_id": trial.trial_id,
+ "config": trial.config,
+ "score": trial.last_result.get("combined_score", float('inf')) if trial.last_result else float('inf'),
+ "status": trial.status,
+ "start_time": trial.start_time.isoformat() if hasattr(trial, 'start_time') and trial.start_time else None,
+ "metric_history": trial.metric_analysis.get("combined_score", {}) if hasattr(trial, 'metric_analysis') else {}
+ }
+ trials_info.append(trial_info)
+
+ # Save trial details to separate file
+ trials_file = self.results_dir / f"{model}_{task}_{masker_name}_trials.json".replace("/", "_")
+ with open(trials_file, "w") as f:
+ json.dump({
+ "model": model,
+ "task": task,
+ "masker_name": masker_name,
+ "objective_function": self.config.get("objective_function", "default"),
+ "best_trial_id": best_trial.trial_id,
+ "trials": trials_info,
+ "analysis_dataframe_path": str(self.results_dir / f"{model}_{task}_{masker_name}_analysis.csv".replace("/", "_"))
+ }, f, indent=2)
+
+ # Save Ray analysis dataframe for detailed analysis
+ df = analysis.dataframe()
+ df.to_csv(self.results_dir / f"{model}_{task}_{masker_name}_analysis.csv".replace("/", "_"), index=False)
+
+ optimal = OptimalConfig(
+ model=model,
+ task=task,
+ masker_name=masker_name,
+ sparse_config=best_config,
+ masker_classes=masker_classes,
+ hyperparams=best_trial.config,
+ score=best_trial.last_result["combined_score"],
+ search_time=time.time() - start_time,
+ num_trials=len(analysis.trials)
+ )
+
+ self._save_config(optimal, config_file)
+ return optimal
+
+ except Exception as e:
+ print(f" ✗ Search failed: {e}")
+ traceback.print_exc()
+ # Return failure config
+ optimal = OptimalConfig(
+ model=model,
+ task=task,
+ masker_name=masker_name,
+ sparse_config=full_sparse_config, # Use the full config passed in
+ masker_classes=masker_classes,
+ hyperparams={},
+ score=5.0,
+ search_time=time.time() - start_time,
+ num_trials=0
+ )
+ self._save_config(optimal, config_file)
+ return optimal
+
+ def _save_config(self, config: OptimalConfig, filepath: Path):
+ """Save configuration to JSON."""
+ data = asdict(config)
+
+ # Convert sparse config to serializable format
+ if config.sparse_config:
+ data["sparse_config"] = self._serialize_sparse_config(config.sparse_config)
+
+ # Convert masker classes to strings
+ if config.masker_classes:
+ data["masker_classes"] = [cls.__name__ for cls in config.masker_classes]
+
+ with open(filepath, "w") as f:
+ json.dump(data, f, indent=2)
+
+ def _load_config(self, filepath: Path) -> OptimalConfig:
+ """Load configuration from JSON."""
+ with open(filepath, "r") as f:
+ data = json.load(f)
+
+ # Reconstruct sparse config if present
+ if data.get("sparse_config"):
+ data["sparse_config"] = self._deserialize_sparse_config(data["sparse_config"])
+
+ # Reconstruct masker classes from strings
+ if data.get("masker_classes"):
+ # Map class names to actual classes
+ class_map = {
+ "LocalMaskerConfig": LocalMaskerConfig,
+ "SinkMaskerConfig": SinkMaskerConfig,
+ "OracleTopKConfig": OracleTopKConfig,
+ "OracleTopPMaskerConfig": OracleTopPMaskerConfig,
+ "HashAttentionTopKMaskerConfig": HashAttentionTopKMaskerConfig,
+ "AdaptiveSamplingMaskerConfig": AdaptiveSamplingMaskerConfig,
+ "RandomSamplingMaskerConfig": RandomSamplingMaskerConfig,
+ "MagicPigConfig": MagicPigConfig,
+ }
+ data["masker_classes"] = [class_map[name] for name in data["masker_classes"]]
+
+ return OptimalConfig(**data)
+
+ def _serialize_sparse_config(self, config: ResearchAttentionConfig) -> dict:
+ """Convert ResearchAttentionConfig to JSON-serializable format."""
+ if config is None:
+ return None
+
+ # Serialize each masker config
+ masker_configs = []
+ for masker in config.masker_configs:
+ masker_dict = {
+ "type": type(masker).__name__,
+ "params": {}
+ }
+ # Add all attributes
+ for attr in dir(masker):
+ if not attr.startswith("_") and hasattr(masker, attr):
+ value = getattr(masker, attr)
+ if isinstance(value, (int, float, str, bool, type(None))):
+ masker_dict["params"][attr] = value
+ masker_configs.append(masker_dict)
+
+ return {
+ "type": "ResearchAttentionConfig",
+ "masker_configs": masker_configs
+ }
+
+ def _deserialize_sparse_config(self, data: dict) -> ResearchAttentionConfig:
+ """Reconstruct ResearchAttentionConfig from JSON data."""
+ if data is None:
+ return None
+
+ if data.get("type") != "ResearchAttentionConfig":
+ return None
+
+ # Map config types to classes
+ config_map = {
+ "LocalMaskerConfig": LocalMaskerConfig,
+ "SinkMaskerConfig": SinkMaskerConfig,
+ "OracleTopKConfig": OracleTopKConfig,
+ "OracleTopPMaskerConfig": OracleTopPMaskerConfig,
+ "HashAttentionTopKMaskerConfig": HashAttentionTopKMaskerConfig,
+ "AdaptiveSamplingMaskerConfig": AdaptiveSamplingMaskerConfig,
+ "RandomSamplingMaskerConfig": RandomSamplingMaskerConfig,
+ "MagicPigConfig": MagicPigConfig,
+ }
+
+ # Reconstruct masker configs
+ masker_configs = []
+ for masker_data in data.get("masker_configs", []):
+ config_class = config_map.get(masker_data["type"])
+ if config_class:
+ # Create instance with parameters
+ params = masker_data.get("params", {})
+ masker_configs.append(config_class(**params))
+
+ return ResearchAttentionConfig(masker_configs=masker_configs)
+
+
+def run_phase_one(config: dict) -> Dict[str, OptimalConfig]:
+ """Phase 1: Find optimal configurations for all combinations."""
+ print("\n" + "="*80)
+ print("PHASE 1: HYPERPARAMETER SEARCH")
+ print("="*80)
+ print(f"Models: {len(config['models'])}")
+ print(f"Tasks: {len(config['tasks'])}")
+ print(f"Sparse Configs: {len(config['sparse_configs'])}")
+ print(f"Total Combinations: {len(config['models']) * len(config['tasks']) * len(config['sparse_configs'])}")
+ print(f"Samples per search: {config['num_samples']}")
+ print(f"Objective Function: {config['objective_function']}")
+
+ # Display objective function details
+ if config['objective_function'].startswith('sparsity_'):
+ target = int(config['objective_function'].split('_')[1])
+ print(f" → Targeting {target}% density (0.{target:02d} fraction)")
+ print(f" → Formula: 0.99 * error + 0.01 * density + penalty for exceeding target")
+
+ print("\nSearch Configuration:")
+ print(f" → Max new tokens: {config['search_max_new_tokens']}")
+ print(f" → Max context length: {config['search_max_context_length']}")
+ print(f" → Max requests per trial: {config['search_max_requests']}")
+ print(f" → Timeout per trial: {config['search_timeout']}s")
+
+ print("\nNote: For each sparse config, Ray Tune will search different hyperparameter")
+ print("values (e.g., window_size, sink_size, sampling_rate) to find the best combination.")
+ print("="*80)
+
+ manager = ConfigSearchManager(config)
+ optimal_configs = {}
+
+ total = len(config["models"]) * len(config["tasks"]) * len(config["sparse_configs"])
+ current = 0
+
+ for model in config["models"]:
+ print(f"\nModel: {model}")
+ print("-" * 60)
+
+ for task in config["tasks"]:
+ for masker_name, (masker_classes, full_config) in config["sparse_configs_map"].items():
+ current += 1
+ key = f"{model}_{task}_{masker_name}".replace("/", "_")
+
+ print(f"\n[{current}/{total}] Task: {task} | Config: {masker_name}")
+
+ # Explain what this config is
+ if masker_classes:
+ print(f" → Config contains: {[cls.__name__ for cls in masker_classes]}")
+ else:
+ print(f" → Dense configuration (no sparse attention)")
+
+ try:
+ optimal = manager.search_optimal_config(
+ model, task, masker_name, masker_classes, full_config
+ )
+ optimal_configs[key] = optimal
+
+ if optimal.num_trials > 0:
+ print(f" ✓ Best score: {optimal.score:.4f} (searched {optimal.num_trials} configs in {optimal.search_time:.1f}s)")
+ else:
+ print(f" ✓ Score: {optimal.score:.4f} (no search needed)")
+
+ except Exception as e:
+ print(f" ✗ Failed: {e}")
+ continue
+
+ print(f"\n{'='*80}")
+ print(f"Phase 1 complete. Found {len(optimal_configs)} optimal configurations.")
+ print(f"Configs saved to: {manager.results_dir}")
+ print(f"Run identifier: {manager.timestamp}")
+ print(f"\nTo use these configs in Phase 2:")
+ print(f" python {sys.argv[0]} --phase 2 # Uses most recent configs")
+ print(f" python {sys.argv[0]} --phase 2 --config-run run_{manager.timestamp} # Uses this specific run")
+ print(f"{'='*80}")
+
+ return optimal_configs
+
+
+def run_phase_two(config: dict, optimal_configs: Dict[str, OptimalConfig]) -> dict:
+ """Phase 2: Run benchmarks with optimal configurations."""
+ print("\n" + "="*80)
+ print("PHASE 2: BENCHMARK EXECUTION")
+ print("="*80)
+
+ # Build unique sparse configs from optimal configs
+ unique_sparse_configs = []
+ seen = set()
+ config_usage = {} # Track which (model, task) use each config
+
+ for key, opt_config in optimal_configs.items():
+ config_str = str(opt_config.sparse_config) if opt_config.sparse_config else "None"
+ if config_str not in seen:
+ seen.add(config_str)
+ unique_sparse_configs.append((
+ opt_config.masker_name,
+ opt_config.sparse_config
+ ))
+ config_usage[config_str] = []
+ config_usage[config_str].append((opt_config.model, opt_config.task))
+
+ print(f"Unique sparse configurations: {len(unique_sparse_configs)}")
+ print(f"Models: {len(config['models'])}")
+ print(f"Tasks: {len(config['tasks'])}")
+ print(f"Total benchmark runs: {len(config['models']) * len(config['tasks']) * len(unique_sparse_configs)}")
+ print(f"GPUs available: {len(config['gpu_ids'])}")
+ print("="*80)
+
+ # Create executor
+ executor = BenchmarkExecutor(
+ gpu_ids=config["gpu_ids"],
+ max_concurrent_runs=len(config["gpu_ids"]),
+ base_result_dir=config["benchmark_results_dir"],
+ enable_resumability=True,
+ required_result_files=["raw_results.csv"],
+ timeout_per_benchmark=config["benchmark_timeout"],
+ verbose=True
+ )
+
+ # Create benchmark configs
+ benchmark_configs = []
+ for task in config["tasks"]:
+ if "/" in task:
+ name, subset = task.split("/", 1)
+ benchmark_configs.append(BenchmarkConfig(
+ benchmark_name=name,
+ subsets=[subset]
+ ))
+ else:
+ benchmark_configs.append(BenchmarkConfig(
+ benchmark_name=task,
+ subsets=None
+ ))
+
+ # Run benchmarks
+ print("\nStarting benchmark execution...")
+ results = executor.run_benchmark_matrix(
+ model_names=config["models"],
+ sparse_attention_configs=unique_sparse_configs,
+ benchmark_configs=benchmark_configs,
+ adapter_config=AdapterConfig(
+ adapter_name="huggingface",
+ model_kwargs={"torch_dtype": torch.bfloat16},
+ tokenizer_kwargs={"padding_side": "left"}
+ ),
+ generation_kwargs={
+ "max_new_tokens": config["benchmark_max_new_tokens"],
+ "do_sample": False,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "pad_token_id": None,
+ },
+ request_kwargs={
+ "max_context_length": config["benchmark_max_context_length"],
+ "max_requests": config["benchmark_max_requests"]
+ }
+ )
+
+ # Save summary
+ summary = {
+ "timestamp": datetime.now().isoformat(),
+ "objective_function": config["objective_function"],
+ "config_run_used": config.get("config_run_dir", "unknown"),
+ "phase1_optimal_configs": {
+ k: {
+ "model": v.model,
+ "task": v.task,
+ "masker_name": v.masker_name,
+ "score": v.score,
+ "hyperparams": v.hyperparams,
+ "search_time": v.search_time,
+ "num_trials": v.num_trials
+ } for k, v in optimal_configs.items()
+ },
+ "phase2_results": {
+ "total": results.progress.total_stubs,
+ "completed": results.progress.completed_stubs,
+ "failed": results.progress.failed_stubs,
+ "skipped": results.progress.skipped_stubs,
+ },
+ "configuration": {
+ "models": config["models"],
+ "tasks": config["tasks"],
+ "num_sparse_configs": len(unique_sparse_configs),
+ "objective_function": config["objective_function"],
+ "benchmark_timeout": config["benchmark_timeout"],
+ "max_new_tokens": config["benchmark_max_new_tokens"],
+ "max_context_length": config["benchmark_max_context_length"],
+ }
+ }
+
+ summary_file = Path(config["benchmark_results_dir"]) / "benchmark_summary.json"
+ summary_file.parent.mkdir(parents=True, exist_ok=True)
+ with open(summary_file, "w") as f:
+ json.dump(summary, f, indent=2, default=str)
+
+ print(f"\n{'='*80}")
+ print(f"Phase 2 complete.")
+ print(f"Results saved to: {config['benchmark_results_dir']}")
+ print(f"Summary saved to: {summary_file}")
+ print(f"Completed: {results.progress.completed_stubs}/{results.progress.total_stubs}")
+ print(f"Failed: {results.progress.failed_stubs}")
+ print(f"{'='*80}")
+
+ return summary
+
+
+def get_masker_list_name(masker_classes: List) -> str:
+ """Generate a name based on the masker classes being used."""
+ if not masker_classes:
+ return "dense"
+
+ # Extract just the key part of each masker name
+ parts = []
+ for cls in masker_classes:
+ name = cls.__name__.replace("MaskerConfig", "").replace("Config", "")
+ # Convert camelCase to lowercase
+ name = ''.join(['_' + c.lower() if c.isupper() else c for c in name]).lstrip('_')
+ parts.append(name)
+
+ return "_".join(parts)
+
+
+def get_all_sparse_configs(weight_file: str = None) -> List[Tuple[str, Optional[ResearchAttentionConfig], Optional[List]]]:
+ """Get all sparse attention configurations.
+ Returns list of (name, full_config, masker_classes) tuples.
+
+ Note: The configs returned here are only used to determine which masker classes
+ to use. The actual parameter values will be determined by Ray Tune search.
+ """
+ configs = []
+
+ # Dense baseline
+ configs.append(("dense", None, None))
+
+ # ==================== Config Set 1: Basic Sampling =================
+ # Random sampling with sink and local
+ classes = [SinkMaskerConfig, LocalMaskerConfig, RandomSamplingMaskerConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32), # Middle value from search space [4, 8, 16, 32, 64, 128]
+ LocalMaskerConfig(window_size=128), # Middle value from search space [32, 64, 128, 256]
+ RandomSamplingMaskerConfig(sampling_rate=0.1) # Middle value from search space [0.01, 0.05, 0.1, 0.2, 0.3, 0.5]
+ ])
+ configs.append((name, config, classes))
+
+ # Adaptive sampling with oracle top k
+ classes = [SinkMaskerConfig, LocalMaskerConfig, OracleTopKConfig, AdaptiveSamplingMaskerConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ OracleTopKConfig(heavy_size=0.05), # Middle value from search space
+ AdaptiveSamplingMaskerConfig(
+ base_rate_sampling=0.1, # Middle value
+ epsilon=0.25, # Middle value
+ delta=0.25, # Middle value
+ init_offset=0.005, # Middle value
+ local_offset=0.005 # Middle value
+ )
+ ])
+ configs.append((name, config, classes))
+
+ # Adaptive sampling with HAT top k
+ if weight_file:
+ classes = [SinkMaskerConfig, LocalMaskerConfig, HashAttentionTopKMaskerConfig, AdaptiveSamplingMaskerConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ HashAttentionTopKMaskerConfig(
+ heavy_size=0.05, # Required parameter
+ hat_bits=32, # Required parameter
+ hat_mlp_layers=3, # Required parameter
+ hat_mlp_hidden_size=128, # Required parameter
+ hat_mlp_activation="silu", # Required parameter
+ hat_weight_file=weight_file # Weight file is required
+ ),
+ AdaptiveSamplingMaskerConfig(
+ base_rate_sampling=0.1,
+ epsilon=0.25,
+ delta=0.25,
+ init_offset=0.005,
+ local_offset=0.005
+ )
+ ])
+ configs.append((name, config, classes))
+
+ # HAT top k (without adaptive)
+ classes = [SinkMaskerConfig, LocalMaskerConfig, HashAttentionTopKMaskerConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ HashAttentionTopKMaskerConfig(
+ heavy_size=0.05,
+ hat_bits=32,
+ hat_mlp_layers=3,
+ hat_mlp_hidden_size=128,
+ hat_mlp_activation="silu",
+ hat_weight_file=weight_file
+ ),
+ ])
+ configs.append((name, config, classes))
+
+ # Oracle top p
+ classes = [SinkMaskerConfig, LocalMaskerConfig, OracleTopPMaskerConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ OracleTopPMaskerConfig(top_p=0.9) # Default middle value from search space
+ ])
+ configs.append((name, config, classes))
+
+ # Oracle top k (already included above with adaptive, but also standalone)
+ classes = [SinkMaskerConfig, LocalMaskerConfig, OracleTopKConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ OracleTopKConfig(heavy_size=0.05)
+ ])
+ configs.append((name, config, classes))
+
+ # MagicPig config
+ classes = [SinkMaskerConfig, LocalMaskerConfig, MagicPigConfig]
+ name = get_masker_list_name(classes)
+ config = ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=32),
+ LocalMaskerConfig(window_size=128),
+ MagicPigConfig(
+ lsh_l=8, # Default value from search space
+ lsh_k=8 # Default value from search space
+ )
+ ])
+ configs.append((name, config, classes))
+
+ return configs
+
+
+def get_run_configuration(args: argparse.Namespace) -> dict:
+ """Build complete configuration from command-line arguments."""
+ num_gpus = torch.cuda.device_count()
+
+ # Get HashAttention weights file
+ machine_key = "ubuntu"
+ weight_file = f"/home/{machine_key}/scratch/krishna/artifacts/llama3.1-8b-patch.64K.v1.hat_weights.pkl"
+ if not os.path.exists(weight_file):
+ weight_file = "./hat_weights.pkl"
+ print(f"Warning: HashAttention weights not found, using {weight_file}")
+
+ # Get all sparse configs
+ all_sparse_configs = get_all_sparse_configs(weight_file)
+
+ # Filter configs based on debug mode
+ if args.debug:
+ sparse_configs = all_sparse_configs[:3] # Just first 3 for debug
+ models = ["meta-llama/Llama-3.1-8B-Instruct"]
+ tasks = ["loogle/shortdep_qa"]
+ num_samples = 8
+ else:
+ sparse_configs = all_sparse_configs
+ models = ["meta-llama/Llama-3.1-8B-Instruct"]
+ tasks = [
+ # "infinite_bench/passkey",
+ # "ruler/4096",
+ "loogle/longdep_summarization",
+ "loogle/longdep_qa",
+ "loogle/shortdep_qa",
+ "loogle/shortdep_cloze",
+ # "zero_scrolls/default",
+ # "longbenchv2/0shot",
+ # "aime2024/aime2024",
+ # "aime2025/aime2025",
+ # "longbench/passage_retrieval_en",
+ # "mock_benchmark/reading_comprehension",
+ ]
+ num_samples = args.num_samples
+
+ # Build config maps
+ sparse_configs_map = {}
+ for name, full_config, classes in sparse_configs:
+ sparse_configs_map[name] = (classes, full_config)
+
+ return {
+ "models": models,
+ "tasks": tasks,
+ "sparse_configs": sparse_configs,
+ "sparse_configs_map": sparse_configs_map,
+ "gpu_ids": list(range(num_gpus)),
+ "num_samples": num_samples,
+ "objective_function": args.objective,
+
+ # Directories
+ "optimal_configs_dir": args.optimal_configs_dir,
+ "benchmark_results_dir": args.benchmark_results_dir,
+ "ray_results_dir": args.ray_results_dir,
+ "search_result_dir": os.path.join(args.ray_results_dir, "search_runs"),
+
+ # Phase 1 params
+ "search_timeout": args.search_timeout,
+ "search_max_new_tokens": args.search_max_new_tokens,
+ "search_max_context_length": args.search_max_context_length,
+ "search_max_requests": args.search_max_requests,
+ "force_search": args.force_search,
+
+ # Phase 2 params
+ "benchmark_timeout": args.benchmark_timeout,
+ "benchmark_max_new_tokens": args.benchmark_max_new_tokens,
+ "benchmark_max_context_length": args.benchmark_max_context_length,
+ "benchmark_max_requests": args.benchmark_max_requests,
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Two-phase benchmark system for sparse attention methods",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ # Phase control
+ parser.add_argument("--phase", type=int, choices=[1, 2],
+ help="Run specific phase only (1=search, 2=benchmark)")
+ parser.add_argument("--debug", action="store_true",
+ help="Debug mode with minimal configs")
+ parser.add_argument("--force-search", action="store_true",
+ help="Force re-run of Phase 1 even if configs exist")
+
+ # Objective function selection
+ parser.add_argument("--objective", type=str, default="default",
+ choices=list(OBJECTIVE_FUNCTIONS.keys()),
+ help="Objective function to use for optimization")
+
+ # Config run selection for Phase 2
+ parser.add_argument("--config-run", type=str,
+ help="Specific config run directory to use for Phase 2 (e.g., 'run_20240315_143022')")
+
+ # Directories
+ parser.add_argument("--optimal-configs-dir", default="./optimal_configs",
+ help="Directory for storing optimal configurations")
+ parser.add_argument("--benchmark-results-dir", default="./benchmark_results",
+ help="Directory for benchmark results")
+ parser.add_argument("--ray-results-dir", default="./ray_results",
+ help="Directory for Ray Tune results")
+
+ # Phase 1 arguments
+ phase1_group = parser.add_argument_group('Phase 1 - Config Search')
+ phase1_group.add_argument("--num-samples", type=int, default=50,
+ help="Number of samples per hyperparameter search")
+ phase1_group.add_argument("--search-timeout", type=int, default=900,
+ help="Timeout per search trial (seconds)")
+ phase1_group.add_argument("--search-max-new-tokens", type=int, default=20,
+ help="Max new tokens for search trials")
+ phase1_group.add_argument("--search-max-context-length", type=int, default=8192,
+ help="Max context length for search trials")
+ phase1_group.add_argument("--search-max-requests", type=int, default=5,
+ help="Max requests per search trial")
+
+ # Phase 2 arguments
+ phase2_group = parser.add_argument_group('Phase 2 - Benchmark Execution')
+ phase2_group.add_argument("--benchmark-timeout", type=int, default=3600,
+ help="Timeout per benchmark (seconds)")
+ phase2_group.add_argument("--benchmark-max-new-tokens", type=int, default=100,
+ help="Max new tokens for benchmarks")
+ phase2_group.add_argument("--benchmark-max-context-length", type=int, default=32000,
+ help="Max context length for benchmarks")
+ phase2_group.add_argument("--benchmark-max-requests", type=int, default=25,
+ help="Max requests per benchmark")
+
+ args = parser.parse_args()
+
+ # Build configuration
+ config = get_run_configuration(args)
+
+ print("Two-Phase Benchmark System")
+ print(f"Ray Version: {ray.__version__}, GPUs Available: {torch.cuda.device_count()}")
+ print(f"Mode: {'Debug' if args.debug else 'Production'}")
+
+ # Initialize Ray
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, log_to_driver=False,
+ runtime_env={"working_dir": str(root_path)})
+
+ start_time = time.time()
+
+ try:
+ # Phase 1: Config Search
+ if args.phase is None or args.phase == 1:
+ optimal_configs = run_phase_one(config)
+ # If running both phases, store the config directory
+ if args.phase is None and optimal_configs:
+ # Get the manager's results directory from any config
+ first_key = next(iter(optimal_configs))
+ manager = ConfigSearchManager(config)
+ config["config_run_dir"] = str(manager.results_dir)
+ else:
+ # Load existing configs for Phase 2
+ print("\nLoading existing optimal configurations...")
+ base_config_dir = Path(args.optimal_configs_dir)
+
+ # Find the most recent run directory or use specified one
+ if args.config_run:
+ config_dir = base_config_dir / args.config_run
+ if not config_dir.exists():
+ print(f"Error: Specified config run {config_dir} does not exist.")
+ return
+ else:
+ # Find the most recent run_* directory
+ run_dirs = sorted([d for d in base_config_dir.glob("run_*") if d.is_dir()])
+ if not run_dirs:
+ # Fallback to base directory for backward compatibility
+ config_dir = base_config_dir
+ if not any(config_dir.glob("*.json")):
+ print(f"Error: No optimal configs found. Run Phase 1 first.")
+ return
+ else:
+ config_dir = run_dirs[-1] # Most recent
+ print(f"Using most recent config run: {config_dir.name}")
+
+ # Create a dummy manager just for loading
+ manager = ConfigSearchManager(config)
+ manager.results_dir = config_dir # Override the directory
+
+ optimal_configs = {}
+ for config_file in config_dir.glob("*.json"):
+ if config_file.name.endswith("_trials.json"):
+ continue
+ try:
+ opt_config = manager._load_config(config_file)
+ key = config_file.stem
+ optimal_configs[key] = opt_config
+ except Exception as e:
+ print(f"Warning: Failed to load {config_file}: {e}")
+
+ print(f"Loaded {len(optimal_configs)} configurations from {config_dir}")
+ # Store which config run was used
+ config["config_run_dir"] = str(config_dir)
+
+ # Phase 2: Benchmark Execution
+ if args.phase is None or args.phase == 2:
+ if not optimal_configs:
+ print("\nError: No optimal configurations found. Run Phase 1 first.")
+ return
+
+ results = run_phase_two(config, optimal_configs)
+
+ # Print final summary
+ print("\n" + "="*80)
+ print("FINAL SUMMARY")
+ print("="*80)
+ if args.phase is None:
+ print(f"Phase 1: Found {len(optimal_configs)} optimal configurations")
+ if results:
+ print(f"Phase 2: Completed {results['phase2_results']['completed']}/{results['phase2_results']['total']} benchmarks")
+ print(f" Failed: {results['phase2_results']['failed']}")
+
+ except KeyboardInterrupt:
+ print("\nInterrupted by user")
+ except Exception as e:
+ print(f"\nError: {e}")
+ traceback.print_exc()
+ finally:
+ total_time = time.time() - start_time
+ print(f"\nTotal execution time: {total_time / 3600:.2f} hours ({total_time:.0f} seconds)")
+ ray.shutdown()
+ print("Done.")
+
+
+if __name__ == "__main__":
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ main()
diff --git a/benchmark/raytune/run_full_benchmark_interleave.py b/benchmark/raytune/run_full_benchmark_interleave.py
new file mode 100755
index 00000000..b8cb3b65
--- /dev/null
+++ b/benchmark/raytune/run_full_benchmark_interleave.py
@@ -0,0 +1,796 @@
+#!/usr/bin/env python3
+"""
+Full End-to-End Benchmark Execution and Optimizer for Sparse Attention Methods.
+
+This script performs a robust, two-stage process for each combination of
+model, benchmark, and sparse attention configuration:
+1. **Search**: It uses Ray Tune to run a hyperparameter search with lightweight
+ settings to quickly discover the optimal parameters.
+2. **Validate**: It takes the single best configuration found during the search
+ and runs a final, thorough benchmark with it to get a definitive score.
+
+## Usage Examples
+
+### Basic Usage
+```bash
+# Run full benchmark suite with all sparse attention configs
+python benchmark/raytune/run_full_benchmark.py
+
+# Run in debug mode (quick test with minimal configs)
+python benchmark/raytune/run_full_benchmark.py --debug
+
+# Run only dense baseline (no sparse attention)
+python benchmark/raytune/run_full_benchmark.py --dense-only
+
+# Print all available configurations without running
+python benchmark/raytune/run_full_benchmark.py --print-configs
+```
+
+### Advanced Usage
+```bash
+# Custom search parameters for faster exploration
+python benchmark/raytune/run_full_benchmark.py \
+ --search-timeout 600 \
+ --search-max-new-tokens 10 \
+ --search-max-context-length 4096 \
+ --num-samples 20
+
+# Custom validation parameters for thorough evaluation
+python benchmark/raytune/run_full_benchmark.py \
+ --validation-timeout 7200 \
+ --validation-max-new-tokens 200 \
+ --validation-max-context-length 64000 \
+ --validation-max-requests 50
+
+# Run with custom result directory suffix
+python benchmark/raytune/run_full_benchmark.py --result-dir-suffix "_experiment_v1"
+```
+
+## Command-Line Arguments
+
+### General Options
+- `--debug`: Run quick test configuration with minimal settings
+- `--num-samples`: Number of Ray Tune samples per optimization (default: 50)
+- `--dense-only`: Run only dense configuration without sparse attention
+- `--result-dir-suffix`: Suffix to add to result directory names
+- `--print-configs`: Print all sparse configurations and exit
+
+### Search Phase Parameters (for finding optimal configs)
+- `--search-timeout`: Timeout for each search trial in seconds (default: 1800)
+- `--search-max-new-tokens`: Max new tokens for search trials (default: 50)
+- `--search-max-context-length`: Max context length for search trials (default: 16384)
+- `--search-max-requests`: Max requests for search trials (default: 15)
+
+### Validation Phase Parameters (for final evaluation)
+- `--validation-timeout`: Timeout for final validation in seconds (default: 3600)
+- `--validation-max-new-tokens`: Max new tokens for validation (default: 100)
+- `--validation-max-context-length`: Max context length for validation (default: 32000)
+- `--validation-max-requests`: Max requests for validation (default: 25)
+
+## Sparse Attention Configurations
+
+The script evaluates 19 different sparse attention configurations across 3 sparsity levels:
+
+### 5% Sparsity
+- Random Sampling (2% sink + 2% window + 1% sampling)
+- Adaptive Sampling with Oracle Top-K
+- Adaptive Sampling with HashAttention Top-K
+- HashAttention Top-K
+- Oracle Top-P (75%)
+- Oracle Top-K
+
+### 10% Sparsity
+- Random Sampling (0.1% sink + 0.1% window + 10% sampling)
+- Adaptive Sampling with Oracle Top-K
+- Adaptive Sampling with HashAttention Top-K
+- HashAttention Top-K
+- Oracle Top-P (80%)
+- Oracle Top-K
+
+### 20% Sparsity
+- Random Sampling (2% sink + 2% window + 20% sampling)
+- Adaptive Sampling with Oracle Top-K
+- Adaptive Sampling with HashAttention Top-K
+- HashAttention Top-K
+- Oracle Top-P (95%)
+- Oracle Top-K
+
+## Benchmarks
+
+The script runs the following benchmark tasks:
+- **InfiniteBench**: passkey task for extreme long context
+- **Ruler**: 4096 context length evaluation
+- **Loogle**: longdep_summarization, longdep_qa, shortdep_qa, shortdep_cloze
+- **ZeroScrolls**: default configuration
+- **LongBenchv2**: 0-shot evaluation
+- **AIME2024/2025**: Mathematical reasoning tasks
+- **LongBench**: passage_retrieval_en
+- **Mock Benchmark**: reading_comprehension (for testing)
+
+## Output Structure
+
+Results are saved in two directories:
+- `./search_results/`: Ray Tune optimization results
+- `./validation_results/`: Final validation results for best configurations
+
+Each run produces:
+- Raw benchmark results (CSV)
+- Micro metrics (JSONL) with attention errors and density
+- Final summary (JSON) with all scores and best configurations
+
+## Notes
+
+- Requires GPU(s) with CUDA support
+- HashAttention weights file should be available at the specified path
+- Ray Tune must be installed: `pip install "ray[tune]" hyperopt`
+- The script automatically handles resumability for interrupted runs
+
+To add new models, benchmarks, or masker presets, modify the `get_run_configurations` function.
+"""
+import argparse
+import json
+import logging
+import os
+import sys
+import time
+import traceback
+from datetime import datetime
+from pathlib import Path
+
+# --- Path Setup ---
+current_dir = Path(__file__).parent
+root_path = current_dir.parent.parent
+sys.path.extend([str(current_dir), str(root_path)])
+os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + f":{current_dir}:{root_path}"
+
+# --- Core Imports ---
+import torch
+from benchmark.executor import BenchmarkExecutor
+from benchmark.executor_config import AdapterConfig, BenchmarkConfig, BenchmarkResult
+from optimizer_factory import create_optimizer
+
+# --- Masker Config Imports ---
+from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
+from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
+ LocalMaskerConfig,
+ SinkMaskerConfig,
+ OracleTopKConfig,
+ OracleTopPMaskerConfig,
+ HashAttentionTopKMaskerConfig,
+)
+from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import (
+ AdaptiveSamplingMaskerConfig,
+ RandomSamplingMaskerConfig,
+ MagicPigConfig,
+)
+
+# --- Ray Tune Imports ---
+try:
+ import ray
+ from ray import tune
+ from ray.tune.schedulers import ASHAScheduler
+ from ray.tune.search.hyperopt import HyperOptSearch
+ from ray.tune.stopper import TrialPlateauStopper
+except ImportError:
+ print("Error: Ray Tune is required. Install with: pip install \"ray[tune]\" hyperopt")
+ sys.exit(1)
+
+
+class ComprehensiveBenchmarkRunner:
+ """Runs a benchmark for a model and sparse attention config, returning a score."""
+
+ def __init__(self, config: dict, verbose: bool = False):
+ self.config = config
+ self.executor = BenchmarkExecutor(
+ gpu_ids=config["gpu_ids"],
+ max_concurrent_runs=config["max_concurrent_runs"],
+ base_result_dir=config["result_dir"],
+ enable_resumability=True,
+ required_result_files=["raw_results.csv"],
+ timeout_per_benchmark=config["timeout_per_benchmark"],
+ verbose=verbose,
+ )
+ self.adapter_config = AdapterConfig(
+ adapter_name="huggingface",
+ model_kwargs={
+ "torch_dtype": torch.bfloat16,
+ "attn_implementation": "flash_attention_2",
+ },
+ tokenizer_kwargs={"padding_side": "left"},
+ )
+ self.generation_kwargs = {"max_new_tokens": config["max_new_tokens"], "do_sample": False}
+ self.request_kwargs = {
+ "max_context_length": config["max_context_length"],
+ "max_requests": config["max_requests"],
+ }
+ self.results_cache = {}
+
+ def _extract_micro_metrics(self, result_dir: Path) -> dict:
+ import math
+ micro_metrics_file = result_dir / "micro_metrics.jsonl"
+ if not micro_metrics_file.exists():
+ # For dense configuration, micro_metrics.jsonl won't exist since no sparse attention is used
+ # Return default values: 0 error (perfect) and 1.0 density (fully dense)
+ print(f" Note: micro_metrics.jsonl not found in {result_dir}, using dense defaults")
+ return {"attention_error": 0.0, "density": 1.0}
+
+ errors, densities = [], []
+ with open(micro_metrics_file, "r") as f:
+ for line in f:
+ try:
+ entry = json.loads(line.strip())
+ metric, value = entry.get("metric"), entry.get("value")
+ if value is not None and not (isinstance(value, float) and math.isnan(value)):
+ if metric == "research_attention_output_error": errors.append(float(value))
+ elif metric == "research_attention_density": densities.append(float(value))
+ except (json.JSONDecodeError, ValueError, TypeError): continue
+ return {"attention_error": sum(errors) / len(errors) if errors else 1.0, "density": sum(densities) / len(densities) if densities else 1.0}
+
+ def __call__(self, attention_config, task_name: str, model_name: str) -> float:
+ config_key = f"{model_name}_{task_name}_{hash(str(attention_config))}"
+ if config_key in self.results_cache: return self.results_cache[config_key]
+
+ try:
+ if "/" in task_name:
+ benchmark_name, subset_name = task_name.split("/", 1)
+ else:
+ benchmark_name, subset_name = task_name, None
+
+ benchmark_config = BenchmarkConfig(
+ benchmark_name=benchmark_name,
+ subsets=[subset_name] if subset_name else None
+ )
+
+ results = self.executor.run_benchmark_matrix(
+ model_names=[model_name],
+ sparse_attention_configs=[("optimized", attention_config)],
+ benchmark_configs=[benchmark_config],
+ adapter_config=self.adapter_config,
+ generation_kwargs=self.generation_kwargs,
+ request_kwargs=self.request_kwargs,
+ )
+
+ if results.progress.completed_stubs > 0 and hasattr(results, "individual_results"):
+ completed = [r for r in results.individual_results if isinstance(r, BenchmarkResult)]
+ if completed:
+ result_dir = Path(completed[0].stub.result_dir)
+ metrics = self._extract_micro_metrics(result_dir)
+ error, density = metrics["attention_error"], metrics["density"]
+
+ # For dense configuration (density=1.0, error=0.0), use a simple score
+ if density == 1.0 and error == 0.0:
+ # Dense baseline: use benchmark accuracy metrics instead of sparse metrics
+ score = 0.1 # Small baseline score for dense
+ else:
+ # For sparse configurations: penalize both error and excessive density
+ score = error + 0.1 * density + (5.0 if density > 0.5 else 0.0)
+
+ self.results_cache[config_key] = score
+ return score
+ except Exception as e:
+ print(f" ✗ Error in benchmark runner: {e}")
+ traceback.print_exc()
+
+ print(f" Warning: Could not compute a valid score for {model_name} on {task_name}. Returning penalty.")
+ self.results_cache[config_key] = 5.0
+ return 5.0
+
+# Helper functions for generating configuration names
+def get_adaptive_config_name(sink_size, window_size, heavy_size, base_rate_sampling, epsilon, delta):
+ return f"adaptive_sampling.sink_{sink_size}_window_{window_size}_heavy_{heavy_size}_base_{base_rate_sampling}_epsilon_{epsilon}_delta_{delta}"
+
+def get_adaptive_hat_config_name(sink_size, window_size, heavy_size, base_rate_sampling, epsilon, delta):
+ return f"adaptive_sampling_hat.sink_{sink_size}_window_{window_size}_heavy_{heavy_size}_base_{base_rate_sampling}_epsilon_{epsilon}_delta_{delta}"
+
+def get_oracle_top_p_config_name(sink_size, window_size, top_p):
+ return f"oracle_top_p_{top_p}.sink_{sink_size}_window_{window_size}"
+
+def get_oracle_top_k_config_name(sink_size, window_size, top_k):
+ return f"oracle_top_k_{top_k}.sink_{sink_size}_window_{window_size}"
+
+def get_hashattention_config_name(sink_size, window_size, top_k):
+ return f"hashattention.sink_{sink_size}_window_{window_size}_top_k_{top_k}"
+
+def get_random_sampling_config_name(sink_size, window_size, sampling_rate):
+ return f"random_sampling.sink_{sink_size}_window_{window_size}_sampling_rate_{sampling_rate}"
+
+def get_run_configurations(args: argparse.Namespace) -> dict:
+ """Defines the complete configuration for the optimization run."""
+ num_gpus = torch.cuda.device_count()
+
+ # Get the HashAttention weights file path
+ machine_key = "ubuntu"
+ weight_file = f"/home/{machine_key}/HashAttention-1.0/artifacts/llama3.1-8b-patch.64K.v1.hat_weights.pkl"
+
+ # If weight file doesn't exist, try a fallback path
+ if not os.path.exists(weight_file):
+ weight_file = "./hat_weights.pkl" # fallback to local file
+ print(f"Warning: HashAttention weights not found at expected path, using {weight_file}")
+
+ # Generate all sparse configurations from the provided script
+ sparse_configs = []
+
+ # Dense baseline
+ sparse_configs.append(("dense", None))
+
+ # ==================== 5% sparsity configs =================
+ # Random sampling 5%
+ sparse_configs.append((get_random_sampling_config_name(0.02, 0.02, 0.01), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.02),
+ LocalMaskerConfig(window_size=0.02),
+ RandomSamplingMaskerConfig(sampling_rate=0.01)
+ ])))
+
+ # Adaptive sampling with oracle top k 5%
+ sparse_configs.append((get_adaptive_config_name(0.001, 0.001, 0.02, 0.01, 0.1, 0.1), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ OracleTopKConfig(heavy_size=0.02),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.01, epsilon=0.1, delta=0.1, init_offset=0.001, local_offset=0.001)
+ ])))
+
+ # Adaptive sampling with HAT top k 5%
+ sparse_configs.append((get_adaptive_hat_config_name(0.01, 0.01, 0.02, 0.01, 0.25, 0.25), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.01),
+ LocalMaskerConfig(window_size=0.01),
+ HashAttentionTopKMaskerConfig(heavy_size=0.02, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.01, epsilon=0.25, delta=0.25, init_offset=0.001, local_offset=0.001)
+ ])))
+
+ # HAT top k 5%
+ sparse_configs.append((get_hashattention_config_name(0.005, 0.005, 0.04), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.005),
+ LocalMaskerConfig(window_size=0.005),
+ HashAttentionTopKMaskerConfig(heavy_size=0.04, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ ])))
+
+ # Oracle top p 5%
+ sparse_configs.append((get_oracle_top_p_config_name(0.001, 0.001, 0.75), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ OracleTopPMaskerConfig(top_p=0.75)
+ ])))
+
+ # Oracle top k 5%
+ sparse_configs.append((get_oracle_top_k_config_name(0.005, 0.005, 0.04), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.005),
+ LocalMaskerConfig(window_size=0.005),
+ OracleTopKConfig(heavy_size=0.04)
+ ])))
+
+ # ==================== 10% sparsity configs =================
+ # Random sampling 10%
+ sparse_configs.append((get_random_sampling_config_name(0.001, 0.001, 0.1), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ RandomSamplingMaskerConfig(sampling_rate=0.1)
+ ])))
+
+ # Adaptive sampling with oracle top k 10%
+ sparse_configs.append((get_adaptive_config_name(0.001, 0.001, 0.05, 0.05, 0.25, 0.25), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ OracleTopKConfig(heavy_size=0.05),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.05, epsilon=0.25, delta=0.25, init_offset=0.001, local_offset=0.001)
+ ])))
+
+ # Adaptive sampling with HAT top k 10%
+ sparse_configs.append((get_adaptive_hat_config_name(0.001, 0.001, 0.05, 0.05, 0.4, 0.4), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ HashAttentionTopKMaskerConfig(heavy_size=0.05, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.05, epsilon=0.4, delta=0.4, init_offset=0.001, local_offset=0.001)
+ ])))
+
+ # HAT top k 10%
+ sparse_configs.append((get_hashattention_config_name(0.001, 0.001, 0.09), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ HashAttentionTopKMaskerConfig(heavy_size=0.09, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ ])))
+
+ # Oracle top p 10%
+ sparse_configs.append((get_oracle_top_p_config_name(0.02, 0.02, 0.8), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.02),
+ LocalMaskerConfig(window_size=0.02),
+ OracleTopPMaskerConfig(top_p=0.8)
+ ])))
+
+ # Oracle top k 10%
+ sparse_configs.append((get_oracle_top_k_config_name(0.001, 0.001, 0.1), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.001),
+ LocalMaskerConfig(window_size=0.001),
+ OracleTopKConfig(heavy_size=0.1)
+ ])))
+
+ # ==================== 20% sparsity configs =================
+ # Random sampling 20%
+ sparse_configs.append((get_random_sampling_config_name(0.02, 0.02, 0.2), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.02),
+ LocalMaskerConfig(window_size=0.02),
+ RandomSamplingMaskerConfig(sampling_rate=0.2)
+ ])))
+
+ # Adaptive sampling with oracle top k 20%
+ sparse_configs.append((get_adaptive_config_name(0.02, 0.02, 0.05, 0.1, 0.3, 0.3), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.02),
+ LocalMaskerConfig(window_size=0.02),
+ OracleTopKConfig(heavy_size=0.05),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.1, epsilon=0.3, delta=0.3, init_offset=0.02, local_offset=0.02)
+ ])))
+
+ # Adaptive sampling with HAT top k 20%
+ sparse_configs.append((get_adaptive_hat_config_name(0.005, 0.005, 0.1, 0.1, 0.25, 0.25), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.005),
+ LocalMaskerConfig(window_size=0.005),
+ HashAttentionTopKMaskerConfig(heavy_size=0.1, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ AdaptiveSamplingMaskerConfig(base_rate_sampling=0.1, epsilon=0.25, delta=0.25, init_offset=0.005, local_offset=0.005)
+ ])))
+
+ # HAT top k 20%
+ sparse_configs.append((get_hashattention_config_name(0.005, 0.005, 0.19), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.005),
+ LocalMaskerConfig(window_size=0.005),
+ HashAttentionTopKMaskerConfig(heavy_size=0.19, hat_bits=32, hat_mlp_layers=3, hat_mlp_hidden_size=128, hat_mlp_activation="silu", hat_weight_file=weight_file, hat_weights=None),
+ ])))
+
+ # Oracle top p 20%
+ sparse_configs.append((get_oracle_top_p_config_name(0.01, 0.01, 0.95), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.01),
+ LocalMaskerConfig(window_size=0.01),
+ OracleTopPMaskerConfig(top_p=0.95)
+ ])))
+
+ # Oracle top k 20%
+ sparse_configs.append((get_oracle_top_k_config_name(0.005, 0.005, 0.19), ResearchAttentionConfig(masker_configs=[
+ SinkMaskerConfig(sink_size=0.005),
+ LocalMaskerConfig(window_size=0.005),
+ OracleTopKConfig(heavy_size=0.19)
+ ])))
+
+ # For Ray Tune optimization, we'll create a smaller subset for the search
+ # and then use all configs for final validation
+ if args.dense_only:
+ # Only run dense configuration
+ selected_sparse_configs = [("dense", None)]
+ elif args.debug:
+ # In debug mode, just test a few configs
+ selected_sparse_configs = sparse_configs[:3]
+ else:
+ # In production, we might want to optimize across all configs or a subset
+ # For now, let's use the full set
+ selected_sparse_configs = sparse_configs
+
+ # Convert configs to optimizer-compatible format
+ # The optimizer expects classes, not instances
+ masker_config_presets = {} # For optimizer (classes)
+ sparse_attention_configs = {} # For validation (full configs)
+
+ for name, config in selected_sparse_configs:
+ if config is not None:
+ # Extract the class from each config instance for the optimizer
+ masker_classes = []
+ for masker_config in config.masker_configs:
+ masker_classes.append(type(masker_config))
+ masker_config_presets[name] = masker_classes
+ sparse_attention_configs[name] = config # Store full config for validation
+ else:
+ masker_config_presets[name] = None
+ sparse_attention_configs[name] = None
+
+ test_suites = {"default": list(masker_config_presets.keys()), "debug": list(masker_config_presets.keys())[:3]}
+
+ # --- Decouple Search and Validation Parameters ---
+ if args.debug:
+ # Use smaller, faster settings for the search phase in debug mode
+ search_params = {
+ "timeout_per_benchmark": 300, "max_new_tokens": 10,
+ "max_context_length": 4096, "max_requests": 2,
+ }
+ # Use slightly more thorough settings for debug validation
+ validation_params = {
+ "timeout_per_benchmark": 600, "max_new_tokens": 30,
+ "max_context_length": 16384, "max_requests": 5,
+ }
+ base_config = {
+ "models": ["meta-llama/Llama-3.1-8B-Instruct"],
+ "benchmarks": [
+ "loogle/shortdep_qa", # Quick benchmark for debug
+ ],
+ "masker_presets": {p: masker_config_presets[p] for p in test_suites["debug"]},
+ "num_samples": 8,
+ }
+ else:
+ # For production, use specific flags for each stage
+ search_params = {
+ "timeout_per_benchmark": args.search_timeout, "max_new_tokens": args.search_max_new_tokens,
+ "max_context_length": args.search_max_context_length, "max_requests": args.search_max_requests,
+ }
+ validation_params = {
+ "timeout_per_benchmark": args.validation_timeout, "max_new_tokens": args.validation_max_new_tokens,
+ "max_context_length": args.validation_max_context_length, "max_requests": args.validation_max_requests,
+ }
+ base_config = {
+ "models": ["meta-llama/Llama-3.1-8B-Instruct"],
+ "benchmarks": [
+ # InfiniteBench
+ "infinite_bench/passkey",
+ # Ruler
+ "ruler/4096",
+ # Loogle
+ "loogle/longdep_summarization",
+ "loogle/longdep_qa",
+ "loogle/shortdep_qa",
+ "loogle/shortdep_cloze",
+ # ZeroScrolls
+ "zero_scrolls/default",
+ # LongBenchv2
+ "longbenchv2/0shot",
+ # AIME benchmarks
+ "aime2024/aime2024",
+ "aime2025/aime2025",
+ # LongBench
+ "longbench/passage_retrieval_en",
+ # Mock benchmark for testing
+ "mock_benchmark/reading_comprehension",
+ ],
+ "masker_presets": {p: masker_config_presets[p] for p in test_suites["default"]},
+ "num_samples": args.num_samples,
+ }
+
+ # Combine into a final, structured configuration
+ return {
+ **base_config,
+ "search_params": search_params,
+ "validation_params": validation_params,
+ "gpu_ids": list(range(num_gpus)),
+ "max_concurrent_runs": num_gpus,
+ "result_dir": f"./search_results{args.result_dir_suffix}", # Base directory for the search phase
+ "detailed_result_dir": f"./validation_results{args.result_dir_suffix}", # Base directory for the validation phase
+ "sparse_configs": selected_sparse_configs, # Store the full list for reference
+ "sparse_attention_configs": sparse_attention_configs, # Store full config objects for validation
+ }
+
+def get_ray_tune_components(config: dict) -> dict:
+ scheduler = ASHAScheduler(time_attr="training_iteration", max_t=20, grace_period=5, reduction_factor=2)
+ search_alg = HyperOptSearch(metric="combined_score", mode="min", n_initial_points=max(1, config["num_samples"] // 4))
+ stopper = TrialPlateauStopper(metric="combined_score", std=0.005, num_results=5, grace_period=8, mode="min")
+ return {"scheduler": scheduler, "search_alg": search_alg, "stop": stopper}
+
+def create_optimization_objective(config: dict, model_name: str, task_name: str, optimizer):
+ """Creates the objective function that Ray Tune will execute for each trial."""
+ def objective(trial_config: dict):
+ # The worker always uses the lighter search parameters for speed
+ worker_config = {**config, **config["search_params"]}
+ worker_config["gpu_ids"] = [0]
+ worker_config["max_concurrent_runs"] = 1
+
+ benchmark_runner = ComprehensiveBenchmarkRunner(worker_config)
+ attention_config = optimizer.create_config_from_params(trial_config)
+ score = benchmark_runner(attention_config, task_name, model_name)
+ return {"combined_score": score}
+ return objective
+
+def run_optimization_and_validation(model_name: str, benchmark_task: str, preset_name: str, masker_configs: list, config: dict, full_sparse_config=None) -> dict:
+ """Runs the two-stage Search-then-Validate process for one combination."""
+ print(f"\n--- Running: {model_name} | {benchmark_task} | {preset_name} ---")
+
+ # Handle dense configuration (no masker configs)
+ if masker_configs is None or preset_name == "dense":
+ print(" Running dense configuration (no optimization needed)...")
+ validation_config = {**config, **config["validation_params"]}
+ validation_config["result_dir"] = os.path.join(config["detailed_result_dir"], preset_name)
+
+ validator = ComprehensiveBenchmarkRunner(validation_config, verbose=True)
+ start_time = time.time()
+ print(f" Running validation benchmark: {model_name} on {benchmark_task}...")
+ final_score = validator(full_sparse_config, benchmark_task, model_name) # Use full config
+ runtime = time.time() - start_time
+ print(f" Validation benchmark completed in {runtime:.1f}s")
+ print(f" ✓ Final validation score: {final_score:.4f}")
+
+ return {
+ "best_search_score": final_score,
+ "final_validation_score": final_score,
+ "best_config": None,
+ "best_params": {},
+ "num_trials": 1,
+ }
+
+ # Stage 1: Search using the lighter 'search_params'
+ print(" 1. Searching for optimal configuration...")
+ try:
+ optimizer = create_optimizer(masker_configs)
+ objective = create_optimization_objective(config, model_name, benchmark_task, optimizer)
+ tune_components = get_ray_tune_components(config)
+ sanitized_task_name = benchmark_task.replace('/', '_')
+
+ analysis = tune.run(
+ objective, config=optimizer.create_search_space(benchmark_task),
+ num_samples=config["num_samples"], metric="combined_score", mode="min",
+ resources_per_trial={"CPU": 1, "GPU": 1.0},
+ name=f"opt_{model_name.split('/')[-1]}_{sanitized_task_name}_{preset_name}",
+ storage_path=config["storage_path"], verbose=1, resume=False,
+ max_concurrent_trials=config["max_concurrent_runs"], **tune_components
+ )
+ best_trial = analysis.get_best_trial("combined_score", "min", "last")
+ best_config_obj = optimizer.create_config_from_params(best_trial.config)
+ best_search_score = best_trial.last_result['combined_score']
+ print(f" ✓ Best search score: {best_search_score:.4f}")
+ except Exception as e:
+ print(f" ✗ Search stage failed: {e}"); traceback.print_exc()
+ return {"error": f"Search failed: {e}"}
+
+ # Stage 2: Validate using the more thorough 'validation_params'
+ print(" 2. Validating the best configuration...")
+ try:
+ # Create a new config for validation by merging base and validation params
+ validation_config = {**config, **config["validation_params"]}
+ validation_config["result_dir"] = os.path.join(config["detailed_result_dir"], preset_name)
+
+ validator = ComprehensiveBenchmarkRunner(validation_config, verbose=True)
+ start_time = time.time()
+ print(f" Running validation benchmark: {model_name} on {benchmark_task}...")
+ final_score = validator(best_config_obj, benchmark_task, model_name)
+ runtime = time.time() - start_time
+ print(f" Validation benchmark completed in {runtime:.1f}s")
+ print(f" ✓ Final validation score: {final_score:.4f}")
+ except Exception as e:
+ print(f" ✗ Validation stage failed: {e}"); traceback.print_exc()
+ return {"error": f"Validation failed: {e}"}
+
+ return {
+ "best_search_score": best_search_score,
+ "final_validation_score": final_score,
+ "best_config": best_config_obj,
+ "best_params": best_trial.config,
+ "num_trials": len(analysis.trials),
+ }
+
+def run_optimization_matrix(config: dict) -> tuple[dict, str]:
+ print("Starting Full Benchmark Optimization and Validation Matrix"); print("=" * 80)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ storage_path = os.path.abspath(f"./ray_results_{timestamp}")
+ config["storage_path"] = storage_path
+ print(f"Ray Tune results will be saved to: {storage_path}")
+
+ all_results = {}
+ for model_name in config["models"]:
+ all_results[model_name] = {}
+ print(f"\nModel: {model_name}"); print("-" * 60)
+ for benchmark_task in config["benchmarks"]:
+ all_results[model_name][benchmark_task] = {}
+ for preset_name, masker_configs in config["masker_presets"].items():
+ full_sparse_config = config.get("sparse_attention_configs", {}).get(preset_name)
+ combo_result = run_optimization_and_validation(model_name, benchmark_task, preset_name, masker_configs, config, full_sparse_config)
+ all_results[model_name][benchmark_task][preset_name] = combo_result
+ return all_results, storage_path
+
+def print_summary(results: dict):
+ print("\n" + "=" * 80); print("--- FINAL BENCHMARK SUMMARY ---"); print("=" * 80)
+ best_overall_score, best_overall_config = float("inf"), {}
+ for model_name, model_results in results.items():
+ print(f"\nModel: {model_name}"); print("-" * 70)
+ for benchmark_task, task_results in model_results.items():
+ print(f"\n Benchmark: {benchmark_task}")
+ for masker_preset, result in task_results.items():
+ if "error" in result:
+ print(f" {masker_preset:25s}: FAILED ({result['error']})"); continue
+ score = result.get("final_validation_score", float("inf"))
+ search_score = result.get("best_search_score", float("inf"))
+ print(f" {masker_preset:25s}: {score:.4f} (Search score: {search_score:.4f})")
+ if score < best_overall_score:
+ best_overall_score = score
+ best_overall_config = {"model": model_name, "benchmark": benchmark_task, "masker": masker_preset, "score": score, "params": result.get("best_params")}
+ print("\n" + "--- Best Overall Configuration ---")
+ if best_overall_config:
+ for key, value in best_overall_config.items(): print(f" {key.capitalize():<12}: {value}")
+ else: print(" No successful runs completed.")
+ print("-" * 32)
+
+def define_cli_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Full benchmark optimization and validation runner.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ # General arguments
+ parser.add_argument("--debug", action="store_true", help="Run a quick test configuration, ignoring other flags.")
+ parser.add_argument("--num-samples", type=int, default=50, help="Number of Ray Tune samples per optimization search.")
+ parser.add_argument("--dense-only", action="store_true", help="Run only dense configuration without sparse attention.")
+ parser.add_argument("--result-dir-suffix", type=str, default="", help="Suffix to add to result directory names.")
+ parser.add_argument("--print-configs", action="store_true", help="Print all sparse configurations and exit.")
+
+ # Search-specific arguments
+ search_group = parser.add_argument_group('Search Parameters (for finding the best config)')
+ search_group.add_argument("--search-timeout", type=int, default=1800, help="Timeout for each search trial.")
+ search_group.add_argument("--search-max-new-tokens", type=int, default=50, help="Max new tokens for search trials.")
+ search_group.add_argument("--search-max-context-length", type=int, default=16384, help="Max context length for search trials.")
+ search_group.add_argument("--search-max-requests", type=int, default=15, help="Max requests for search trials.")
+
+ # Validation-specific arguments
+ validation_group = parser.add_argument_group('Validation Parameters (for the final run with the best config)')
+ validation_group.add_argument("--validation-timeout", type=int, default=3600, help="Timeout for the final validation run.")
+ validation_group.add_argument("--validation-max-new-tokens", type=int, default=100, help="Max new tokens for the final validation run.")
+ validation_group.add_argument("--validation-max-context-length", type=int, default=32000, help="Max context length for the final validation run.")
+ validation_group.add_argument("--validation-max-requests", type=int, default=25, help="Max requests for the final validation run.")
+
+ return parser.parse_args()
+
+def main():
+ args = define_cli_args()
+ config = get_run_configurations(args)
+
+ # Print configurations if requested
+ if args.print_configs:
+ print("\n" + "=" * 80)
+ print("SPARSE ATTENTION CONFIGURATIONS")
+ print("=" * 80)
+ for i, (name, cfg) in enumerate(config.get("sparse_configs", [])):
+ print(f"\n{i+1}. {name}")
+ if cfg is not None:
+ print(" Maskers:")
+ for masker in cfg.masker_configs:
+ print(f" - {masker.__class__.__name__}")
+ else:
+ print(" Dense (no sparse attention)")
+ print("\n" + "=" * 80)
+ print(f"Total configurations: {len(config.get('sparse_configs', []))}")
+ print("=" * 80)
+ return
+
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True, log_to_driver=False, runtime_env={"working_dir": str(root_path)})
+
+ mode = "Quick Test" if args.debug else "Full Production"
+ print(f"Starting {mode} Optimization & Validation...")
+ print(f"Ray Version: {ray.__version__}, GPUs Available: {torch.cuda.device_count()}")
+
+ # Print execution summary
+ print("\n" + "=" * 80)
+ print("EXECUTION SUMMARY")
+ print("=" * 80)
+ print(f"Models ({len(config['models'])}):")
+ for model in config['models']:
+ print(f" - {model}")
+ print(f"\nBenchmarks ({len(config['benchmarks'])}):")
+ for benchmark in config['benchmarks']:
+ print(f" - {benchmark}")
+ print(f"\nSparse Configurations ({len(config['masker_presets'])}):")
+ for i, preset in enumerate(list(config['masker_presets'].keys())[:5]):
+ print(f" - {preset}")
+ if len(config['masker_presets']) > 5:
+ print(f" ... and {len(config['masker_presets']) - 5} more")
+
+ total_combinations = len(config['models']) * len(config['benchmarks']) * len(config['masker_presets'])
+ print(f"\nTotal combinations to run: {total_combinations}")
+ print("=" * 80 + "\n")
+
+ start_time = time.time()
+ try:
+ results, storage_path = run_optimization_matrix(config)
+ print_summary(results)
+ print(f"\nDetailed validation results saved to: {config['detailed_result_dir']}")
+ print(f"View optimization progress with: tensorboard --logdir {storage_path}")
+
+ results_file = Path(storage_path) / "final_summary.json"
+ def json_serializer(obj): return str(obj)
+
+ print(f"Saving summary to: {results_file}")
+ # Create directory if it doesn't exist
+ results_file.parent.mkdir(parents=True, exist_ok=True)
+ with open(results_file, "w") as f: json.dump(results, f, indent=2, default=json_serializer)
+ print("Summary saved successfully.")
+ except KeyboardInterrupt:
+ print("\nWarning: Optimization interrupted by user.")
+ except Exception as e:
+ print(f"\n✗ An unexpected error occurred: {e}"); traceback.print_exc()
+ finally:
+ total_time = time.time() - start_time
+ print(f"\nTotal script time: {total_time / 3600:.2f} hours ({total_time:.0f} seconds)")
+ ray.shutdown()
+ print("Script finished.")
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
+ main()
\ No newline at end of file
diff --git a/benchmark/raytune/run_ray_benchmarks.py b/benchmark/raytune/run_ray_benchmarks.py
new file mode 100755
index 00000000..dc7dd443
--- /dev/null
+++ b/benchmark/raytune/run_ray_benchmarks.py
@@ -0,0 +1,691 @@
+#!/usr/bin/env python3
+"""
+Ray-based parallel benchmark runner with efficient resource management.
+
+This implementation uses Ray for:
+- Distributed execution with automatic resource management
+- Efficient model caching through Ray actors
+- Built-in fault tolerance and progress tracking
+- Optimal task scheduling to minimize model loading
+
+Usage:
+ python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531
+ python benchmark/raytune/run_ray_benchmarks.py --config-run run_20250818_203531 --resume
+"""
+
+import argparse
+import json
+import logging
+import os
+import sys
+import time
+import torch
+from pathlib import Path
+from datetime import datetime
+from typing import Dict, List, Tuple, Optional, Any
+from dataclasses import dataclass, asdict
+from collections import defaultdict
+import traceback
+
+# Path setup
+current_dir = Path(__file__).parent
+root_path = current_dir.parent.parent
+sys.path.extend([str(current_dir), str(root_path)])
+
+import ray
+from ray.util.queue import Queue as RayQueue
+from ray.util.actor_pool import ActorPool
+
+from benchmark.executor_config import AdapterConfig
+from benchmark.benchmark_registry import create_benchmark_instance
+from sparse_attention_hub.adapters.huggingface import ModelAdapterHF
+from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
+from sparse_attention_hub.metric_logging.logger import MicroMetricLogger
+
+# Import all masker configs
+from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import *
+from sparse_attention_hub.sparse_attention.research_attention.maskers.sampling.implementations import *
+
+
+@dataclass
+class BenchmarkTask:
+ """Single benchmark task to execute."""
+ task_id: str
+ model_name: str
+ task_name: str
+ masker_name: str
+ sparse_config: Optional[Dict] # JSON-serializable config
+ result_dir: str
+ generation_kwargs: Dict[str, Any]
+ request_kwargs: Dict[str, Any]
+
+
+@dataclass
+class BenchmarkResult:
+ """Result from a benchmark execution."""
+ task_id: str
+ success: bool
+ metrics: Optional[Dict[str, Any]] = None
+ error: Optional[str] = None
+ execution_time: float = 0.0
+ gpu_id: Optional[int] = None
+ model_load_time: float = 0.0
+
+
+@ray.remote(num_gpus=1)
+class GPUBenchmarkActor:
+ """Ray actor that runs benchmarks on a specific GPU with fresh model initialization for each task."""
+
+ def __init__(self, actor_id: int, adapter_config: Dict):
+ self.actor_id = actor_id
+ self.adapter_config = adapter_config
+
+ # Ray sets CUDA_VISIBLE_DEVICES for us, so GPU 0 is always the correct device
+ self.gpu_id = 0 # Always use device 0 in the actor's visible GPU space
+ torch.cuda.set_device(self.gpu_id)
+
+ # Get actual GPU info for logging
+ gpu_name = torch.cuda.get_device_name(self.gpu_id)
+ logging.info(f"Actor {actor_id} initialized on GPU {gpu_name}")
+
+ def _reconstruct_sparse_config(self, config_data: Optional[Dict]) -> Optional[ResearchAttentionConfig]:
+ """Reconstruct ResearchAttentionConfig from JSON data."""
+ if not config_data or not config_data.get("masker_configs"):
+ return None
+
+ config_class_map = {
+ "LocalMaskerConfig": LocalMaskerConfig,
+ "SinkMaskerConfig": SinkMaskerConfig,
+ "OracleTopKConfig": OracleTopKConfig,
+ "OracleTopPMaskerConfig": OracleTopPMaskerConfig,
+ "HashAttentionTopKMaskerConfig": HashAttentionTopKMaskerConfig,
+ "AdaptiveSamplingMaskerConfig": AdaptiveSamplingMaskerConfig,
+ "RandomSamplingMaskerConfig": RandomSamplingMaskerConfig,
+ "MagicPigConfig": MagicPigConfig,
+ }
+
+ masker_configs = []
+ for masker_data in config_data["masker_configs"]:
+ config_class = config_class_map.get(masker_data["type"])
+ if config_class:
+ try:
+ params = masker_data.get("params", {})
+ masker_configs.append(config_class(**params))
+ except Exception as e:
+ logging.warning(f"Failed to create {masker_data['type']}: {e}")
+
+ return ResearchAttentionConfig(masker_configs=masker_configs) if masker_configs else None
+
+ def _create_fresh_model(self, model_name: str, sparse_config: Optional[Dict],
+ masker_name: str, task_name: str) -> Tuple[ModelAdapterHF, float]:
+ """Create a fresh model from scratch for each task.
+
+ This ensures no state leakage between tasks with different sparse configs.
+ Returns (model, load_time).
+ """
+ logging.info(f"Actor {self.actor_id}: Creating fresh model for {task_name} with {masker_name}")
+
+ # Clear any GPU cache before loading
+ torch.cuda.empty_cache()
+
+ start_time = time.time()
+
+ # Reconstruct sparse config
+ sparse_attention_config = self._reconstruct_sparse_config(sparse_config)
+
+ # Create completely fresh model instance
+ adapter = ModelAdapterHF(
+ model_name=model_name,
+ sparse_attention_config=sparse_attention_config,
+ model_kwargs=self.adapter_config["model_kwargs"],
+ tokenizer_kwargs=self.adapter_config["tokenizer_kwargs"]
+ )
+
+ load_time = time.time() - start_time
+ logging.info(f"Actor {self.actor_id}: Model created in {load_time:.1f}s")
+
+ return adapter, load_time
+
+ def run_benchmark(self, task: BenchmarkTask) -> BenchmarkResult:
+ """Execute a single benchmark task."""
+ total_start = time.time()
+
+ adapter = None
+ try:
+ # Create fresh model for this task
+ adapter, model_load_time = self._create_fresh_model(
+ task.model_name, task.sparse_config, task.masker_name, task.task_name
+ )
+
+ # Parse benchmark info
+ benchmark_name, subset = (task.task_name.split("/", 1)
+ if "/" in task.task_name
+ else (task.task_name, None))
+
+ # Create benchmark instance
+ benchmark = create_benchmark_instance(
+ benchmark_name=benchmark_name,
+ subsets=[subset] if subset else None
+ )
+
+ # Setup result directory
+ Path(task.result_dir).mkdir(parents=True, exist_ok=True)
+
+ # Check if already completed
+ metrics_file = Path(task.result_dir) / "metrics.json"
+ if metrics_file.exists():
+ logging.info(f"Actor {self.actor_id}: Skipping completed {task.task_id}")
+ with open(metrics_file, 'r') as f:
+ metrics = json.load(f)
+ return BenchmarkResult(
+ task_id=task.task_id,
+ success=True,
+ metrics=metrics,
+ execution_time=0.0,
+ gpu_id=None,
+ model_load_time=0.0
+ )
+
+ # Setup micro metrics
+ metric_logger = MicroMetricLogger()
+ metric_logger.configure_logging(
+ log_path=task.result_dir,
+ enabled_metrics=["research_attention_density", "research_attention_output_error"]
+ )
+
+ # Run benchmark
+ benchmark_start = time.time()
+ logging.info(f"Actor {self.actor_id}: Running {task.task_id}")
+
+ metrics = benchmark.run_benchmark(
+ adapter=adapter,
+ result_dir=task.result_dir,
+ generation_kwargs=task.generation_kwargs,
+ request_kwargs=task.request_kwargs
+ )
+
+ metric_logger.flush()
+
+ execution_time = time.time() - total_start
+
+ return BenchmarkResult(
+ task_id=task.task_id,
+ success=True,
+ metrics=metrics,
+ execution_time=execution_time,
+ gpu_id=None,
+ model_load_time=model_load_time
+ )
+
+ except Exception as e:
+ logging.error(f"Actor {self.actor_id}: Task {task.task_id} failed: {e}")
+ traceback.print_exc()
+
+ return BenchmarkResult(
+ task_id=task.task_id,
+ success=False,
+ error=str(e),
+ execution_time=time.time() - total_start,
+ gpu_id=None
+ )
+
+ finally:
+ # Always clean up the model to ensure no state leakage
+ if adapter is not None:
+ logging.info(f"Actor {self.actor_id}: Cleaning up model for {task.task_id}")
+ try:
+ del adapter
+ torch.cuda.empty_cache()
+ except Exception as e:
+ logging.warning(f"Actor {self.actor_id}: Cleanup error: {e}")
+
+ def get_stats(self) -> Dict:
+ """Return actor statistics."""
+ return {
+ "actor_id": self.actor_id,
+ "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
+ "status": "active"
+ }
+
+ def cleanup(self):
+ """Clean up resources."""
+ logging.info(f"Actor {self.actor_id}: Final cleanup")
+ torch.cuda.empty_cache()
+
+
+def prepare_tasks(tasks: List[BenchmarkTask]) -> List[BenchmarkTask]:
+ """Prepare tasks for execution.
+
+ Since each task has unique optimized parameters from Phase 1,
+ every task requires fresh model initialization.
+ """
+ return tasks
+
+
+def serialize_sparse_config(config: Optional[ResearchAttentionConfig]) -> Optional[Dict]:
+ """Convert ResearchAttentionConfig to JSON-serializable format."""
+ if config is None:
+ return None
+
+ masker_configs = []
+ for masker in config.masker_configs:
+ masker_dict = {
+ "type": type(masker).__name__,
+ "params": {}
+ }
+ # Extract all public attributes
+ for attr in dir(masker):
+ if not attr.startswith("_"):
+ value = getattr(masker, attr)
+ if isinstance(value, (int, float, str, bool, type(None))):
+ masker_dict["params"][attr] = value
+ masker_configs.append(masker_dict)
+
+ return {
+ "type": "ResearchAttentionConfig",
+ "masker_configs": masker_configs
+ }
+
+
+def load_optimal_configs(config_dir: Path) -> List[BenchmarkTask]:
+ """Load optimal configurations and create benchmark tasks."""
+ tasks = []
+
+ for config_file in config_dir.glob("*.json"):
+ if config_file.name.endswith(("_trials.json", "_analysis.csv")):
+ continue
+
+ try:
+ with open(config_file, "r") as f:
+ data = json.load(f)
+
+ task_id = f"{data['model']}_{data['task']}_{data['masker_name']}".replace("/", "_")
+
+ task = BenchmarkTask(
+ task_id=task_id,
+ model_name=data["model"],
+ task_name=data["task"],
+ masker_name=data["masker_name"],
+ sparse_config=data.get("sparse_config"),
+ result_dir="", # Will be set later
+ generation_kwargs={}, # Will be set later
+ request_kwargs={} # Will be set later
+ )
+ tasks.append(task)
+
+ except Exception as e:
+ logging.warning(f"Failed to load {config_file}: {e}")
+
+ return tasks
+
+
+@ray.remote
+def progress_reporter(total_tasks: int, result_queue: RayQueue) -> None:
+ """Ray task that reports progress from result queue."""
+ completed = 0
+ failed = 0
+ start_time = time.time()
+ total_model_load_time = 0.0
+
+ while completed + failed < total_tasks:
+ try:
+ result = result_queue.get(timeout=10)
+
+ if result.success:
+ completed += 1
+ total_model_load_time += result.model_load_time
+
+ print(f"[{completed + failed}/{total_tasks}] ✓ {result.task_id} "
+ f"({result.execution_time:.1f}s, model load: {result.model_load_time:.1f}s)")
+ else:
+ failed += 1
+ print(f"[{completed + failed}/{total_tasks}] ✗ {result.task_id} - {result.error}")
+
+ # Print progress stats every 10 tasks
+ if (completed + failed) % 10 == 0:
+ elapsed = time.time() - start_time
+ rate = (completed + failed) / elapsed
+ eta = (total_tasks - completed - failed) / rate if rate > 0 else 0
+ avg_load_time = total_model_load_time / max(1, completed)
+ print(f"\n--- Progress: {completed + failed}/{total_tasks} "
+ f"({rate:.2f} tasks/s, ETA: {eta/60:.1f} min) ---")
+ print(f"--- Avg model load time: {avg_load_time:.1f}s ---\n")
+
+ except Exception:
+ continue
+
+ # Final summary
+ total_time = time.time() - start_time
+ print(f"\n{'='*80}")
+ print(f"Completed: {completed}, Failed: {failed}")
+ print(f"Total execution time: {total_time/60:.1f} minutes")
+ print(f"Total model load time: {total_model_load_time/60:.1f} minutes")
+ print(f"Throughput: {completed/total_time*3600:.1f} tasks/hour")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Ray-based parallel benchmark runner")
+ parser.add_argument("--config-run", type=str, required=True,
+ help="Config run directory name")
+ parser.add_argument("--optimal-configs-dir", default="./optimal_configs")
+ parser.add_argument("--benchmark-results-dir", default="./benchmark_results_ray_16k_100req")
+ parser.add_argument("--max-new-tokens", type=int, default=100)
+ parser.add_argument("--max-context-length", type=int, default=16000)
+ parser.add_argument("--max-requests", type=int, default=100)
+ parser.add_argument("--num-actors", type=int, default=None,
+ help="Number of Ray actors (default: number of GPUs)")
+ parser.add_argument("--actors-per-gpu", type=int, default=None,
+ help="Number of actors per GPU for better utilization (overrides --num-actors)")
+ parser.add_argument("--resume", action="store_true",
+ help="Resume from existing results")
+ parser.add_argument("--dry-run", action="store_true",
+ help="Show what would be executed without running benchmarks")
+ parser.add_argument("--debug", action="store_true",
+ help="Debug mode - run only 2-4 benchmarks to test functionality")
+
+ args = parser.parse_args()
+
+ # Setup logging
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(levelname)s - %(message)s"
+ )
+
+ print(f"\n{'='*80}")
+ print(f"RAY BENCHMARK RUNNER")
+ print(f"{'='*80}")
+
+ # Initialize Ray
+ if not ray.is_initialized():
+ ray.init(ignore_reinit_error=True)
+
+ # Get GPU info
+ num_gpus = int(ray.available_resources().get("GPU", 0))
+ if num_gpus == 0:
+ print("Error: No GPUs available")
+ sys.exit(1)
+
+ # Determine number of actors
+ if args.actors_per_gpu:
+ num_actors = num_gpus * args.actors_per_gpu
+ print(f"Creating {args.actors_per_gpu} actors per GPU for maximum utilization")
+ elif args.num_actors:
+ num_actors = args.num_actors
+ else:
+ # Default to number of GPUs
+ num_actors = num_gpus
+ # In debug mode, still use all GPUs unless specified
+ if args.debug:
+ print(f"Debug mode: using all {num_actors} GPUs for maximum utilization")
+
+ print(f"Ray cluster: {ray.available_resources()}")
+ print(f"Using {num_actors} actors on {num_gpus} GPUs")
+
+ # Load configurations
+ config_dir = Path(args.optimal_configs_dir) / args.config_run
+ if not config_dir.exists():
+ print(f"Error: Config directory {config_dir} not found")
+ sys.exit(1)
+
+ print(f"\nLoading configurations from {config_dir}...")
+ tasks = load_optimal_configs(config_dir)
+ print(f"Loaded {len(tasks)} configurations")
+
+ # Debug mode adjustments
+ if args.debug:
+ print("\n⚠️ DEBUG MODE ENABLED ⚠️")
+ print(" - Will run only a subset of benchmarks")
+ print(" - Using reduced parameters for faster testing")
+
+ # Filter tasks for debug mode - take diverse samples
+ debug_tasks = []
+ # Get one dense config
+ dense_tasks = [t for t in tasks if t.masker_name == "dense"]
+ if dense_tasks:
+ debug_tasks.append(dense_tasks[0])
+
+ # Get 2-3 sparse configs with different maskers
+ sparse_tasks = [t for t in tasks if t.masker_name != "dense"]
+ seen_maskers = set()
+ for task in sparse_tasks:
+ if task.masker_name not in seen_maskers and len(debug_tasks) < 4:
+ debug_tasks.append(task)
+ seen_maskers.add(task.masker_name)
+
+ tasks = debug_tasks
+ print(f" - Selected {len(tasks)} tasks for debug run:")
+ for task in tasks:
+ print(f" * {task.model_name} / {task.masker_name} / {task.task_name}")
+
+ # Override parameters for faster execution
+ generation_kwargs = {
+ "max_new_tokens": 20, # Much smaller for debug
+ "do_sample": False,
+ }
+
+ request_kwargs = {
+ "max_context_length": 4096, # Smaller context
+ "max_requests": 2, # Just 2 requests per benchmark
+ }
+
+ print(f"\n Debug parameters:")
+ print(f" - max_new_tokens: 20 (vs {args.max_new_tokens})")
+ print(f" - max_context_length: 4096 (vs {args.max_context_length})")
+ print(f" - max_requests: 2 (vs {args.max_requests})")
+
+ else:
+ # Normal mode - use full parameters
+ generation_kwargs = {
+ "max_new_tokens": args.max_new_tokens,
+ "do_sample": False,
+ }
+
+ request_kwargs = {
+ "max_context_length": args.max_context_length,
+ "max_requests": args.max_requests
+ }
+
+ # Update tasks with full configuration
+ for task in tasks:
+ task.result_dir = os.path.join(
+ args.benchmark_results_dir,
+ task.model_name.replace("/", "_"),
+ task.masker_name,
+ task.task_name.replace("/", "_")
+ )
+ task.generation_kwargs = generation_kwargs
+ task.request_kwargs = request_kwargs
+
+ # Prepare tasks
+ print("\nPreparing tasks...")
+ tasks = prepare_tasks(tasks)
+
+ # Dry run mode - show what would be executed
+ if args.dry_run:
+ print(f"\n{'='*80}")
+ if args.debug:
+ print("DRY RUN MODE (DEBUG) - No benchmarks will be executed")
+ else:
+ print("DRY RUN MODE - No benchmarks will be executed")
+ print(f"{'='*80}")
+
+ # Group tasks by model and masker for analysis
+ task_groups = defaultdict(list)
+ for task in tasks:
+ key = (task.model_name, task.masker_name)
+ task_groups[key].append(task)
+
+ print(f"\nTask Summary:")
+ print(f" Total tasks: {len(tasks)}")
+ print(f" Unique model/masker combinations: {len(task_groups)}")
+ print(f" Actors to be created: {num_actors}")
+
+ # Check existing results
+ completed_count = 0
+ for task in tasks:
+ metrics_file = Path(task.result_dir) / "metrics.json"
+ if metrics_file.exists() and not args.resume:
+ completed_count += 1
+
+ if completed_count > 0:
+ print(f" Already completed: {completed_count} (would be skipped)")
+ print(f" To be executed: {len(tasks) - completed_count}")
+
+ print(f"\nTask Groups (optimized order):")
+ print("-" * 80)
+
+ for i, ((model, masker), group_tasks) in enumerate(task_groups.items()):
+ print(f"\n{i+1}. {model} + {masker}")
+ print(f" Tasks ({len(group_tasks)}):")
+ for task in group_tasks[:3]: # Show first 3
+ status = "✓" if (Path(task.result_dir) / "metrics.json").exists() else "○"
+ print(f" {status} {task.task_name}")
+ if len(group_tasks) > 3:
+ print(f" ... and {len(group_tasks) - 3} more")
+
+ # Estimate resource usage
+ print(f"\nResource Estimates:")
+ print("-" * 80)
+ model_sizes = {
+ "Llama-3.1-8B": 16, # GB in bfloat16
+ "Phi-4-mini": 7,
+ # Add more model estimates
+ }
+
+ est_model_size = 16 # Default estimate
+ for model_key in model_sizes:
+ if model_key in tasks[0].model_name:
+ est_model_size = model_sizes[model_key]
+ break
+
+ print(f" Estimated model size: ~{est_model_size} GB per model")
+ print(f" Total unique model configurations: {len(tasks)}")
+ print(f" GPU memory required per actor: ~{est_model_size} GB")
+
+ # Execution plan
+ print(f"\nExecution Plan:")
+ print("-" * 80)
+ print(f" 1. Initialize Ray with {num_actors} GPU actors")
+ print(f" 2. Each actor processes tasks independently")
+ print(f" 3. Fresh model initialization for each task:")
+ print(f" - Each task has unique optimized parameters from Phase 1")
+ print(f" - Total model loads: {len(tasks)} (one per task)")
+
+ # Show example of different configs
+ if len(tasks) >= 2:
+ print(f"\nExample configurations showing parameter differences:")
+
+ # Find tasks with same masker but different parameters
+ masker_groups = defaultdict(list)
+ for task in tasks:
+ masker_groups[task.masker_name].append(task)
+
+ # Show first group with multiple tasks
+ for masker_name, group_tasks in masker_groups.items():
+ if len(group_tasks) >= 2 and masker_name != "dense":
+ for i, task in enumerate(group_tasks[:2]):
+ print(f"\n {task.masker_name} for {task.task_name}:")
+ if task.sparse_config and task.sparse_config.get("masker_configs"):
+ for masker in task.sparse_config["masker_configs"][:2]:
+ params = masker.get("params", {})
+ param_str = ", ".join([f"{k}={v}" for k, v in sorted(params.items())[:3]])
+ print(f" - {masker['type']}: {param_str}...")
+ break
+
+ print(f"\nGeneration Configuration:")
+ print(f" max_new_tokens: {args.max_new_tokens}")
+ print(f" max_context_length: {args.max_context_length}")
+ print(f" max_requests: {args.max_requests}")
+
+ print(f"\nResults will be saved to:")
+ print(f" {args.benchmark_results_dir}/")
+ print(f" └── /")
+ print(f" └── /")
+ print(f" └── /")
+ print(f" ├── raw_results.csv")
+ print(f" ├── metrics.json")
+ print(f" └── micro_metrics.jsonl")
+
+ print(f"\n{'='*80}")
+ print("Dry run complete. Remove --dry-run to execute benchmarks.")
+ print(f"{'='*80}")
+ return
+
+ # Create adapter config
+ adapter_config = {
+ "adapter_name": "huggingface",
+ "model_kwargs": {"torch_dtype": torch.bfloat16},
+ "tokenizer_kwargs": {"padding_side": "left"}
+ }
+
+ # Create Ray actors
+ print(f"\nCreating {num_actors} Ray actors...")
+ actors = []
+
+ # Calculate GPU resources per actor
+ if args.actors_per_gpu and args.actors_per_gpu > 1:
+ # When multiple actors per GPU, each gets a fraction
+ gpu_per_actor = 1.0 / args.actors_per_gpu
+ print(f"Each actor will use {gpu_per_actor:.2f} GPU resources")
+
+ # Create actors with fractional GPU resources
+ for i in range(num_actors):
+ # Have to use options to set fractional GPU
+ actor = GPUBenchmarkActor.options(num_gpus=gpu_per_actor).remote(i, adapter_config)
+ actors.append(actor)
+ else:
+ # Standard: one actor per GPU
+ for i in range(num_actors):
+ actor = GPUBenchmarkActor.remote(i, adapter_config)
+ actors.append(actor)
+
+ # Create result queue and progress reporter
+ result_queue = RayQueue(maxsize=len(tasks))
+ progress_task = progress_reporter.remote(len(tasks), result_queue)
+
+ # Create actor pool for load balancing
+ pool = ActorPool(actors)
+
+ # Submit all tasks
+ print(f"\nSubmitting {len(tasks)} tasks...")
+ print("-" * 80)
+
+ start_time = time.time()
+
+ # Submit tasks to actor pool
+ # ActorPool.submit expects (fn, value) where fn(actor, value) is called
+ for task in tasks:
+ pool.submit(lambda actor, task: actor.run_benchmark.remote(task), task)
+
+ # Collect results
+ while pool.has_next():
+ result = pool.get_next()
+ result_queue.put(result)
+
+ # Wait for progress reporter
+ ray.get(progress_task)
+
+ # Get actor statistics
+ print("\nActor statistics:")
+ for actor in actors:
+ stats = ray.get(actor.get_stats.remote())
+ print(f" Actor {stats['actor_id']} ({stats['gpu_name']}): {stats['status']}")
+
+ # Cleanup
+ print("\nCleaning up...")
+ for actor in actors:
+ ray.get(actor.cleanup.remote())
+
+ total_time = time.time() - start_time
+ print(f"\n{'='*80}")
+ print(f"EXECUTION COMPLETE")
+ print(f"{'='*80}")
+ print(f"Total time: {total_time/3600:.2f} hours")
+ print(f"Results saved to: {args.benchmark_results_dir}")
+ print(f"{'='*80}")
+
+ ray.shutdown()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/raytune/test_phase1_objectives.py b/benchmark/raytune/test_phase1_objectives.py
new file mode 100755
index 00000000..c41acee8
--- /dev/null
+++ b/benchmark/raytune/test_phase1_objectives.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+"""
+Test script to verify Phase 1 works with different objective functions.
+"""
+
+import subprocess
+import sys
+import os
+import argparse
+
+def test_objective_function(objective_name, show_full_output=False):
+ """Test Phase 1 with a specific objective function."""
+ print(f"\n{'='*60}")
+ print(f"Testing Phase 1 with objective: {objective_name}")
+ print(f"{'='*60}")
+
+ cmd = [
+ sys.executable,
+ "benchmark/raytune/run_two_phase_benchmark.py",
+ "--phase", "1",
+ "--debug", # Use debug mode for faster testing
+ "--objective", objective_name,
+ "--num-samples", "5", # Fewer samples for testing
+ "--search-timeout", "300",
+ "--force-search" # Force re-search to test the objective
+ ]
+
+ try:
+ # Run subprocess with real-time output
+ print("\n--- Starting Phase 1 run ---")
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ text=True, bufsize=1, universal_newlines=True)
+
+ output_lines = []
+ found_objective = False
+ found_score_logging = False
+
+ # Stream output line by line
+ for line in process.stdout:
+ # Store for later analysis
+ output_lines.append(line.rstrip())
+
+ # Check for key indicators
+ if f"Objective Function: {objective_name}" in line:
+ found_objective = True
+ if "Error:" in line and "Density:" in line and "Score:" in line:
+ found_score_logging = True
+
+ # Print based on preference
+ if show_full_output:
+ print(line.rstrip())
+ else:
+ # Only print important lines for default mode
+ if any(keyword in line for keyword in [
+ "Objective Function:", "Objective:", "Error:", "Density:", "Score:",
+ "Targeting", "Formula", "Best score:", "✓", "✗", "Phase 1 complete",
+ "ERROR", "Exception", "Traceback", "Failed", "Warning"
+ ]):
+ print(f" > {line.rstrip()}")
+
+ # Wait for process to complete
+ return_code = process.wait()
+ print("--- Phase 1 run completed ---\n")
+
+ if return_code == 0:
+ print("✓ Phase 1 completed successfully")
+
+ if found_objective:
+ print(f"✓ Objective function '{objective_name}' was properly logged")
+ else:
+ print(f"✗ Objective function '{objective_name}' was not found in output")
+
+ if found_score_logging:
+ print("✓ Density, error, and score logging is working")
+ else:
+ print("✗ Score logging not detected")
+
+ return True
+ else:
+ print(f"✗ Phase 1 failed with exit code {return_code}")
+ return False
+
+ except Exception as e:
+ print(f"✗ Test failed with exception: {e}")
+ return False
+
+def main():
+ """Test different objective functions."""
+ parser = argparse.ArgumentParser(description="Test Phase 1 with different objective functions")
+ parser.add_argument("--full-output", action="store_true",
+ help="Show full output from each test run instead of just key lines")
+ parser.add_argument("--objectives", nargs="+",
+ default=["default", "sparsity_5", "sparsity_10", "sparsity_15"],
+ help="List of objectives to test")
+ args = parser.parse_args()
+
+ print("Testing Phase 1 with different objective functions")
+ if args.full_output:
+ print("(Full output mode enabled)")
+
+ # Change to project root
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ os.chdir(project_root)
+
+ # Test different objectives
+ objectives_to_test = args.objectives
+
+ results = {}
+ for obj in objectives_to_test:
+ results[obj] = test_objective_function(obj, show_full_output=args.full_output)
+
+ # Summary
+ print(f"\n{'='*60}")
+ print("SUMMARY")
+ print(f"{'='*60}")
+ for obj, success in results.items():
+ status = "✓ PASSED" if success else "✗ FAILED"
+ print(f"{obj}: {status}")
+
+ # Overall result
+ all_passed = all(results.values())
+ if all_passed:
+ print("\n✓ All tests passed!")
+ else:
+ print("\n✗ Some tests failed!")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmark/raytune/visualize_benchmark_results.py b/benchmark/raytune/visualize_benchmark_results.py
new file mode 100644
index 00000000..3ecfa186
--- /dev/null
+++ b/benchmark/raytune/visualize_benchmark_results.py
@@ -0,0 +1,470 @@
+#!/usr/bin/env python3
+"""
+Production-quality interactive visualization dashboard for sparse attention benchmark results.
+
+This script creates professional-grade interactive plots using Plotly to visualize:
+- Model performance across different tasks
+- Sparse attention density vs error trade-offs
+- Comparative analysis across different sparse attention methods
+
+Usage:
+ python visualize_benchmark_results.py --results-dir benchmark_results_ray --output dashboard.html
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional, Any
+from collections import defaultdict
+import pandas as pd
+import numpy as np
+
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import plotly.express as px
+from plotly.colors import qualitative
+
+
+class BenchmarkResultsVisualizer:
+ """Production-grade visualizer for sparse attention benchmark results."""
+
+ def __init__(self, results_dir: Path):
+ self.results_dir = results_dir
+ self.data = self._load_all_results()
+ self._setup_styling()
+
+ def _setup_styling(self):
+ """Setup consistent styling for all plots."""
+ self.colors = {
+ 'dense': '#1f77b4', # Blue
+ 'sink_local_random_sampling': '#ff7f0e', # Orange
+ 'sink_local_oracle_top_k_adaptive_sampling': '#2ca02c', # Green
+ 'sink_local_hash_attention_top_k_adaptive_sampling': '#d62728', # Red
+ 'sink_local_oracle_top_p': '#9467bd', # Purple
+ 'sink_local_oracle_top_k': '#8c564b', # Brown
+ 'sink_local_hash_attention_top_k': '#e377c2', # Pink
+ 'sink_local_magic_pig': '#7f7f7f', # Gray
+ }
+
+ self.plot_config = {
+ 'displayModeBar': True,
+ 'toImageButtonOptions': {
+ 'format': 'png',
+ 'filename': 'sparse_attention_benchmark',
+ 'height': 1200,
+ 'width': 1600,
+ 'scale': 2
+ }
+ }
+
+ self.layout_template = {
+ 'font': {'family': 'Arial, sans-serif', 'size': 12},
+ 'title_font': {'size': 20, 'family': 'Arial Black, sans-serif'},
+ 'hovermode': 'x unified',
+ 'plot_bgcolor': 'white',
+ 'paper_bgcolor': 'white',
+ 'margin': {'l': 80, 'r': 80, 't': 100, 'b': 80}
+ }
+
+ def _load_all_results(self) -> pd.DataFrame:
+ """Load all benchmark results into a structured DataFrame."""
+ results = []
+
+ for model_dir in self.results_dir.iterdir():
+ if not model_dir.is_dir():
+ continue
+
+ model_name = model_dir.name
+
+ for config_dir in model_dir.iterdir():
+ if not config_dir.is_dir():
+ continue
+
+ config_name = config_dir.name
+
+ for task_dir in config_dir.iterdir():
+ if not task_dir.is_dir():
+ continue
+
+ task_name = task_dir.name
+
+ # Load metrics
+ metrics_file = task_dir / "metrics.json"
+ if metrics_file.exists():
+ with open(metrics_file, 'r') as f:
+ metrics = json.load(f)
+
+ # Load micro metrics for sparse configs
+ density = None
+ attention_error = None
+ micro_metrics_file = task_dir / "micro_metrics.jsonl"
+
+ if micro_metrics_file.exists() and config_name != "dense":
+ densities = []
+ errors = []
+
+ with open(micro_metrics_file, 'r') as f:
+ for line in f:
+ try:
+ entry = json.loads(line.strip())
+ if entry.get("metric") == "research_attention_density":
+ densities.append(entry["value"])
+ elif entry.get("metric") == "research_attention_output_error":
+ errors.append(entry["value"])
+ except:
+ continue
+
+ if densities:
+ density = np.mean(densities)
+ if errors:
+ attention_error = np.mean(errors)
+
+ # Extract performance metrics
+ result = {
+ 'model': model_name,
+ 'config': config_name,
+ 'task': task_name,
+ 'overall_score': metrics.get('overall_score', 0),
+ 'density': density,
+ 'attention_error': attention_error,
+ 'total_samples': metrics.get('summary', {}).get('total_samples', 0)
+ }
+
+ # Add task-specific scores
+ task_scores = metrics.get('task_scores', {})
+ if task_scores:
+ first_task = list(task_scores.values())[0]
+ for metric, value in first_task.items():
+ result[f'metric_{metric}'] = value
+
+ results.append(result)
+
+ return pd.DataFrame(results)
+
+ def create_performance_heatmap(self) -> go.Figure:
+ """Create a heatmap showing performance across tasks and configs."""
+ # Pivot data for heatmap
+ pivot_data = self.data.pivot_table(
+ index='config',
+ columns='task',
+ values='overall_score',
+ aggfunc='mean'
+ )
+
+ # Sort configs by average performance
+ config_order = pivot_data.mean(axis=1).sort_values(ascending=False).index
+ pivot_data = pivot_data.loc[config_order]
+
+ # Create heatmap
+ fig = go.Figure(data=go.Heatmap(
+ z=pivot_data.values,
+ x=pivot_data.columns,
+ y=pivot_data.index,
+ colorscale='RdBu_r',
+ text=np.round(pivot_data.values, 3),
+ texttemplate='%{text}',
+ textfont={"size": 10},
+ colorbar=dict(title="Overall Score"),
+ hovertemplate='Config: %{y}
Task: %{x}
Score: %{z:.3f}'
+ ))
+
+ fig.update_layout(
+ title='Performance Heatmap: Sparse Attention Methods vs Tasks',
+ xaxis_title='Benchmark Task',
+ yaxis_title='Sparse Attention Configuration',
+ height=600,
+ **self.layout_template
+ )
+
+ return fig
+
+ def create_density_vs_performance_scatter(self) -> go.Figure:
+ """Create scatter plot showing density vs performance trade-off."""
+ # Filter out dense baseline
+ sparse_data = self.data[self.data['config'] != 'dense'].copy()
+
+ fig = go.Figure()
+
+ # Add scatter points for each config
+ for config in sparse_data['config'].unique():
+ config_data = sparse_data[sparse_data['config'] == config]
+
+ fig.add_trace(go.Scatter(
+ x=config_data['density'],
+ y=config_data['overall_score'],
+ mode='markers',
+ marker=dict(
+ size=10,
+ color=self.colors.get(config, '#000000'),
+ line=dict(width=1, color='white')
+ ),
+ name=config.replace('_', ' ').title(),
+ text=config_data['task'],
+ hovertemplate='%{text}
Density: %{x:.3f}
Score: %{y:.3f}'
+ ))
+
+ # Add dense baseline as horizontal line
+ dense_scores = self.data[self.data['config'] == 'dense']['overall_score']
+ if not dense_scores.empty:
+ fig.add_hline(
+ y=dense_scores.mean(),
+ line_dash="dash",
+ line_color="gray",
+ annotation_text="Dense Baseline",
+ annotation_position="right"
+ )
+
+ fig.update_layout(
+ title='Density vs Performance Trade-off',
+ xaxis_title='Average Attention Density',
+ yaxis_title='Overall Score',
+ height=600,
+ xaxis=dict(range=[0, 1]),
+ showlegend=True,
+ legend=dict(
+ yanchor="top",
+ y=0.99,
+ xanchor="left",
+ x=1.02,
+ bgcolor="rgba(255, 255, 255, 0.8)",
+ bordercolor="rgba(0, 0, 0, 0.2)",
+ borderwidth=1
+ ),
+ margin=dict(r=150), # Add right margin for legend
+ **self.layout_template
+ )
+
+ return fig
+
+ def create_error_vs_density_scatter(self) -> go.Figure:
+ """Create scatter plot showing attention error vs density."""
+ # Filter out dense baseline and data without error metrics
+ sparse_data = self.data[
+ (self.data['config'] != 'dense') &
+ (self.data['attention_error'].notna())
+ ].copy()
+
+ fig = go.Figure()
+
+ # Add scatter points for each task
+ for task in sparse_data['task'].unique():
+ task_data = sparse_data[sparse_data['task'] == task]
+
+ fig.add_trace(go.Scatter(
+ x=task_data['density'],
+ y=task_data['attention_error'],
+ mode='markers',
+ marker=dict(
+ size=10,
+ symbol='circle',
+ line=dict(width=1, color='white')
+ ),
+ name=task.replace('_', ' ').title(),
+ text=task_data['config'],
+ hovertemplate='%{text}
Density: %{x:.3f}
Error: %{y:.3f}'
+ ))
+
+ # Add ideal line (y=0)
+ fig.add_hline(
+ y=0,
+ line_dash="dash",
+ line_color="green",
+ annotation_text="Perfect Attention",
+ annotation_position="right"
+ )
+
+ fig.update_layout(
+ title='Attention Error vs Density by Task',
+ xaxis_title='Average Attention Density',
+ yaxis_title='Average Attention Error',
+ height=600,
+ xaxis=dict(range=[0, 1]),
+ yaxis=dict(range=[0, max(0.5, sparse_data['attention_error'].max() * 1.1)]),
+ showlegend=True,
+ **self.layout_template
+ )
+
+ return fig
+
+ def create_performance_by_task_bar(self) -> go.Figure:
+ """Create grouped bar chart showing performance by task."""
+ fig = go.Figure()
+
+ # Get unique tasks and configs
+ tasks = sorted(self.data['task'].unique())
+ configs = sorted(self.data['config'].unique())
+
+ # Create grouped bars
+ for config in configs:
+ config_data = self.data[self.data['config'] == config]
+
+ # Calculate mean score per task
+ task_scores = []
+ for task in tasks:
+ task_data = config_data[config_data['task'] == task]
+ score = task_data['overall_score'].mean() if not task_data.empty else 0
+ task_scores.append(score)
+
+ fig.add_trace(go.Bar(
+ name=config.replace('_', ' ').title(),
+ x=tasks,
+ y=task_scores,
+ marker_color=self.colors.get(config, '#000000'),
+ hovertemplate='Task: %{x}
Score: %{y:.3f}'
+ ))
+
+ fig.update_layout(
+ title='Performance Comparison by Task',
+ xaxis_title='Benchmark Task',
+ yaxis_title='Overall Score',
+ barmode='group',
+ height=600,
+ xaxis_tickangle=-45,
+ **self.layout_template
+ )
+
+ return fig
+
+ def create_dashboard(self, output_file: str = "benchmark_dashboard.html"):
+ """Create a comprehensive dashboard with all visualizations."""
+ # Create subplots with specific layout
+ fig = make_subplots(
+ rows=2, cols=2,
+ subplot_titles=(
+ 'Performance Heatmap',
+ 'Density vs Performance Trade-off',
+ 'Performance by Task',
+ 'Attention Error vs Density'
+ ),
+ specs=[
+ [{"type": "heatmap"}, {"type": "scatter"}],
+ [{"type": "bar"}, {"type": "scatter"}]
+ ],
+ vertical_spacing=0.15,
+ horizontal_spacing=0.12
+ )
+
+ # Create individual plots
+ heatmap = self.create_performance_heatmap()
+ density_perf = self.create_density_vs_performance_scatter()
+ task_bars = self.create_performance_by_task_bar()
+ error_density = self.create_error_vs_density_scatter()
+
+ # Add traces to subplots
+ for trace in heatmap.data:
+ fig.add_trace(trace, row=1, col=1)
+
+ for trace in density_perf.data:
+ fig.add_trace(trace, row=1, col=2)
+
+ for trace in task_bars.data:
+ fig.add_trace(trace, row=2, col=1)
+
+ for trace in error_density.data:
+ fig.add_trace(trace, row=2, col=2)
+
+ # Update layout
+ fig.update_layout(
+ title_text="Sparse Attention Benchmark Results Dashboard",
+ title_font_size=24,
+ height=1200,
+ showlegend=False, # Individual plots have their own legends
+ **self.layout_template
+ )
+
+ # Update axes labels
+ fig.update_xaxes(title_text="Benchmark Task", row=1, col=1)
+ fig.update_yaxes(title_text="Configuration", row=1, col=1)
+
+ fig.update_xaxes(title_text="Density", row=1, col=2)
+ fig.update_yaxes(title_text="Overall Score", row=1, col=2)
+
+ fig.update_xaxes(title_text="Task", row=2, col=1)
+ fig.update_yaxes(title_text="Score", row=2, col=1)
+
+ fig.update_xaxes(title_text="Density", row=2, col=2)
+ fig.update_yaxes(title_text="Attention Error", row=2, col=2)
+
+ # Save dashboard
+ fig.write_html(
+ output_file,
+ config=self.plot_config,
+ include_plotlyjs='cdn'
+ )
+
+ # Also create individual plots
+ output_dir = Path(output_file).parent
+
+ # Save individual plots
+ heatmap.write_html(output_dir / "performance_heatmap.html", config=self.plot_config)
+ density_perf.write_html(output_dir / "density_vs_performance.html", config=self.plot_config)
+ task_bars.write_html(output_dir / "performance_by_task.html", config=self.plot_config)
+ error_density.write_html(output_dir / "error_vs_density.html", config=self.plot_config)
+
+ print(f"Dashboard saved to: {output_file}")
+ print(f"Individual plots saved to: {output_dir}/")
+
+ return fig
+
+ def generate_summary_stats(self) -> pd.DataFrame:
+ """Generate summary statistics for the benchmark results."""
+ summary = []
+
+ for config in self.data['config'].unique():
+ config_data = self.data[self.data['config'] == config]
+
+ stats = {
+ 'config': config,
+ 'avg_score': config_data['overall_score'].mean(),
+ 'std_score': config_data['overall_score'].std(),
+ 'avg_density': config_data['density'].mean() if config != 'dense' else 1.0,
+ 'avg_error': config_data['attention_error'].mean() if config != 'dense' else 0.0,
+ 'num_tasks': len(config_data),
+ 'best_task': config_data.loc[config_data['overall_score'].idxmax(), 'task'] if not config_data.empty else None,
+ 'worst_task': config_data.loc[config_data['overall_score'].idxmin(), 'task'] if not config_data.empty else None
+ }
+
+ summary.append(stats)
+
+ summary_df = pd.DataFrame(summary)
+ summary_df = summary_df.sort_values('avg_score', ascending=False)
+
+ # Save summary
+ summary_df.to_csv(self.results_dir.parent / "benchmark_summary.csv", index=False)
+
+ return summary_df
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Visualize sparse attention benchmark results")
+ parser.add_argument("--results-dir", type=str, default="benchmark_results_ray",
+ help="Directory containing benchmark results")
+ parser.add_argument("--output", type=str, default="benchmark_dashboard.html",
+ help="Output HTML file for dashboard")
+
+ args = parser.parse_args()
+
+ results_dir = Path(args.results_dir)
+ if not results_dir.exists():
+ print(f"Error: Results directory {results_dir} not found")
+ sys.exit(1)
+
+ # Create visualizer and generate dashboard
+ visualizer = BenchmarkResultsVisualizer(results_dir)
+
+ # Generate dashboard
+ visualizer.create_dashboard(args.output)
+
+ # Generate summary statistics
+ summary = visualizer.generate_summary_stats()
+ print("\nBenchmark Summary:")
+ print(summary.to_string(index=False))
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/benchmark/raytune/visualize_error_vs_density.py b/benchmark/raytune/visualize_error_vs_density.py
new file mode 100755
index 00000000..a5121953
--- /dev/null
+++ b/benchmark/raytune/visualize_error_vs_density.py
@@ -0,0 +1,566 @@
+#!/usr/bin/env python3
+"""
+Interactive HTML visualization for error vs density across benchmarks and configurations.
+
+This script creates an interactive Plotly dashboard to visualize the relationship
+between error and density metrics across different models, benchmarks, and attention
+configurations from Ray Tune optimization results.
+"""
+
+import json
+import os
+import re
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+
+
+def parse_experiment_name(experiment_dir: str) -> Tuple[str, str, str, str]:
+ """
+ Parse experiment directory name to extract model, benchmark, task, and config.
+
+ Args:
+ experiment_dir: Directory name like 'meta-llama_Llama-3.1-8B-Instruct_loogle_shortdep_qa_sink_local_random_sampling'
+
+ Returns:
+ Tuple of (model, benchmark, task, config_type)
+ """
+ parts = experiment_dir.split('_')
+
+ # Handle model name with underscores
+ if parts[0] == 'meta-llama':
+ model = f"{parts[0]}/{parts[1]}"
+ remaining = parts[2:]
+ else:
+ model = parts[0]
+ remaining = parts[1:]
+
+ # Extract benchmark
+ benchmark = remaining[0] if len(remaining) > 0 else "unknown"
+
+ # The task is everything between benchmark and sink_local
+ # Find where 'sink_local' starts
+ sink_idx = -1
+ for i in range(1, len(remaining)):
+ if remaining[i] == 'sink' and i+1 < len(remaining) and remaining[i+1] == 'local':
+ sink_idx = i
+ break
+
+ if sink_idx > 1:
+ # Task is everything between benchmark and sink_local
+ task = '_'.join(remaining[1:sink_idx])
+ # Config type is everything from sink_local onwards
+ config_type = '_'.join(remaining[sink_idx:])
+ else:
+ # Fallback parsing
+ task = remaining[1] if len(remaining) > 1 else "unknown"
+ config_type = '_'.join(remaining[2:]) if len(remaining) > 2 else "unknown"
+
+ return model, benchmark, task, config_type
+
+
+def extract_config_params(config: Dict) -> str:
+ """
+ Extract and format configuration parameters for display.
+
+ Args:
+ config: Configuration dictionary from result.json
+
+ Returns:
+ Formatted string of configuration parameters
+ """
+ params = []
+ for key, value in sorted(config.items()):
+ # Shorten parameter names for display
+ short_key = key.replace('masker_', '').replace('_size', '').replace('_rate', '')
+ if isinstance(value, float):
+ params.append(f"{short_key}={value:.3f}")
+ else:
+ params.append(f"{short_key}={value}")
+ return ", ".join(params)
+
+
+def collect_results(ray_results_dir: Path) -> pd.DataFrame:
+ """
+ Collect all results from ray_results directory.
+
+ Args:
+ ray_results_dir: Path to ray_results directory
+
+ Returns:
+ DataFrame with columns: model, benchmark, task, config_type, density, error, config_params, trial_id
+ """
+ results = []
+
+ for experiment_dir in ray_results_dir.iterdir():
+ if not experiment_dir.is_dir():
+ continue
+
+ # Parse experiment name
+ model, benchmark, task, config_type = parse_experiment_name(experiment_dir.name)
+
+ # Process each trial in the experiment
+ for trial_dir in experiment_dir.iterdir():
+ if not trial_dir.is_dir() or not trial_dir.name.startswith('objective_'):
+ continue
+
+ result_file = trial_dir / 'result.json'
+ if not result_file.exists():
+ continue
+
+ try:
+ with open(result_file, 'r') as f:
+ data = json.load(f)
+
+ # Extract metrics
+ density = data.get('density', None)
+ error = data.get('error', None)
+
+ if density is None or error is None:
+ continue
+
+ # Extract trial ID
+ trial_id = data.get('trial_id', trial_dir.name.split('_')[1])
+
+ # Format config parameters
+ config_params = extract_config_params(data.get('config', {}))
+
+ results.append({
+ 'model': model,
+ 'benchmark': benchmark,
+ 'task': task,
+ 'config_type': config_type,
+ 'density': density,
+ 'error': error,
+ 'config_params': config_params,
+ 'trial_id': trial_id,
+ 'combined_score': data.get('combined_score', None)
+ })
+
+ except Exception as e:
+ print(f"Error processing {result_file}: {e}")
+ continue
+
+ return pd.DataFrame(results)
+
+
+def create_interactive_dashboard(df: pd.DataFrame, output_file: str = "error_vs_density_dashboard.html", output_dir: Path = None):
+ """
+ Create an interactive Plotly dashboard for error vs density visualization.
+
+ Args:
+ df: DataFrame with results
+ output_file: Output HTML file name
+ output_dir: Output directory for additional files
+ """
+ # Get unique tasks
+ tasks = sorted(df['task'].unique())
+ n_tasks = len(tasks)
+
+ # Force 2x2 layout for better presentation
+ n_cols = 2
+ n_rows = 2
+
+ # Create subplot titles with better formatting
+ subplot_titles = [f"{task.replace('_', ' ').title()}" for task in tasks]
+
+ # Create the main figure with subplots
+ fig = make_subplots(
+ rows=n_rows, cols=n_cols,
+ subplot_titles=subplot_titles,
+ horizontal_spacing=0.12,
+ vertical_spacing=0.15,
+ specs=[[{"type": "scatter"} for _ in range(n_cols)] for _ in range(n_rows)]
+ )
+
+ # Define dark color palette for config types
+ config_types = sorted(df['config_type'].unique())
+ # Using dark, vibrant colors for better visibility
+ dark_colors = [
+ '#1f77b4', # dark blue
+ '#ff7f0e', # dark orange
+ '#2ca02c', # dark green
+ '#d62728', # dark red
+ '#9467bd', # dark purple
+ '#8c564b', # dark brown
+ '#e377c2', # dark pink
+ '#7f7f7f', # dark gray
+ '#bcbd22', # dark olive
+ '#17becf', # dark cyan
+ '#393b79', # midnight blue
+ '#637939', # dark olive green
+ '#8c6d31', # dark tan
+ '#843c39', # dark maroon
+ '#7b4173', # dark magenta
+ '#5254a3', # dark indigo
+ '#6b6ecf', # dark lavender
+ '#9c9ede', # dark periwinkle
+ '#bd9e39', # dark gold
+ '#ad494a', # dark coral
+ '#a55194', # dark orchid
+ ]
+ color_map = {config: dark_colors[i % len(dark_colors)] for i, config in enumerate(config_types)}
+
+ # Define marker symbols for better distinction
+ symbols = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down',
+ 'triangle-left', 'triangle-right', 'pentagon', 'hexagon', 'star']
+ symbol_map = {config: symbols[i % len(symbols)] for i, config in enumerate(config_types)}
+
+ # Track if we've added each config type to legend
+ added_to_legend = set()
+
+ # For each task, create a subplot
+ for idx, task in enumerate(tasks):
+ row = idx // n_cols + 1
+ col = idx % n_cols + 1
+
+ task_df = df[df['task'] == task]
+
+ # Find best configs for this task at different density levels
+ best_configs = []
+ for density_threshold in [0.1, 0.2, 0.3, 0.4, 0.5]:
+ subset = task_df[task_df['density'] <= density_threshold]
+ if not subset.empty:
+ best_idx = subset['error'].idxmin()
+ best_configs.append(subset.loc[best_idx])
+
+ # Add traces for each config type in this task
+ for config_type in config_types:
+ config_task_df = task_df[task_df['config_type'] == config_type]
+
+ if config_task_df.empty:
+ continue
+
+ # Check if we should show in legend
+ show_legend = config_type not in added_to_legend
+ if show_legend:
+ added_to_legend.add(config_type)
+
+ fig.add_trace(
+ go.Scatter(
+ x=config_task_df['density'],
+ y=config_task_df['error'],
+ mode='markers',
+ name=config_type.replace('sink_local_', '').replace('_', ' '),
+ marker=dict(
+ size=10,
+ color=color_map[config_type],
+ symbol=symbol_map[config_type],
+ line=dict(width=1, color='white'),
+ opacity=0.9
+ ),
+ customdata=config_task_df[['model', 'benchmark', 'task', 'config_params', 'trial_id', 'combined_score']],
+ hovertemplate=(
+ "%{fullData.name}
" +
+ "Density: %{x:.3f}
" +
+ "Error: %{y:.3f}
" +
+ "Model: %{customdata[0]}
" +
+ "Benchmark: %{customdata[1]}
" +
+ "Task: %{customdata[2]}
" +
+ "Config: %{customdata[3]}
" +
+ "Trial ID: %{customdata[4]}
" +
+ "Combined Score: %{customdata[5]:.3f}
" +
+ ""
+ ),
+ showlegend=show_legend,
+ legendgroup=config_type
+ ),
+ row=row, col=col
+ )
+
+ # Highlight best performers with larger markers
+ if best_configs:
+ best_df = pd.DataFrame(best_configs)
+ fig.add_trace(
+ go.Scatter(
+ x=best_df['density'],
+ y=best_df['error'],
+ mode='markers',
+ name='Best at density level',
+ marker=dict(
+ size=16,
+ color='#8B0000', # dark red
+ symbol='star',
+ line=dict(width=2, color='#4B0000') # even darker red
+ ),
+ customdata=best_df[['config_type', 'config_params']],
+ hovertemplate=(
+ "BEST at density %.1f
" +
+ "Config: %{customdata[0]}
" +
+ "Params: %{customdata[1]}
" +
+ "Density: %{x:.3f}
" +
+ "Error: %{y:.3f}
" +
+ ""
+ ),
+ showlegend=(idx == 0), # Only show in legend once
+ legendgroup='best'
+ ),
+ row=row, col=col
+ )
+
+ # Update all axes
+ for i in range(1, n_rows + 1):
+ for j in range(1, n_cols + 1):
+ # Update x-axis
+ fig.update_xaxes(
+ title=dict(text="Density", font={'size': 14}),
+ tickfont={'size': 12},
+ gridcolor='rgba(128, 128, 128, 0.2)',
+ zeroline=False,
+ range=[-0.05, 1.05], # Fixed range for better comparison
+ row=i, col=j
+ )
+ # Update y-axis
+ fig.update_yaxes(
+ title=dict(text="Error", font={'size': 14}),
+ tickfont={'size': 12},
+ gridcolor='rgba(128, 128, 128, 0.2)',
+ zeroline=False,
+ range=[-0.05, 0.9], # Fixed range for better comparison
+ row=i, col=j
+ )
+
+ # Update layout with aesthetic styling
+ fig.update_layout(
+ title={
+ 'text': f"Error vs Density Analysis by Task
{df['benchmark'].iloc[0]} benchmark on {df['model'].iloc[0]}",
+ 'font': {'size': 24, 'family': 'Arial, sans-serif'},
+ 'x': 0.5,
+ 'xanchor': 'center'
+ },
+ plot_bgcolor='white',
+ paper_bgcolor='white',
+ hovermode='closest',
+ legend=dict(
+ title=dict(text="Configuration Type", font={'size': 14}),
+ font={'size': 11},
+ bgcolor='rgba(255, 255, 255, 0.9)',
+ bordercolor='rgba(0, 0, 0, 0.2)',
+ borderwidth=1,
+ itemsizing='constant',
+ x=1.02,
+ y=1,
+ xanchor='left',
+ yanchor='top'
+ ),
+ height=400 * n_rows,
+ width=1400,
+ margin=dict(l=80, r=250, t=120, b=80),
+ showlegend=True
+ )
+
+ # Save the figure
+ fig.write_html(
+ output_file,
+ config={'displayModeBar': True, 'displaylogo': False}
+ )
+
+ print(f"Dashboard saved to {output_file}")
+
+ # Also create separate plots by benchmark and task
+ if output_dir:
+ create_faceted_plots(df, str(output_dir / "error_vs_density_by_benchmark.html"))
+ else:
+ create_faceted_plots(df, "error_vs_density_by_benchmark.html")
+
+
+def create_faceted_plots(df: pd.DataFrame, output_file: str):
+ """
+ Create faceted plots showing error vs density grouped by benchmark and task.
+
+ Args:
+ df: DataFrame with results
+ output_file: Output HTML file name
+ """
+ # Create a more detailed visualization with facets
+ fig = px.scatter(
+ df,
+ x='density',
+ y='error',
+ color='config_type',
+ facet_col='task',
+ facet_row='benchmark',
+ hover_data=['model', 'config_params', 'trial_id', 'combined_score'],
+ title="Error vs Density by Benchmark and Task",
+ labels={
+ 'density': 'Density',
+ 'error': 'Error',
+ 'config_type': 'Configuration Type'
+ },
+ height=1200,
+ width=1600
+ )
+
+ # Update styling
+ fig.update_traces(marker=dict(size=8, line=dict(width=1, color='white')))
+
+ fig.update_layout(
+ font={'family': 'Arial, sans-serif'},
+ plot_bgcolor='rgba(240, 240, 240, 0.5)',
+ paper_bgcolor='white',
+ hovermode='closest'
+ )
+
+ # Update axes
+ fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
+ fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
+
+ # Save the figure
+ fig.write_html(
+ output_file,
+ config={'displayModeBar': True, 'displaylogo': False}
+ )
+
+ print(f"Faceted plots saved to {output_file}")
+
+
+def create_best_config_summary(df: pd.DataFrame, output_file: str):
+ """
+ Create a summary visualization showing best configurations for each task.
+
+ Args:
+ df: DataFrame with results
+ output_file: Output HTML file name
+ """
+ tasks = sorted(df['task'].unique())
+ density_levels = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]
+
+ # Create summary data
+ summary_data = []
+ for task in tasks:
+ task_df = df[df['task'] == task]
+ for density_level in density_levels:
+ subset = task_df[task_df['density'] <= density_level]
+ if not subset.empty:
+ best_idx = subset['error'].idxmin()
+ best_row = subset.loc[best_idx]
+ summary_data.append({
+ 'task': task,
+ 'density_level': density_level,
+ 'best_config': best_row['config_type'].replace('sink_local_', ''),
+ 'error': best_row['error'],
+ 'actual_density': best_row['density'],
+ 'params': best_row['config_params']
+ })
+
+ summary_df = pd.DataFrame(summary_data)
+
+ # Create a heatmap-style visualization
+ fig = go.Figure()
+
+ # Create a trace for each config type
+ config_types = summary_df['best_config'].unique()
+ colors = px.colors.qualitative.Set3
+ color_map = {config: colors[i % len(colors)] for i, config in enumerate(config_types)}
+
+ for task in tasks:
+ task_data = summary_df[summary_df['task'] == task]
+
+ # Add bar chart showing best config at each density level
+ fig.add_trace(
+ go.Bar(
+ name=task,
+ x=[f"≤{d:.0%}" for d in task_data['density_level']],
+ y=task_data['error'],
+ text=[f"{row['best_config']}
Error: {row['error']:.3f}"
+ for _, row in task_data.iterrows()],
+ textposition='auto',
+ marker_color=[color_map[config] for config in task_data['best_config']],
+ customdata=task_data[['best_config', 'actual_density', 'params']],
+ hovertemplate=(
+ "Task: %{fullData.name}
" +
+ "Density Level: %{x}
" +
+ "Best Config: %{customdata[0]}
" +
+ "Error: %{y:.3f}
" +
+ "Actual Density: %{customdata[1]:.3f}
" +
+ "Parameters: %{customdata[2]}
" +
+ ""
+ )
+ )
+ )
+
+ fig.update_layout(
+ title={
+ 'text': "Best Configurations by Task and Density Level",
+ 'font': {'size': 20, 'family': 'Arial, sans-serif'},
+ 'x': 0.5,
+ 'xanchor': 'center'
+ },
+ xaxis=dict(
+ title="Maximum Density Level",
+ tickfont={'size': 12}
+ ),
+ yaxis=dict(
+ title="Error",
+ tickfont={'size': 12}
+ ),
+ barmode='group',
+ height=600,
+ width=1200,
+ showlegend=True,
+ legend=dict(
+ title="Task",
+ font={'size': 12},
+ x=1.02,
+ y=1,
+ xanchor='left',
+ yanchor='top',
+ bgcolor='rgba(255, 255, 255, 0.8)',
+ bordercolor='rgba(0, 0, 0, 0.2)',
+ borderwidth=1
+ ),
+ margin=dict(r=150), # Add right margin for legend
+ plot_bgcolor='rgba(240, 240, 240, 0.5)',
+ paper_bgcolor='white'
+ )
+
+ fig.write_html(output_file, config={'displayModeBar': True, 'displaylogo': False})
+ print(f"Best config summary saved to {output_file}")
+
+
+def main():
+ """Main function to generate the visualization."""
+ # Get the ray_results directory
+ ray_results_dir = Path(__file__).parent.parent.parent / "ray_results"
+
+ if not ray_results_dir.exists():
+ print(f"Error: ray_results directory not found at {ray_results_dir}")
+ return
+
+ print("Collecting results from ray_results directory...")
+ df = collect_results(ray_results_dir)
+
+ if df.empty:
+ print("No results found!")
+ return
+
+ print(f"Found {len(df)} results across {df['model'].nunique()} models, "
+ f"{df['benchmark'].nunique()} benchmarks, and {df['config_type'].nunique()} configuration types")
+
+ # Create output directory
+ output_dir = Path(__file__).parent / "visualizations"
+ output_dir.mkdir(exist_ok=True)
+
+ # Generate visualizations
+ print("\nGenerating interactive dashboard...")
+ create_interactive_dashboard(df, str(output_dir / "error_vs_density_by_task.html"), output_dir)
+
+ # Create best config summary
+ print("\nGenerating best config summary...")
+ create_best_config_summary(df, str(output_dir / "best_configs_summary.html"))
+
+ # Print a clean summary
+ print("\nVisualization complete! Generated files:")
+ print(f" - error_vs_density_by_task.html (task-wise subplots)")
+ print(f" - best_configs_summary.html (best configs at each density level)")
+ print(f" - error_vs_density_by_benchmark.html (faceted by benchmark/task)")
+
+ print(f"\nAnalyzed {len(df)} configurations across {len(df['task'].unique())} tasks")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index 3bcc130d..f204daca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,7 @@ classifiers = [
keywords = ["attention", "sparse", "transformer", "deep-learning", "pytorch"]
dependencies = [
- "torch>=1.9.0",
+ "torch==2.7.1",
"numpy>=1.21.0",
"matplotlib>=3.5.0",
"seaborn>=0.11.0",
@@ -49,6 +49,8 @@ dependencies = [
"pandas>=2.3.1",
"pynvml>=12.0.0",
"colorama>=0.4.6",
+ "ray[tune]>=2.48.0",
+ "hyperopt>=0.2.7",
]
[project.optional-dependencies]
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/basic_fixed.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/basic_fixed.py
index 1a7547a6..1141654d 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/basic_fixed.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/basic_fixed.py
@@ -21,6 +21,22 @@ class LocalMaskerConfig(FixedMaskerConfig):
window_size: Union[float, int]
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for Local masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "window_size": tune.choice([0.01])
+ }
+
@MaskerRegistry.register(LocalMaskerConfig)
class LocalMasker(FixedMasker):
@@ -168,6 +184,22 @@ class SinkMaskerConfig(FixedMaskerConfig):
sink_size: Union[float, int]
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for Sink masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "sink_size": tune.choice([0.01])
+ }
+
@MaskerRegistry.register(SinkMaskerConfig)
class SinkMasker(FixedMasker):
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py
index 83c4f877..65457e93 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/hashattention_top_k.py
@@ -32,6 +32,25 @@ class HashAttentionTopKMaskerConfig(TopKMaskerConfig):
hat_mlp_activation: str
hat_weights: Optional[Dict[int, Dict[str, List[torch.Tensor]]]] = None
hat_weight_file: Optional[str] = None
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for HashAttentionTopK masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ # Only tune heavy_size, other parameters are fixed by the pre-trained model
+ return {
+ "heavy_size": tune.choice([0.01, 0.02, 0.03])
+ }
+ ## set in benchmarking config
+ # return {}
@MaskerRegistry.register(HashAttentionTopKMaskerConfig)
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_k.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_k.py
index 256dd0e2..74c6015b 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_k.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_k.py
@@ -24,6 +24,23 @@ class OracleTopKConfig(TopKMaskerConfig):
"""Configuration for OracleTopK masker."""
pass
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for OracleTopK masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "heavy_size": tune.choice([0.01, 0.02, 0.03])
+ }
+ # return {}
@MaskerRegistry.register(OracleTopKConfig)
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py
index 65e3f27d..45427ee3 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/fixed/implementations/oracle_top_p.py
@@ -24,6 +24,22 @@ class OracleTopPMaskerConfig(TopPMaskerConfig):
"""Configuration for OracleTopPMasker."""
pass # Inherits top_p from parent with validation
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for OracleTopP masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "top_p": tune.choice([0.5, 0.6, 0.7, 0.8, 0.9, 0.95])
+ }
@MaskerRegistry.register(OracleTopPMaskerConfig)
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py
index 9c4928cd..5b3d0284 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py
@@ -109,6 +109,26 @@ def __post_init__(self) -> None:
raise ValueError(
f"local_offset must be int or float, got {type(self.local_offset)}"
)
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for AdaptiveSampling masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "base_rate_sampling": tune.choice([0.01, 0.02, 0.03]),
+ "epsilon": tune.choice([0.1, 0.2, 0.3, 0.4]),
+ "delta": tune.choice([0.1, 0.2, 0.3, 0.4]),
+ "init_offset": tune.choice([0.01]),
+ "local_offset": tune.choice([0.01])
+ }
@MaskerRegistry.register(AdaptiveSamplingMaskerConfig)
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py
index 1ec655a6..d72f10d2 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py
@@ -65,6 +65,25 @@ def __post_init__(self) -> None:
)
if self.seed is None:
raise ValueError("seed cannot be None")
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for MagicPig masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ return {
+ "lsh_l": tune.choice([32, 64, 128]),
+ "lsh_k": tune.choice([8, 16, 32]),
+ "center": tune.choice([True]),
+ "packing": tune.choice(["int64"])
+ }
@MaskerRegistry.register(MagicPigConfig)
diff --git a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/random_sampling.py b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/random_sampling.py
index ed72b255..ca461c55 100644
--- a/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/random_sampling.py
+++ b/sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/random_sampling.py
@@ -45,6 +45,23 @@ def __post_init__(self) -> None:
raise ValueError(
f"sampling_rate must be in range [0, 1], got {self.sampling_rate}"
)
+
+ @classmethod
+ def get_search_space(cls, task_name: str) -> Dict[str, Any]:
+ """Get Ray Tune search space for RandomSampling masker.
+
+ Args:
+ task_name: Name of the benchmark task to optimize for
+
+ Returns:
+ Dictionary mapping parameter names to Ray Tune distributions
+ """
+ from ray import tune
+
+ # return {
+ # "sampling_rate": tune.choice([0.01, 0.05, 0.1, 0.2, 0.3, 0.5])
+ # }
+ return {}
@MaskerRegistry.register(RandomSamplingMaskerConfig)