Skip to content

Commit 65feb1c

Browse files
committed
feat: Optimized SSM_SCAN kernel for metal
This used Claude Code and resulted in a modest performance improvement while maintaining correctness. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 953bb62 commit 65feb1c

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
444444
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
445445
}
446446

447-
res.smem = 32*sizeof(float)*nsg;
447+
// Shared memory layout:
448+
// - sgptg * NW floats for partial sums (nsg * 32)
449+
// - sgptg floats for shared_x_dt (nsg)
450+
// - sgptg floats for shared_dA (nsg)
451+
// Total: nsg * (32 + 2) floats
452+
res.smem = (32 + 2)*sizeof(float)*nsg;
448453

449454
return res;
450455
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,7 @@ kernel void kernel_ssm_conv_f32_f32_4(
23942394
}
23952395

23962396
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2397+
// Optimized version: reduces redundant memory loads by having one thread load shared values
23972398
kernel void kernel_ssm_scan_f32(
23982399
constant ggml_metal_kargs_ssm_scan & args,
23992400
device const void * src0,
@@ -2413,7 +2414,15 @@ kernel void kernel_ssm_scan_f32(
24132414
uint3 tgpg[[threadgroups_per_grid]]) {
24142415
constexpr short NW = N_SIMDWIDTH;
24152416

2416-
shared[tpitg.x] = 0.0f;
2417+
// Shared memory layout:
2418+
// [0..sgptg*NW-1]: partial sums for reduction (existing)
2419+
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2420+
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2421+
threadgroup float * shared_sums = shared;
2422+
threadgroup float * shared_x_dt = shared + sgptg * NW;
2423+
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2424+
2425+
shared_sums[tpitg.x] = 0.0f;
24172426

24182427
const int32_t i0 = tpitg.x;
24192428
const int32_t i1 = tgpig.x;
@@ -2453,32 +2462,47 @@ kernel void kernel_ssm_scan_f32(
24532462
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
24542463
threadgroup_barrier(mem_flags::mem_threadgroup);
24552464

2456-
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2457-
const float dt0 = dt[0];
2465+
// Pre-compute x_dt and dA for this batch of tokens
2466+
// Only first sgptg threads do the loads and expensive math
2467+
if (i0 < sgptg && i2 + i0 < n_t) {
2468+
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2469+
device const float * x_t = x + i0 * args.ns12;
2470+
device const float * dt_t = dt + i0 * args.ns21;
2471+
2472+
const float dt0 = dt_t[0];
24582473
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2459-
const float x_dt = x[0] * dtsp;
2460-
const float dA = exp(dtsp * A0);
2474+
shared_x_dt[i0] = x_t[0] * dtsp;
2475+
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2476+
}
2477+
2478+
threadgroup_barrier(mem_flags::mem_threadgroup);
2479+
2480+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2481+
const float x_dt = shared_x_dt[t];
2482+
const float dA = exp(shared_dA[t] * A0);
24612483

24622484
s = (s0 * dA) + (B[i0] * x_dt);
24632485

24642486
const float sumf = simd_sum(s * C[i0]);
24652487

24662488
if (tiisg == 0) {
2467-
shared[t*NW + sgitg] = sumf;
2489+
shared_sums[t*NW + sgitg] = sumf;
24682490
}
24692491

24702492
// recurse
24712493
s0 = s;
24722494

2473-
x += args.ns12;
2474-
dt += args.ns21;
24752495
B += args.ns42;
24762496
C += args.ns52;
24772497
}
24782498

2499+
// Advance pointers for next batch
2500+
x += sgptg * args.ns12;
2501+
dt += sgptg * args.ns21;
2502+
24792503
threadgroup_barrier(mem_flags::mem_threadgroup);
24802504

2481-
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
2505+
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
24822506

24832507
if (tiisg == 0 && i2 + sgitg < n_t) {
24842508
y[sgitg*nh*nr] = sumf;

0 commit comments

Comments
 (0)