Skip to content

Commit 6340ab1

Browse files
committed
TEMP (wip): Partial work towards a unified kernel implementation of SSD
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 51adb32 commit 6340ab1

File tree

5 files changed

+189
-11
lines changed

5 files changed

+189
-11
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me
454454
return res;
455455
}
456456

457+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd(ggml_metal_library_t lib, const ggml_tensor * op) {
458+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
459+
460+
char base[256];
461+
char name[256];
462+
463+
const int nsg = (ne00 + 31)/32;
464+
465+
snprintf(base, 256, "kernel_ssm_scan_ssd_%s", ggml_type_name(op->src[0]->type));
466+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
467+
468+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
469+
if (!res.pipeline) {
470+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
471+
}
472+
473+
// Shared memory layout for SSD kernel:
474+
// - BATCH_SIZE * sgptg floats for partial sums
475+
// BATCH_SIZE = 8, so 8 * nsg floats
476+
constexpr int BATCH_SIZE = 8;
477+
res.smem = BATCH_SIZE * nsg * sizeof(float);
478+
479+
return res;
480+
}
481+
457482
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
458483
char base[256];
459484
char name[256];

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max
119119
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op);
121121
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
122+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan_ssd (ggml_metal_library_t lib, const struct ggml_tensor * op);
122123
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
123124
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
124125
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1471,7 +1471,13 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
14711471
/*.nb0 =*/ nb0,
14721472
};
14731473

1474-
auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1474+
auto pipeline = (n_seq_tokens > 1)
1475+
? ggml_metal_library_get_pipeline_ssm_scan_ssd(lib, op)
1476+
: ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1477+
1478+
// // Use sequential scan for now - the SSD kernel needs further optimization
1479+
// // to be competitive with the efficient sequential implementation
1480+
// auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
14751481

