Skip to content

Commit 3d65a96

Browse files
committed
Fix linting errors
1 parent 1702008 commit 3d65a96

File tree

7 files changed

+103
-59
lines changed

7 files changed

+103
-59
lines changed

sparse_attention_hub/metric_logging/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .logger import MicroMetricLogger
44

5-
__all__ = ["MicroMetricLogger"]
5+
__all__ = ["MicroMetricLogger"]
Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""MicroMetricLogger implementation for sparse attention hub."""
22

3+
import inspect
34
import json
45
import os
56
import time
67
from collections import deque
7-
from dataclasses import dataclass, asdict
8+
from dataclasses import asdict, dataclass
89
from datetime import datetime
910
from typing import Any, Dict, List, Optional, Union
10-
import inspect
1111

1212

1313
@dataclass
1414
class LogEvent:
1515
"""Log event data structure."""
16+
1617
timestamp: datetime
1718
metric: str # Metric identifier string
1819
value: Union[None, Any]
@@ -25,7 +26,7 @@ class MicroMetricLogger:
2526

2627
_instance: Optional["MicroMetricLogger"] = None
2728
_initialized: bool = False
28-
29+
2930
# Class-level storage for registered metrics (works without initialization)
3031
_registered_metrics: Dict[str, type] = {} # identifier -> dtype mapping
3132

@@ -34,34 +35,36 @@ def __new__(cls, *args, **kwargs) -> "MicroMetricLogger":
3435
cls._instance = super().__new__(cls)
3536
return cls._instance
3637

37-
def __init__(self,
38-
log_path: Optional[str] = None,
39-
flush_every: int = 1000, # Flush every N events
40-
flush_interval: float = 60.0, # Flush every N seconds
41-
enabled_metrics: Union[List[str], str] = None): # List of string identifiers to enable, or "all"
38+
def __init__(
39+
self,
40+
log_path: Optional[str] = None,
41+
flush_every: int = 1000, # Flush every N events
42+
flush_interval: float = 60.0, # Flush every N seconds
43+
enabled_metrics: Union[List[str], str] = None,
44+
): # List of string identifiers to enable, or "all"
4245
if not self._initialized:
4346
self.log_path = log_path
4447
self.flush_every = flush_every
4548
self.flush_interval = flush_interval
46-
49+
4750
# Internal state
4851
self.log_queue: deque = deque(maxlen=10000) # Circular buffer
4952
self.enabled_metrics: set = set()
5053
self.last_flush_time = time.time()
51-
54+
5255
# Enable metrics if log_path is provided
5356
if self.log_path is not None:
5457
self._ensure_log_directory()
5558
self.enable_metrics(enabled_metrics)
56-
59+
5760
MicroMetricLogger._initialized = True
5861

5962
# main registration function
6063

6164
@classmethod
6265
def register_metric(cls, identifier: str, dtype: type) -> None:
6366
"""Register a metric with its string identifier and expected data type.
64-
67+
6568
This works at class level and doesn't require initialization.
6669
"""
6770
if identifier in cls._registered_metrics:
@@ -73,7 +76,6 @@ def get_registered_metrics(cls) -> Dict[str, type]:
7376
"""Get all registered metrics at class level."""
7477
return cls._registered_metrics.copy()
7578

76-
7779
# helper methods
7880

7981
def _ensure_log_directory(self) -> None:
@@ -88,19 +90,19 @@ def _get_calling_location(self) -> str:
8890
caller_frame = inspect.currentframe().f_back.f_back
8991
if caller_frame is None:
9092
return "unknown"
91-
93+
9294
# Get module name
9395
module = inspect.getmodule(caller_frame)
9496
module_name = module.__name__ if module else "unknown"
95-
97+
9698
# Get function/class name
9799
function_name = caller_frame.f_code.co_name
98-
100+
99101
# Try to get class name if it's a method
100102
class_name = None
101-
if 'self' in caller_frame.f_locals:
102-
class_name = caller_frame.f_locals['self'].__class__.__name__
103-
103+
if "self" in caller_frame.f_locals:
104+
class_name = caller_frame.f_locals["self"].__class__.__name__
105+
104106
if class_name:
105107
return f"{module_name}.{class_name}.{function_name}"
106108
else:
@@ -110,14 +112,13 @@ def _get_calling_location(self) -> str:
110112

111113
def __del__(self):
112114
"""Cleanup when logger is destroyed."""
113-
self.flush() # Final flush
114-
115+
self.flush() # Final flush
115116

116-
# api
117+
# api
117118

118119
def enable_metrics(self, metrics: Union[List[str], str] = None) -> None:
119120
"""Enable logging for specific metrics.
120-
121+
121122
Args:
122123
metrics: List of metric identifiers to enable, or "all" for all registered metrics.
123124
If None, enables no metrics (empty list).
@@ -129,46 +130,54 @@ def enable_metrics(self, metrics: Union[List[str], str] = None) -> None:
129130
valid_metrics = set(metrics) & set(self._registered_metrics.keys())
130131
invalid_metrics = set(metrics) - set(self._registered_metrics.keys())
131132
if invalid_metrics:
132-
print(f"Warning: Attempting to enable unregistered metrics: {invalid_metrics}")
133+
print(
134+
f"Warning: Attempting to enable unregistered metrics: {invalid_metrics}"
135+
)
133136
self.enabled_metrics = valid_metrics
134137
else:
135138
# Default to empty set
136139
self.enabled_metrics = set()
137140

