Skip to content

Commit f873f2d

Browse files
committed
Add init_offset and local_offset can have float values in AdpativeSamplingMasker
Tool: Cursor
1 parent 41e149b commit f873f2d

File tree

1 file changed

+60
-14
lines changed
  • sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations

1 file changed

+60
-14
lines changed

sparse_attention_hub/sparse_attention/research_attention/maskers/sampling/implementations/adaptive_sampling.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,19 @@ class AdaptiveSamplingMaskerConfig(SamplingMaskerConfig):
4343
If float, must be in (0,1); if int, must be positive.
4444
epsilon: Float in range (0,1) representing the error bound.
4545
delta: Float in range (0,1) representing the confidence bound.
46-
init_offset: Non-negative integer representing the start index for sampling.
47-
local_offset: Non-negative integer representing the end offset for sampling.
46+
init_offset: Union[int, float] representing the start index for sampling.
47+
If int, must be non-negative; if float, must be in [0,1] and will be
48+
multiplied by the number of keys to get the actual offset.
49+
local_offset: Union[int, float] representing the end offset for sampling.
50+
If int, must be non-negative; if float, must be in [0,1] and will be
51+
multiplied by the number of keys to get the actual offset.
4852
"""
4953

5054
base_rate_sampling: Union[int, float] # Base rate (0,1) if float
5155
epsilon: float # Error bound (0,1)
5256
delta: float # Confidence bound (0,1)
53-
init_offset: int # Start index
54-
local_offset: int # End offset
57+
init_offset: Union[int, float] # Start index
58+
local_offset: Union[int, float] # End offset
5559

5660
def __post_init__(self) -> None:
5761
"""Validate configuration parameters."""
@@ -76,14 +80,34 @@ def __post_init__(self) -> None:
7680
if not (0.0 < self.delta < 1.0):
7781
raise ValueError(f"delta must be in (0, 1), got {self.delta}")
7882

79-
if self.init_offset < 0:
83+
if isinstance(self.init_offset, float):
84+
if not (0.0 <= self.init_offset <= 1.0):
85+
raise ValueError(
86+
f"init_offset must be in [0, 1] if float, got {self.init_offset}"
87+
)
88+
elif isinstance(self.init_offset, int):
89+
if self.init_offset < 0:
90+
raise ValueError(
91+
f"init_offset must be non-negative if int, got {self.init_offset}"
92+
)
93+
else:
8094
raise ValueError(
81-
f"init_offset must be non-negative, got {self.init_offset}"
95+
f"init_offset must be int or float, got {type(self.init_offset)}"
8296
)
8397

84-
if self.local_offset < 0:
98+
if isinstance(self.local_offset, float):
99+
if not (0.0 <= self.local_offset <= 1.0):
100+
raise ValueError(
101+
f"local_offset must be in [0, 1] if float, got {self.local_offset}"
102+
)
103+
elif isinstance(self.local_offset, int):
104+
if self.local_offset < 0:
105+
raise ValueError(
106+
f"local_offset must be non-negative if int, got {self.local_offset}"
107+
)
108+
else:
85109
raise ValueError(
86-
f"local_offset must be non-negative, got {self.local_offset}"
110+
f"local_offset must be int or float, got {type(self.local_offset)}"
87111
)
88112

89113

@@ -102,8 +126,10 @@ class AdaptiveSamplingMasker(SamplingMasker):
102126
base_rate_sampling: The base sampling rate (int or float).
103127
epsilon: The error bound for statistical guarantees.
104128
delta: The confidence bound for statistical guarantees.
105-
init_offset: Starting index for sampling range.
106-
local_offset: Ending offset for sampling range.
129+
init_offset: Starting index for sampling range (int or float).
130+
If float, represents fraction of sequence length.
131+
local_offset: Ending offset for sampling range (int or float).
132+
If float, represents fraction of sequence length.
107133
delta_ppf: Pre-computed percentile point function for efficiency.
108134
109135
Important Notes:
@@ -116,7 +142,7 @@ class AdaptiveSamplingMasker(SamplingMasker):
116142
Example:
117143
>>> config = AdaptiveSamplingMaskerConfig(
118144
... base_rate_sampling=0.1, epsilon=0.1, delta=0.05,
119-
... init_offset=0, local_offset=0
145+
... init_offset=0.1, local_offset=0.2 # Use 10% from start, 20% from end
120146
... )
121147
>>> masker = AdaptiveSamplingMasker(config)
122148
>>> # Use masker.add_mask() to apply adaptive sampling to attention masks
@@ -160,9 +186,29 @@ def _compute_exp_attention_scores(
160186
return torch.exp(raw_scores - max_scores)
161187

162188
def _get_sampling_range(self, seq_len_keys: int) -> tuple[int, int, int]:
163-
"""Get sampling range and validate it."""
164-
start_idx = self.init_offset
165-
end_idx = seq_len_keys - self.local_offset
189+
"""Get sampling range and validate it.
190+
191+
Args:
192+
seq_len_keys: Number of keys in the sequence.
193+
194+
Returns:
195+
Tuple of (start_idx, end_idx, sampling_range).
196+
197+
Raises:
198+
ValueError: If the computed sampling range is invalid.
199+
"""
200+
# Compute start index
201+
if isinstance(self.init_offset, float):
202+
start_idx: int = int(self.init_offset * seq_len_keys)
203+
else:
204+
start_idx = self.init_offset
205+
206+
# Compute end index
207+
if isinstance(self.local_offset, float):
208+
end_idx: int = seq_len_keys - int(self.local_offset * seq_len_keys)
209+
else:
210+
end_idx = seq_len_keys - self.local_offset
211+
166212
sampling_range = end_idx - start_idx
167213

168214
if sampling_range <= 0:

0 commit comments

Comments
 (0)