14761482
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
14771483

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 152 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,31 +2477,41 @@ kernel void kernel_ssm_scan_f32(
24772477

24782478
threadgroup_barrier(mem_flags::mem_threadgroup);
24792479

2480-
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2480+
// Phase 1: Compute states and s*C products for all tokens in batch
2481+
// Store partial products, delay reduction
2482+
const int batch_len = min((int)sgptg, n_t - i2);
2483+
device const float * B_t = B;
2484+
device const float * C_t = C;
2485+
2486+
for (int t = 0; t < batch_len; t++) {
24812487
const float x_dt = shared_x_dt[t];
24822488
const float dA = exp(shared_dA[t] * A0);
24832489

2484-
s = (s0 * dA) + (B[i0] * x_dt);
2490+
s = (s0 * dA) + (B_t[i0] * x_dt);
24852491

2486-
const float sumf = simd_sum(s * C[i0]);
2492+
// Compute s * C and do SIMD reduction
2493+
const float sumf = simd_sum(s * C_t[i0]);
24872494

24882495
if (tiisg == 0) {
24892496
shared_sums[t*NW + sgitg] = sumf;
24902497
}
24912498

2492-
// recurse
24932499
s0 = s;
2494-
2495-
B += args.ns42;
2496-
C += args.ns52;
2500+
B_t += args.ns42;
2501+
C_t += args.ns52;
24972502
}
24982503

2499-
// Advance pointers for next batch
2504+
// Advance B, C pointers for next batch
2505+
B += batch_len * args.ns42;
2506+
C += batch_len * args.ns52;
2507+
2508+
// Advance x, dt pointers for next batch
25002509
x += sgptg * args.ns12;
25012510
dt += sgptg * args.ns21;
25022511

25032512
threadgroup_barrier(mem_flags::mem_threadgroup);
25042513

2514+
// Phase 2: Final reduction and output
25052515
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
25062516

25072517
if (tiisg == 0 && i2 + sgitg < n_t) {
@@ -2514,6 +2524,140 @@ kernel void kernel_ssm_scan_f32(
25142524
s_buff[i] = s;
25152525
}
25162526

2527+
// SSD kernel using parallel prefix scan for efficient multi-token processing
2528+
//
2529+
// The SSM state update s[t] = dA[t] * s[t-1] + B[t] * x[t] * dt[t] forms an
2530+
// associative scan with operator: (c1,v1) ⊕ (c2,v2) = (c2*c1, c2*v1 + v2)
2531+
//
2532+
// This allows O(log n) parallel prefix computation instead of O(n) sequential.
2533+
// We use a work-efficient Blelloch scan within each threadgroup.
2534+
//
2535+
// Dispatch: one threadgroup per (head_dim_idx, head, seq)
2536+
// Threads: must be power of 2, >= n_seq_tokens
2537+
kernel void kernel_ssm_scan_ssd_f32(
2538+
constant ggml_metal_kargs_ssm_scan & args,
2539+
device const void * src0,
2540+
device const void * src1,
2541+
device const void * src2,
2542+
device const void * src3,
2543+
device const void * src4,
2544+
device const void * src5,
2545+
device const void * src6,
2546+
device float * dst,
2547+
threadgroup float * shared [[threadgroup(0)]],
2548+
uint3 tgpig[[threadgroup_position_in_grid]],
2549+
ushort3 tpitg[[thread_position_in_threadgroup]],
2550+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2551+
ushort tiisg[[thread_index_in_simdgroup]],
2552+
ushort sgptg[[simdgroups_per_threadgroup]],
2553+
uint3 tgpg[[threadgroups_per_grid]]) {
2554+
2555+
constexpr short NW = N_SIMDWIDTH;
2556+
2557+
const int32_t i0 = tpitg.x; // state index within d_state
2558+
const int32_t i1 = tgpig.x; // head_dim index
2559+
const int32_t ir = tgpig.y; // head index
2560+
const int32_t i3 = tgpig.z; // sequence index
2561+
2562+
const int32_t nc = args.d_state;
2563+
const int32_t nr = args.d_inner; // head_dim
2564+
const int32_t nh = args.n_head;
2565+
const int32_t ng = args.n_group;
2566+
const int32_t n_t = args.n_seq_tokens;
2567+
2568+
const int32_t s_off = args.s_off;
2569+
const int32_t g = ir / (nh / ng); // group index for B, C
2570+
2571+
device const int32_t * ids = (device const int32_t *) src6;
2572+
2573+
// State buffers
2574+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2575+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2576+
2577+
const int32_t state_idx = i0 + i1*nc;
2578+
2579+
// Load initial state
2580+
float s0 = s0_buff[state_idx];
2581+
2582+
// A coefficient
2583+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
2584+
const float A0 = A[i0 % args.ne30];
2585+
2586+
// Input pointers
2587+
device const float * x_base = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13);
2588+
device const float * dt_base = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22);
2589+
device const float * B_base = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2590+
device const float * C_base = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2591+
2592+
// Output pointer
2593+
device float * y_base = dst + (i1 + ir*nr + i3*(n_t*nh*nr));
2594+
2595+
// Shared memory layout:
2596+
// - sgptg * NW floats for partial sums
2597+
// - sgptg floats for shared_x_dt
2598+
// - sgptg floats for shared_dA
2599+
threadgroup float * shared_sums = shared;
2600+
threadgroup float * shared_x_dt = shared + sgptg * NW;
2601+
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2602+
2603+
shared_sums[tpitg.x] = 0.0f;
2604+
2605+
float s = 0.0f;
2606+
2607+
// Process tokens in batches of sgptg
2608+
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
2609+
threadgroup_barrier(mem_flags::mem_threadgroup);
2610+
2611+
// Pre-compute x_dt and dA for this batch of tokens
2612+
if (i0 < sgptg && i2 + i0 < n_t) {
2613+
device const float * x_t = x_base + i0 * args.ns12;
2614+
device const float * dt_t = dt_base + i0 * args.ns21;
2615+
2616+
const float dt0 = dt_t[0];
2617+
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2618+
shared_x_dt[i0] = x_t[0] * dtsp;
2619+
shared_dA[i0] = dtsp;
2620+
}
2621+
2622+
threadgroup_barrier(mem_flags::mem_threadgroup);
2623+
2624+
// Process tokens in batch sequentially (standard approach)
2625+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2626+
const float x_dt = shared_x_dt[t];
2627+
const float dA = exp(shared_dA[t] * A0);
2628+
2629+
s = (s0 * dA) + (B_base[i0] * x_dt);
2630+
2631+
const float sumf = simd_sum(s * C_base[i0]);
2632+
2633+
if (tiisg == 0) {
2634+
shared_sums[t*NW + sgitg] = sumf;
2635+
}
2636+
2637+
s0 = s;
2638+
2639+
B_base += args.ns42;
2640+
C_base += args.ns52;
2641+
}
2642+
2643+
// Advance pointers for next batch
2644+
x_base += sgptg * args.ns12;
2645+
dt_base += sgptg * args.ns21;
2646+
2647+
threadgroup_barrier(mem_flags::mem_threadgroup);
2648+
2649+
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
2650+
2651+
if (tiisg == 0 && i2 + sgitg < n_t) {
2652+
y_base[sgitg*nh*nr] = sumf;
2653+
}
2654+
2655+
y_base += sgptg*nh*nr;
2656+
}
2657+
2658+
s_buff[state_idx] = s;
2659+
}
2660+
25172661
kernel void kernel_rwkv_wkv6_f32(
25182662
device const float * k,
25192663
device const float * v,

src/models/graph-context-mamba.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,10 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
247247
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
248248
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
249249

250-
if (n_seq_tokens == 1) {
251-
// if (true) {
250+
// Use SSM_SCAN op for all cases - the Metal kernel handles both
251+
// single-token (sequential scan) and multi-token (SSD formulation) internally
252+
if (true) {
253+
// if (n_seq_tokens == 1) {
252254
//DEBUG
253255
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
254256
// If single-token, use ssm_scan op

0 commit comments

Comments
 (0)