138141
def log(self, identifier: str, value: Any, metadata: Dict[str, Any] = None) -> None:
139142
"""Log a metric value with optional metadata. Location is auto-inferred.
140-
143+
141144
This only works if log_path is defined.
142145
"""
143146
# Check if logging is configured
144147
if self.log_path is None:
145-
print(f"Warning: Cannot log metric '{identifier}' - log_path not defined. Use configure_logging() first.")
148+
print(
149+
f"Warning: Cannot log metric '{identifier}' - log_path not defined. Use configure_logging() first."
150+
)
146151
return
147-
152+
148153
# Check if metric is enabled
149154
if identifier not in self.enabled_metrics:
150-
print(f"Warning: Attempting to log metric '{identifier}' which is not enabled")
155+
print(
156+
f"Warning: Attempting to log metric '{identifier}' which is not enabled"
157+
)
151158
return
152-
159+
153160
# Create log event
154161
event = LogEvent(
155162
timestamp=datetime.now(),
156163
metric=identifier,
157164
value=value,
158165
metadata=metadata or {},
159-
location=self._get_calling_location()
166+
location=self._get_calling_location(),
160167
)
161-
168+
162169
# Add to queue
163170
self.log_queue.append(event)
164-
171+
165172
# Check if we should flush
166173
if len(self.log_queue) >= self.flush_every:
167174
self.flush()
168175

169-
def configure_logging(self, log_path: str, enabled_metrics: Union[List[str], str] = None) -> None:
176+
def configure_logging(
177+
self, log_path: str, enabled_metrics: Union[List[str], str] = None
178+
) -> None:
170179
"""Configure logging with a log path and optionally enable metrics.
171-
180+
172181
This must be called before logging can work.
173182
"""
174183
self.log_path = log_path
@@ -179,11 +188,11 @@ def flush(self) -> None:
179188
"""Force flush the current queue to disk."""
180189
if not self.log_queue or self.log_path is None:
181190
return
182-
191+
183192
# Get current timestamp for filename
184193
filename = f"micro_metrics.jsonl"
185194
filepath = os.path.join(self.log_path, filename)
186-
195+
187196
# Write events to file
188197
with open(filepath, "a", encoding="utf-8") as f:
189198
while self.log_queue:
@@ -193,7 +202,7 @@ def flush(self) -> None:
193202
# Convert datetime to ISO format string
194203
event_dict["timestamp"] = event_dict["timestamp"].isoformat()
195204
f.write(json.dumps(event_dict) + "\n")
196-
205+
197206
self.last_flush_time = time.time()
198207

199208
def is_metric_enabled(self, identifier: str) -> bool:
@@ -206,4 +215,4 @@ def get_enabled_metrics(self) -> set:
206215

207216
def is_logging_configured(self) -> bool:
208217
"""Check if logging is configured (log_path is set)."""
209-
return self.log_path is not None
218+
return self.log_path is not None

sparse_attention_hub/sparse_attention/research_attention/base.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,21 @@
66
import torch
77
from torch import nn
88

