11#include < algorithm>
22#include " cumsum.cuh"
33#include " convert.cuh"
4+ #include " ggml.h"
45
56#ifdef GGML_CUDA_USE_CUB
67# include < cub/device/device_scan.cuh>
8+ #endif
9+
710
811template <typename T, int BLOCK_SIZE>
912static __global__ void cumsum_cub_kernel (
1013 const T* __restrict__ src,
1114 T* __restrict__ dst,
1215 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1316 const int64_t nb01, const int64_t nb02, const int64_t nb03,
14- const int64_t nb1, const int64_t nb2, const int64_t nb3)
15- {
17+ const int64_t nb1, const int64_t nb2, const int64_t nb3) {
18+ # ifdef GGML_CUDA_USE_CUB
1619 using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
1720
1821 __shared__ typename BlockScan::TempStorage temp_storage;
@@ -61,17 +64,10 @@ static __global__ void cumsum_cub_kernel(
6164
6265 __syncthreads ();
6366 }
64- }
6567#else
66- template <typename T, int BLOCK_SIZE>
67- static __global__ void cumsum_cub_kernel (
68- const T* __restrict__ src,
69- T* __restrict__ dst,
70- const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
71- const int64_t nb01, const int64_t nb02, const int64_t nb03,
72- const int64_t nb1, const int64_t nb2, const int64_t nb3) {}
73- // empty function to avoid triggering compilation errors on non-CUB paths, just in case compiler doesn't optimize away
74- #endif // GGML_CUDA_USE_CUB
68+ NO_DEVICE_CODE;
69+ #endif
70+ }
7571
7672// Fallback kernel implementation (original)
7773template <typename T>
@@ -81,6 +77,8 @@ static __global__ void cumsum_kernel(
8177 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
8278 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
8379
80+ GGML_UNUSED_VARS (nb00, nb0);
81+
8482 const int tid = threadIdx .x ;
8583 const int lane = tid & (WARP_SIZE - 1 );
8684 const int warp = tid / WARP_SIZE;
@@ -138,7 +136,7 @@ static __global__ void cumsum_kernel(
138136 float carry = *s_carry;
139137 float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
140138 if (idx < ne00) {
141- dst_row[idx] = static_cast <T >(final_val);
139+ dst_row[idx] = ggml_cuda_cast<T, float >(final_val);
142140 }
143141 __syncthreads ();
144142
0 commit comments