@@ -2394,6 +2394,7 @@ kernel void kernel_ssm_conv_f32_f32_4(
23942394}
23952395
23962396// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2397+ // Optimized version: reduces redundant memory loads by having one thread load shared values
23972398kernel void kernel_ssm_scan_f32 (
23982399 constant ggml_metal_kargs_ssm_scan & args,
23992400 device const void * src0,
@@ -2413,7 +2414,15 @@ kernel void kernel_ssm_scan_f32(
24132414 uint3 tgpg[[threadgroups_per_grid]]) {
24142415 constexpr short NW = N_SIMDWIDTH;
24152416
2416- shared[tpitg.x ] = 0 .0f ;
2417+ // Shared memory layout:
2418+ // [0..sgptg*NW-1]: partial sums for reduction (existing)
2419+ // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2420+ // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2421+ threadgroup float * shared_sums = shared;
2422+ threadgroup float * shared_x_dt = shared + sgptg * NW;
2423+ threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2424+
2425+ shared_sums[tpitg.x ] = 0 .0f ;
24172426
24182427 const int32_t i0 = tpitg.x ;
24192428 const int32_t i1 = tgpig.x ;
@@ -2453,32 +2462,47 @@ kernel void kernel_ssm_scan_f32(
24532462 for (int i2 = 0 ; i2 < n_t ; i2 += sgptg) {
24542463 threadgroup_barrier (mem_flags::mem_threadgroup);
24552464
2456- for (int t = 0 ; t < sgptg && i2 + t < n_t ; t++) {
2457- const float dt0 = dt[0 ];
2465+ // Pre-compute x_dt and dA for this batch of tokens
2466+ // Only first sgptg threads do the loads and expensive math
2467+ if (i0 < sgptg && i2 + i0 < n_t ) {
2468+ // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2469+ device const float * x_t = x + i0 * args.ns12 ;
2470+ device const float * dt_t = dt + i0 * args.ns21 ;
2471+
2472+ const float dt0 = dt_t [0 ];
24582473 const float dtsp = dt0 <= 20 .0f ? log (1 .0f + exp (dt0)) : dt0;
2459- const float x_dt = x[0 ] * dtsp;
2460- const float dA = exp (dtsp * A0);
2474+ shared_x_dt[i0] = x_t [0 ] * dtsp;
2475+ shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2476+ }
2477+
2478+ threadgroup_barrier (mem_flags::mem_threadgroup);
2479+
2480+ for (int t = 0 ; t < sgptg && i2 + t < n_t ; t++) {
2481+ const float x_dt = shared_x_dt[t];
2482+ const float dA = exp (shared_dA[t] * A0);
24612483
24622484 s = (s0 * dA) + (B[i0] * x_dt);
24632485
24642486 const float sumf = simd_sum (s * C[i0]);
24652487
24662488 if (tiisg == 0 ) {
2467- shared [t*NW + sgitg] = sumf;
2489+ shared_sums [t*NW + sgitg] = sumf;
24682490 }
24692491
24702492 // recurse
24712493 s0 = s;
24722494
2473- x += args.ns12 ;
2474- dt += args.ns21 ;
24752495 B += args.ns42 ;
24762496 C += args.ns52 ;
24772497 }
24782498
2499+ // Advance pointers for next batch
2500+ x += sgptg * args.ns12 ;
2501+ dt += sgptg * args.ns21 ;
2502+
24792503 threadgroup_barrier (mem_flags::mem_threadgroup);
24802504
2481- const float sumf = simd_sum (shared [sgitg*NW + tiisg]);
2505+ const float sumf = simd_sum (shared_sums [sgitg*NW + tiisg]);
24822506
24832507 if (tiisg == 0 && i2 + sgitg < n_t ) {
24842508 y[sgitg*nh*nr] = sumf;
0 commit comments