9+
from sparse_attention_hub.metric_logging.logger import MicroMetricLogger
10+
911
from ..base import SparseAttention, SparseAttentionConfig
1012
from ..utils.mask import Mask
11-
from ..utils.mask_attention_utils import get_masked_attention_output, get_true_attention_output
13+
from ..utils.mask_attention_utils import (
14+
get_masked_attention_output,
15+
get_true_attention_output,
16+
)
1217
from .maskers.base import MaskerConfig, ResearchMasker
1318
from .maskers.sampling.base import SamplingMasker
1419

15-
from sparse_attention_hub.metric_logging.logger import MicroMetricLogger
1620
MicroMetricLogger.register_metric("research_attention_density", float)
1721
MicroMetricLogger.register_metric("research_attention_output_error", float)
1822

23+
1924
@dataclass
2025
class ResearchAttentionConfig(SparseAttentionConfig):
2126
"""Configuration class for research attention mechanisms."""
@@ -104,9 +109,13 @@ def custom_attention(
104109
previous_mask=sparse_attention_mask,
105110
**kwargs,
106111
)
107-
112+
108113
if MicroMetricLogger().is_metric_enabled("research_attention_density"):
109-
MicroMetricLogger().log("research_attention_density", sparse_attention_mask.get_density(), metadata={"layer_idx" : kwargs["layer_idx"]})
114+
MicroMetricLogger().log(
115+
"research_attention_density",
116+
sparse_attention_mask.get_density(),
117+
metadata={"layer_idx": kwargs["layer_idx"]},
118+
)
110119

111120
# Call compute_masked_attention_output on the result of the last mask
112121
# Always request attention weights to match the expected return signature
@@ -126,9 +135,24 @@ def custom_attention(
126135
)
127136

128137
if MicroMetricLogger().is_metric_enabled("research_attention_output_error"):
129-
true_attention_output, _ = get_true_attention_output(module, queries, keys, values, attention_mask, scaling, dropout, **kwargs)
130-
error = torch.norm(true_attention_output - attention_output) / torch.norm(true_attention_output)
131-
MicroMetricLogger().log("research_attention_output_error", float(error.item()), metadata={"layer_idx" : kwargs["layer_idx"]})
138+
true_attention_output, _ = get_true_attention_output(
139+
module,
140+
queries,
141+
keys,
142+
values,
143+
attention_mask,
144+
scaling,
145+
dropout,
146+
**kwargs,
147+
)
148+
error = torch.norm(true_attention_output - attention_output) / torch.norm(
149+
true_attention_output
150+
)
151+
MicroMetricLogger().log(
152+
"research_attention_output_error",
153+
float(error.item()),
154+
metadata={"layer_idx": kwargs["layer_idx"]},
155+
)
132156

133157
return attention_output, attention_weights
134158

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/magic_pig.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
_get_num_key_value_groups,
2020
repeat_kv,
2121
)
22-
2322
from sparse_attention_hub.sparse_attention.utils.mask import Mask
2423

2524
from ..base import SamplingMasker, SamplingMaskerConfig

sparse_attention_hub/sparse_attention/utils/mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,4 +656,4 @@ def get_density(self) -> float:
656656
elif self.from_index:
657657
return float(len(self.indices)) / float(np.prod(self.shape))
658658
else:
659-
raise RuntimeError("Mask object is in an invalid state")
659+
raise RuntimeError("Mask object is in an invalid state")

sparse_attention_hub/sparse_attention/utils/mask_attention_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ def get_true_attention_output(
4545
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
4646
attn_weights = attn_weights + causal_mask
4747

48-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(queries.dtype)
49-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
48+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
49+
queries.dtype
50+
)
51+
attn_weights = nn.functional.dropout(
52+
attn_weights, p=dropout, training=module.training
53+
)
5054
attn_output = torch.matmul(attn_weights, value_states)
5155
attn_output = attn_output.transpose(1, 2).contiguous()
5256

5357
return attn_output, attn_weights
54-
58+
5559

5660
def apply_inv_mask_sum(input_tensor: torch.Tensor, mask: Mask) -> torch.Tensor:
5761
"""Apply inverse mask to input tensor and sum along the last dimension.

0 commit comments

Comments
 (0)