Skip to content

Commit 6cf8a1f

Browse files
committed
Improves attention bias numerical stability
Replaces $\\exp(A\\cdot\\mathrm{softplus}(\\Delta V))$ with $A\\cdot\\mathrm{softplus}(\\Delta V)$ to prevent overflow/NaNs in attention bias and stabilize training/inference. Preserves tensor shape/dtype and adds a clarifying comment on the rationale.
1 parent 78bb93d commit 6cf8a1f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/modeling/modeling_doge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def forward(
217217
dt_states = self.dt_proj(
218218
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
219219
)
220-
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
220+
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
221+
attn_bias = self.A * F.softplus(dt_states).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
221222

222223
attention_interface: Callable = flash_dynamic_mask_attention_forward
223224

0 commit comments

Comments
 (0)