@@ -1801,7 +1801,7 @@ kernel void kernel_cumsum(
18011801 constant ggml_metal_kargs_cumsum & args,
18021802 device const char * src0,
18031803 device const char * dst,
1804- threadgroup float * shmem_f32 [[threadgroup(0 )]],
1804+ threadgroup float * shmem_f32 [[threadgroup(0 )]],
18051805 uint3 tgpig[[threadgroup_position_in_grid]],
18061806 ushort3 tpitg[[thread_position_in_threadgroup]],
18071807 ushort sgitg[[simdgroup_index_in_threadgroup]],
@@ -1822,40 +1822,31 @@ kernel void kernel_cumsum(
18221822 // threadgroup, so this will loop once for each index that this thread is
18231823 // responsible for
18241824 for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1825- // DEBUG -- This is the _very_ neive version
1826- dst_row[i0] = src_row[i0];
1827- for (int64_t j = 0 ; j < i0; ++j) {
1828- dst_row[i0] = static_cast <T>(static_cast <float >(src_row[j]) + static_cast <float >(dst_row[i0]));
1829- }
1830- }
1831-
1832- // if (sgitg == 0) {
1833- // shmem_f32[tiisg] = 0.0f;
1834- // }
1835-
1836-
1837- // float sumf = 0;
1838-
1839- // for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1840- // sumf += src_row[i0];
1841- // }
18421825
1843- // sumf = simd_sum(sumf);
1826+ // Each thread does simd_prefix_inclusive_sum => every element of row
1827+ // now holds cumsum of the simd group
1828+ float sumf = static_cast <float >(src_row[i0]);
1829+ sumf = simd_prefix_inclusive_sum (sumf);
1830+ dst_row[i0] = static_cast <T>(sumf);
18441831
1845- // threadgroup_barrier(mem_flags::mem_threadgroup);
1846-
1847- // if (tiisg == 0 ) {
1848- // shmem_f32[sgitg] = sumf ;
1849- // }
1850-
1851- // threadgroup_barrier(mem_flags::mem_threadgroup);
1832+ // If this is the last element of the simd group, store its value in
1833+ // shared memory
1834+ if (tiisg == N_SIMDWIDTH - 1 || i0 == args. ne00 - 1 ) {
1835+ const ushort shmem_idx = i0 / N_SIMDWIDTH ;
1836+ shmem_f32[shmem_idx] = sumf;
1837+ }
1838+ }
18521839
1853- // sumf = shmem_f32[tiisg];
1854- // sumf = simd_sum(sumf );
1840+ // Ensure all simd groups sync here before proceeding
1841+ threadgroup_barrier (mem_flags::mem_threadgroup );
18551842
1856- // if (tpitg.x == 0) {
1857- // dst_row[0] = norm ? sumf / args.ne00 : sumf;
1858- // }
1843+ // Each element then adds the final value of all preceding simd groups
1844+ for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1845+ const ushort shmem_idx = i0 / N_SIMDWIDTH;
1846+ for (ushort j = 0 ; j < shmem_idx; ++j) {
1847+ dst_row[i0] += static_cast <T>(shmem_f32[j]);
1848+ }
1849+ }
18591850}
18601851
18611852typedef decltype (kernel_cumsum<float >) kernel_cumsum_t;
0 commit comments