Skip to content

Commit e5587cb

Browse files
committed
feat(ggml-metal): Efficient implementation of cumsum for metal
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 78e137f commit e5587cb

File tree

2 files changed

+29
-37
lines changed

2 files changed

+29
-37
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum(ggml_metal_library_
330330

331331
snprintf(name, 256, "%s", base);
332332

333+
// reuse existing precompiled pipeline, but allow memory size setting
333334
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
334-
if (res) {
335-
return res;
335+
if (!res) {
336+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
336337
}
337338

338-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
339-
340-
// shared memory buffer for a single simd group size
341-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
339+
// one shared memory element for each simd group in the threadgroup
340+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
341+
const int nsg = (ne00 + 31)/32;
342+
ggml_metal_pipeline_set_smem(res, nsg*sizeof(float));
342343

343344
return res;
344345
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,7 @@ kernel void kernel_cumsum(
18011801
constant ggml_metal_kargs_cumsum & args,
18021802
device const char * src0,
18031803
device const char * dst,
1804-
threadgroup float * shmem_f32 [[threadgroup(0)]],
1804+
threadgroup float * shmem_f32 [[threadgroup(0)]],
18051805
uint3 tgpig[[threadgroup_position_in_grid]],
18061806
ushort3 tpitg[[thread_position_in_threadgroup]],
18071807
ushort sgitg[[simdgroup_index_in_threadgroup]],
@@ -1822,40 +1822,31 @@ kernel void kernel_cumsum(
18221822
// threadgroup, so this will loop once for each index that this thread is
18231823
// responsible for
18241824
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1825-
//DEBUG -- This is the _very_ neive version
1826-
dst_row[i0] = src_row[i0];
1827-
for (int64_t j = 0; j < i0; ++j) {
1828-
dst_row[i0] = static_cast<T>(static_cast<float>(src_row[j]) + static_cast<float>(dst_row[i0]));
1829-
}
1830-
}
1831-
1832-
// if (sgitg == 0) {
1833-
// shmem_f32[tiisg] = 0.0f;
1834-
// }
1835-
1836-
1837-
// float sumf = 0;
1838-
1839-
// for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1840-
// sumf += src_row[i0];
1841-
// }
18421825

1843-
// sumf = simd_sum(sumf);
1826+
// Each thread does simd_prefix_inclusive_sum => every element of row
1827+
// now holds cumsum of the simd group
1828+
float sumf = static_cast<float>(src_row[i0]);
1829+
sumf = simd_prefix_inclusive_sum(sumf);
1830+
dst_row[i0] = static_cast<T>(sumf);
18441831

1845-
// threadgroup_barrier(mem_flags::mem_threadgroup);
1846-
1847-
// if (tiisg == 0) {
1848-
// shmem_f32[sgitg] = sumf;
1849-
// }
1850-
1851-
// threadgroup_barrier(mem_flags::mem_threadgroup);
1832+
// If this is the last element of the simd group, store its value in
1833+
// shared memory
1834+
if (tiisg == N_SIMDWIDTH - 1 || i0 == args.ne00 - 1) {
1835+
const ushort shmem_idx = i0 / N_SIMDWIDTH;
1836+
shmem_f32[shmem_idx] = sumf;
1837+
}
1838+
}
18521839

1853-
// sumf = shmem_f32[tiisg];
1854-
// sumf = simd_sum(sumf);
1840+
// Ensure all simd groups sync here before proceeding
1841+
threadgroup_barrier(mem_flags::mem_threadgroup);
18551842

1856-
// if (tpitg.x == 0) {
1857-
// dst_row[0] = norm ? sumf / args.ne00 : sumf;
1858-
// }
1843+
// Each element then adds the final value of all preceding simd groups
1844+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1845+
const ushort shmem_idx = i0 / N_SIMDWIDTH;
1846+
for (ushort j = 0; j < shmem_idx; ++j) {
1847+
dst_row[i0] += static_cast<T>(shmem_f32[j]);
1848+
}
1849+
}
18591850
}
18601851

18611852
typedef decltype(kernel_cumsum<float>) kernel_cumsum_t;

0 commit comments

Comments
 (0)