Skip to content

Commit 9eaea4a

Browse files
authored
Fix CoreML iOS26 numerics in static attention (#16144)
Summary: This diff decomposes SDPA to fix iOS26 numerics in Core ML. It also removes repeat interleave to further optimize performance on Core ML by about 10-15%, depending on the hardware. Differential Revision: D88705980
1 parent d39d64b commit 9eaea4a

File tree

1 file changed

+49
-5
lines changed

1 file changed

+49
-5
lines changed

examples/models/llama/static_attention.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,10 @@ def __init__(
768768
self.use_conv2d = False
769769
self.enable_qnn_masked_softmax = kwargs.get("enable_qnn_masked_softmax", False)
770770

771+
# This fixes numerics on iOS26 on Core ML
772+
# Possibly disable in future, depending on bug fixes in Core ML runtime
773+
self.decompose_sdpa_in_mha: bool = kwargs.get("decompose_sdpa_in_mha", False)
774+
771775
if self.split_mha:
772776
self.wqs = nn.ModuleList(
773777
[
@@ -1027,16 +1031,56 @@ def _forward_mha(
10271031
k, out_cache_state = self.k_caches[0].update(k, in_cache_state, out_cache_state)
10281032
v, out_cache_state = self.v_caches[0].update(v, in_cache_state, out_cache_state)
10291033

1030-
if self.n_rep > 1:
1031-
k = k.repeat_interleave(self.n_rep, dim=1)
1032-
v = v.repeat_interleave(self.n_rep, dim=1)
1033-
10341034
mask = None
10351035
masks = kwargs.get("masks")
10361036
if masks:
10371037
cache_len = k.size(-2) - seq_len
10381038
mask = masks[cache_len]
1039-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1039+
1040+
if not self.decompose_sdpa_in_mha:
1041+
if self.n_rep > 1:
1042+
k = k.repeat_interleave(self.n_rep, dim=1)
1043+
v = v.repeat_interleave(self.n_rep, dim=1)
1044+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1045+
else:
1046+
# We remove bsz dim to keep matmul's on 4D tensors
1047+
# Core ML sometimes fails at runtime when given 5D tensors
1048+
assert bsz == 1, "Batch size > 1 not supported yet"
1049+
1050+
n_kv = self.n_kv_heads
1051+
n_rep = self.n_rep
1052+
D = self.head_dim
1053+
1054+
# Explicitly track lengths; they are NOT necessarily equal.
1055+
Tq = q.size(-2) # query length (current step/window), e.g. 64
1056+
Tk = k.size(-2) # key/value length (cache length), e.g. 2048
1057+
1058+
# Group Q to match KV layout
1059+
# q: (bsz=1, n_heads, Tq, D), with n_heads = n_kv * n_rep
1060+
# 1 * n_heads * Tq * D == n_kv * n_rep * Tq * D
1061+
# q_grouped: (n_kv, n_rep, Tq, D)
1062+
q_grouped = q.view(n_kv, n_rep, Tq, D)
1063+
1064+
# Prepare K for grouped KV matmul
1065+
# k: (1, n_kv, Tk, d) -> (n_kv, 1, Tk, D)
1066+
k_grouped = k.view(n_kv, 1, Tk, D)
1067+
1068+
# (n_kv, n_rep, Tq, Tk)
1069+
attn_grouped = q_grouped @ k_grouped.transpose(-2, -1)
1070+
attn_grouped = attn_grouped * self.inv_scale
1071+
1072+
# Ungroup, add mask, and regroup
1073+
attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk)
1074+
attn_grouped = attn_grouped + mask
1075+
attn_grouped = F.softmax(attn_grouped, dim=-1)
1076+
attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk)
1077+
1078+
# Group v
1079+
v_grouped = v.view(n_kv, 1, Tk, D)
1080+
y_grouped = attn_grouped @ v_grouped
1081+
1082+
# Ungroup y
1083+
y = y_grouped.view(1, self.n_heads, Tq, D)
10401084

10411085
return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state
10421086

0 commit comments

Comments
 (0)