@@ -768,6 +768,10 @@ def __init__(
768768 self .use_conv2d = False
769769 self .enable_qnn_masked_softmax = kwargs .get ("enable_qnn_masked_softmax" , False )
770770
771+ # This fixes numerics on iOS26 on Core ML
772+ # Possibly disable in future, depending on bug fixes in Core ML runtime
773+ self .decompose_sdpa_in_mha : bool = kwargs .get ("decompose_sdpa_in_mha" , False )
774+
771775 if self .split_mha :
772776 self .wqs = nn .ModuleList (
773777 [
@@ -1027,16 +1031,56 @@ def _forward_mha(
10271031 k , out_cache_state = self .k_caches [0 ].update (k , in_cache_state , out_cache_state )
10281032 v , out_cache_state = self .v_caches [0 ].update (v , in_cache_state , out_cache_state )
10291033
1030- if self .n_rep > 1 :
1031- k = k .repeat_interleave (self .n_rep , dim = 1 )
1032- v = v .repeat_interleave (self .n_rep , dim = 1 )
1033-
10341034 mask = None
10351035 masks = kwargs .get ("masks" )
10361036 if masks :
10371037 cache_len = k .size (- 2 ) - seq_len
10381038 mask = masks [cache_len ]
1039- y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
1039+
1040+ if not self .decompose_sdpa_in_mha :
1041+ if self .n_rep > 1 :
1042+ k = k .repeat_interleave (self .n_rep , dim = 1 )
1043+ v = v .repeat_interleave (self .n_rep , dim = 1 )
1044+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
1045+ else :
1046+ # We remove bsz dim to keep matmul's on 4D tensors
1047+ # Core ML sometimes fails at runtime when given 5D tensors
1048+ assert bsz == 1 , "Batch size > 1 not supported yet"
1049+
1050+ n_kv = self .n_kv_heads
1051+ n_rep = self .n_rep
1052+ D = self .head_dim
1053+
1054+ # Explicitly track lengths; they are NOT necessarily equal.
1055+ Tq = q .size (- 2 ) # query length (current step/window), e.g. 64
1056+ Tk = k .size (- 2 ) # key/value length (cache length), e.g. 2048
1057+
1058+ # Group Q to match KV layout
1059+ # q: (bsz=1, n_heads, Tq, D), with n_heads = n_kv * n_rep
1060+ # 1 * n_heads * Tq * D == n_kv * n_rep * Tq * D
1061+ # q_grouped: (n_kv, n_rep, Tq, D)
1062+ q_grouped = q .view (n_kv , n_rep , Tq , D )
1063+
1064+ # Prepare K for grouped KV matmul
1065+ # k: (1, n_kv, Tk, d) -> (n_kv, 1, Tk, D)
1066+ k_grouped = k .view (n_kv , 1 , Tk , D )
1067+
1068+ # (n_kv, n_rep, Tq, Tk)
1069+ attn_grouped = q_grouped @ k_grouped .transpose (- 2 , - 1 )
1070+ attn_grouped = attn_grouped * self .inv_scale
1071+
1072+ # Ungroup, add mask, and regroup
1073+ attn_grouped = attn_grouped .view (1 , self .n_heads , Tq , Tk )
1074+ attn_grouped = attn_grouped + mask
1075+ attn_grouped = F .softmax (attn_grouped , dim = - 1 )
1076+ attn_grouped = attn_grouped .view (n_kv , n_rep , Tq , Tk )
1077+
1078+ # Group v
1079+ v_grouped = v .view (n_kv , 1 , Tk , D )
1080+ y_grouped = attn_grouped @ v_grouped
1081+
1082+ # Ungroup y
1083+ y = y_grouped .view (1 , self .n_heads , Tq , D )
10401084
10411085 return y .transpose (1 , 2 ).contiguous ().view (bsz , seq_len , - 1 ), out_cache_state
10421086
0 commit comments