@@ -2477,31 +2477,41 @@ kernel void kernel_ssm_scan_f32(
24772477
24782478 threadgroup_barrier (mem_flags::mem_threadgroup);
24792479
2480- for (int t = 0 ; t < sgptg && i2 + t < n_t ; t++) {
2480+ // Phase 1: Compute states and s*C products for all tokens in batch
2481+ // Store partial products, delay reduction
2482+ const int batch_len = min ((int )sgptg, n_t - i2);
2483+ device const float * B_t = B;
2484+ device const float * C_t = C;
2485+
2486+ for (int t = 0 ; t < batch_len; t++) {
24812487 const float x_dt = shared_x_dt[t];
24822488 const float dA = exp (shared_dA[t] * A0);
24832489
2484- s = (s0 * dA) + (B [i0] * x_dt);
2490+ s = (s0 * dA) + (B_t [i0] * x_dt);
24852491
2486- const float sumf = simd_sum (s * C[i0]);
2492+ // Compute s * C and do SIMD reduction
2493+ const float sumf = simd_sum (s * C_t[i0]);
24872494
24882495 if (tiisg == 0 ) {
24892496 shared_sums[t*NW + sgitg] = sumf;
24902497 }
24912498
2492- // recurse
24932499 s0 = s;
2494-
2495- B += args.ns42 ;
2496- C += args.ns52 ;
2500+ B_t += args.ns42 ;
2501+ C_t += args.ns52 ;
24972502 }
24982503
2499- // Advance pointers for next batch
2504+ // Advance B, C pointers for next batch
2505+ B += batch_len * args.ns42 ;
2506+ C += batch_len * args.ns52 ;
2507+
2508+ // Advance x, dt pointers for next batch
25002509 x += sgptg * args.ns12 ;
25012510 dt += sgptg * args.ns21 ;
25022511
25032512 threadgroup_barrier (mem_flags::mem_threadgroup);
25042513
2514+ // Phase 2: Final reduction and output
25052515 const float sumf = simd_sum (shared_sums[sgitg*NW + tiisg]);
25062516
25072517 if (tiisg == 0 && i2 + sgitg < n_t ) {
@@ -2514,6 +2524,140 @@ kernel void kernel_ssm_scan_f32(
25142524 s_buff[i] = s;
25152525}
25162526
2527+ // SSD kernel using parallel prefix scan for efficient multi-token processing
2528+ //
2529+ // The SSM state update s[t] = dA[t] * s[t-1] + B[t] * x[t] * dt[t] forms an
2530+ // associative scan with operator: (c1,v1) ⊕ (c2,v2) = (c2*c1, c2*v1 + v2)
2531+ //
2532+ // This allows O(log n) parallel prefix computation instead of O(n) sequential.
2533+ // We use a work-efficient Blelloch scan within each threadgroup.
2534+ //
2535+ // Dispatch: one threadgroup per (head_dim_idx, head, seq)
2536+ // Threads: must be power of 2, >= n_seq_tokens
2537+ kernel void kernel_ssm_scan_ssd_f32 (
2538+ constant ggml_metal_kargs_ssm_scan & args,
2539+ device const void * src0,
2540+ device const void * src1,
2541+ device const void * src2,
2542+ device const void * src3,
2543+ device const void * src4,
2544+ device const void * src5,
2545+ device const void * src6,
2546+ device float * dst,
2547+ threadgroup float * shared [[threadgroup(0 )]],
2548+ uint3 tgpig[[threadgroup_position_in_grid]],
2549+ ushort3 tpitg[[thread_position_in_threadgroup]],
2550+ ushort sgitg[[simdgroup_index_in_threadgroup]],
2551+ ushort tiisg[[thread_index_in_simdgroup]],
2552+ ushort sgptg[[simdgroups_per_threadgroup]],
2553+ uint3 tgpg[[threadgroups_per_grid]]) {
2554+
2555+ constexpr short NW = N_SIMDWIDTH;
2556+
2557+ const int32_t i0 = tpitg.x ; // state index within d_state
2558+ const int32_t i1 = tgpig.x ; // head_dim index
2559+ const int32_t ir = tgpig.y ; // head index
2560+ const int32_t i3 = tgpig.z ; // sequence index
2561+
2562+ const int32_t nc = args.d_state ;
2563+ const int32_t nr = args.d_inner ; // head_dim
2564+ const int32_t nh = args.n_head ;
2565+ const int32_t ng = args.n_group ;
2566+ const int32_t n_t = args.n_seq_tokens ;
2567+
2568+ const int32_t s_off = args.s_off ;
2569+ const int32_t g = ir / (nh / ng); // group index for B, C
2570+
2571+ device const int32_t * ids = (device const int32_t *) src6;
2572+
2573+ // State buffers
2574+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
2575+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2576+
2577+ const int32_t state_idx = i0 + i1*nc;
2578+
2579+ // Load initial state
2580+ float s0 = s0_buff[state_idx];
2581+
2582+ // A coefficient
2583+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 );
2584+ const float A0 = A[i0 % args.ne30 ];
2585+
2586+ // Input pointers
2587+ device const float * x_base = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13 );
2588+ device const float * dt_base = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22 );
2589+ device const float * B_base = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43 );
2590+ device const float * C_base = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53 );
2591+
2592+ // Output pointer
2593+ device float * y_base = dst + (i1 + ir*nr + i3*(n_t *nh*nr));
2594+
2595+ // Shared memory layout:
2596+ // - sgptg * NW floats for partial sums
2597+ // - sgptg floats for shared_x_dt
2598+ // - sgptg floats for shared_dA
2599+ threadgroup float * shared_sums = shared;
2600+ threadgroup float * shared_x_dt = shared + sgptg * NW;
2601+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2602+
2603+ shared_sums[tpitg.x ] = 0 .0f ;
2604+
2605+ float s = 0 .0f ;
2606+
2607+ // Process tokens in batches of sgptg
2608+ for (int i2 = 0 ; i2 < n_t ; i2 += sgptg) {
2609+ threadgroup_barrier (mem_flags::mem_threadgroup);
2610+
2611+ // Pre-compute x_dt and dA for this batch of tokens
2612+ if (i0 < sgptg && i2 + i0 < n_t ) {
2613+ device const float * x_t = x_base + i0 * args.ns12 ;
2614+ device const float * dt_t = dt_base + i0 * args.ns21 ;
2615+
2616+ const float dt0 = dt_t [0 ];
2617+ const float dtsp = dt0 <= 20 .0f ? log (1 .0f + exp (dt0)) : dt0;
2618+ shared_x_dt[i0] = x_t [0 ] * dtsp;
2619+ shared_dA[i0] = dtsp;
2620+ }
2621+
2622+ threadgroup_barrier (mem_flags::mem_threadgroup);
2623+
2624+ // Process tokens in batch sequentially (standard approach)
2625+ for (int t = 0 ; t < sgptg && i2 + t < n_t ; t++) {
2626+ const float x_dt = shared_x_dt[t];
2627+ const float dA = exp (shared_dA[t] * A0);
2628+
2629+ s = (s0 * dA) + (B_base[i0] * x_dt);
2630+
2631+ const float sumf = simd_sum (s * C_base[i0]);
2632+
2633+ if (tiisg == 0 ) {
2634+ shared_sums[t*NW + sgitg] = sumf;
2635+ }
2636+
2637+ s0 = s;
2638+
2639+ B_base += args.ns42 ;
2640+ C_base += args.ns52 ;
2641+ }
2642+
2643+ // Advance pointers for next batch
2644+ x_base += sgptg * args.ns12 ;
2645+ dt_base += sgptg * args.ns21 ;
2646+
2647+ threadgroup_barrier (mem_flags::mem_threadgroup);
2648+
2649+ const float sumf = simd_sum (shared_sums[sgitg*NW + tiisg]);
2650+
2651+ if (tiisg == 0 && i2 + sgitg < n_t ) {
2652+ y_base[sgitg*nh*nr] = sumf;
2653+ }
2654+
2655+ y_base += sgptg*nh*nr;
2656+ }
2657+
2658+ s_buff[state_idx] = s;
2659+ }
2660+
25172661kernel void kernel_rwkv_wkv6_f32 (
25182662 device const float * k,
25192663 device const float * v,
0 commit comments