Skip to content

Commit 1702008

Browse files
committed
Add logging of density and local attention error
1 parent 730f6ba commit 1702008

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

benchmark/executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional, Tuple
1717
from queue import Empty
1818
from contextlib import contextmanager
19+
from sparse_attention_hub.metric_logging.logger import MicroMetricLogger
1920

2021
# Set multiprocessing start method to 'spawn' for CUDA compatibility
2122
if multiprocessing.get_start_method(allow_none=True) != 'spawn':
@@ -233,12 +234,15 @@ def _benchmark_worker(
233234

234235
# Execute benchmark
235236
logger.info(f"Worker {worker_id}: Executing benchmark {stub.benchmark_name} on GPU {current_gpu_id}")
237+
metric_logger = MicroMetricLogger()
238+
metric_logger.configure_logging(log_path=stub.result_dir, enabled_metrics=["research_attention_density", "research_attention_output_error"])
236239
metrics = benchmark.run_benchmark(
237240
adapter=adapter,
238241
result_dir=stub.result_dir,
239242
generation_kwargs=stub.generation_kwargs,
240243
request_kwargs=stub.request_kwargs
241244
)
245+
metric_logger.flush()
242246

243247
execution_time = time.time() - start_time
244248
execution_success = True

sparse_attention_hub/metric_logging/logger.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __new__(cls, *args, **kwargs) -> "MicroMetricLogger":
3636

3737
def __init__(self,
3838
log_path: Optional[str] = None,
39-
flush_every: int = 100, # Flush every N events
39+
flush_every: int = 1000, # Flush every N events
4040
flush_interval: float = 60.0, # Flush every N seconds
4141
enabled_metrics: Union[List[str], str] = None): # List of string identifiers to enable, or "all"
4242
if not self._initialized:
@@ -181,8 +181,7 @@ def flush(self) -> None:
181181
return
182182

183183
# Get current timestamp for filename
184-
timestamp = datetime.now().strftime("%Y%m%d")
185-
filename = f"metrics_{timestamp}.jsonl"
184+
filename = f"micro_metrics.jsonl"
186185
filepath = os.path.join(self.log_path, filename)
187186

188187
# Write events to file

sparse_attention_hub/sparse_attention/research_attention/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
from ..base import SparseAttention, SparseAttentionConfig
1010
from ..utils.mask import Mask
11-
from ..utils.mask_attention_utils import get_masked_attention_output
11+
from ..utils.mask_attention_utils import get_masked_attention_output, get_true_attention_output
1212
from .maskers.base import MaskerConfig, ResearchMasker
1313
from .maskers.sampling.base import SamplingMasker
1414

15+
from sparse_attention_hub.metric_logging.logger import MicroMetricLogger
16+
MicroMetricLogger.register_metric("research_attention_density", float)
17+
MicroMetricLogger.register_metric("research_attention_output_error", float)
1518

1619
@dataclass
1720
class ResearchAttentionConfig(SparseAttentionConfig):
@@ -101,6 +104,9 @@ def custom_attention(
101104
previous_mask=sparse_attention_mask,
102105
**kwargs,
103106
)
107+
108+
if MicroMetricLogger().is_metric_enabled("research_attention_density"):
109+
MicroMetricLogger().log("research_attention_density", sparse_attention_mask.get_density(), metadata={"layer_idx" : kwargs["layer_idx"]})
104110

105111
# Call compute_masked_attention_output on the result of the last mask
106112
# Always request attention weights to match the expected return signature
@@ -118,6 +124,12 @@ def custom_attention(
118124
return_attention_weights=True,
119125
**kwargs,
120126
)
127+
128+
if MicroMetricLogger().is_metric_enabled("research_attention_output_error"):
129+
true_attention_output, _ = get_true_attention_output(module, queries, keys, values, attention_mask, scaling, dropout, **kwargs)
130+
error = torch.norm(true_attention_output - attention_output) / torch.norm(true_attention_output)
131+
MicroMetricLogger().log("research_attention_output_error", float(error.item()), metadata={"layer_idx" : kwargs["layer_idx"]})
132+
121133
return attention_output, attention_weights
122134

123135
@classmethod

sparse_attention_hub/sparse_attention/utils/mask_attention_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,50 @@
99
from .mask import Mask
1010

1111

12+
def get_true_attention_output(
13+
module: nn.Module,
14+
queries: torch.Tensor,
15+
keys: torch.Tensor,
16+
values: torch.Tensor,
17+
attention_mask: Optional[torch.Tensor],
18+
scaling: float,
19+
dropout: float,
20+
**kwargs: Dict[str, Any],
21+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
22+
"""Get the true (dense) attention output from the module.
23+
24+
Args:
25+
module: The attention module (used for dropout training flag).
26+
queries: Query tensor of shape (..., seq_len_q, d_k).
27+
keys: Key tensor of shape (..., seq_len_k, d_k).
28+
values: Value tensor of shape (..., seq_len_k, d_v).
29+
attention_mask: Optional mask tensor to apply to attention weights.
30+
scaling: Scaling factor for attention logits.
31+
dropout: Dropout probability for attention weights.
32+
**kwargs: Additional keyword arguments (unused).
33+
34+
Returns:
35+
Tuple containing:
36+
- attention_output: Output tensor after applying attention.
37+
- attention_weights: Softmax-normalized attention weights.
38+
"""
39+
num_key_value_groups: int = _get_num_key_value_groups(queries, keys)
40+
key_states = repeat_kv(keys, num_key_value_groups)
41+
value_states = repeat_kv(values, num_key_value_groups)
42+
43+
attn_weights = torch.matmul(queries, key_states.transpose(2, 3)) * scaling
44+
if attention_mask is not None:
45+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
46+
attn_weights = attn_weights + causal_mask
47+
48+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(queries.dtype)
49+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
50+
attn_output = torch.matmul(attn_weights, value_states)
51+
attn_output = attn_output.transpose(1, 2).contiguous()
52+
53+
return attn_output, attn_weights
54+
55+
1256
def apply_inv_mask_sum(input_tensor: torch.Tensor, mask: Mask) -> torch.Tensor:
1357
"""Apply inverse mask to input tensor and sum along the last dimension.
1458

0 commit comments

Comments
 (0)