Skip to content

Commit 069413a

Browse files
committed
Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS
1 parent bbe3743 commit 069413a

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
#include <algorithm>
22
#include "cumsum.cuh"
33
#include "convert.cuh"
4+
#include "ggml.h"
45

56
#ifdef GGML_CUDA_USE_CUB
67
# include <cub/device/device_scan.cuh>
8+
#endif
9+
710

811
template<typename T, int BLOCK_SIZE>
912
static __global__ void cumsum_cub_kernel(
1013
const T* __restrict__ src,
1114
T* __restrict__ dst,
1215
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1316
const int64_t nb01, const int64_t nb02, const int64_t nb03,
14-
const int64_t nb1, const int64_t nb2, const int64_t nb3)
15-
{
17+
const int64_t nb1, const int64_t nb2, const int64_t nb3) {
18+
#ifdef GGML_CUDA_USE_CUB
1619
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
1720

1821
__shared__ typename BlockScan::TempStorage temp_storage;
@@ -61,17 +64,10 @@ static __global__ void cumsum_cub_kernel(
6164

6265
__syncthreads();
6366
}
64-
}
6567
#else
66-
template<typename T, int BLOCK_SIZE>
67-
static __global__ void cumsum_cub_kernel(
68-
const T* __restrict__ src,
69-
T* __restrict__ dst,
70-
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
71-
const int64_t nb01, const int64_t nb02, const int64_t nb03,
72-
const int64_t nb1, const int64_t nb2, const int64_t nb3) {}
73-
// empty function to avoid triggering compilation errors on non-CUB paths, just in case compiler doesn't optimize away
74-
#endif // GGML_CUDA_USE_CUB
68+
NO_DEVICE_CODE;
69+
#endif
70+
}
7571

7672
// Fallback kernel implementation (original)
7773
template<typename T>
@@ -81,6 +77,8 @@ static __global__ void cumsum_kernel(
8177
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
8278
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
8379

80+
GGML_UNUSED_VARS(nb00, nb0);
81+
8482
const int tid = threadIdx.x;
8583
const int lane = tid & (WARP_SIZE - 1);
8684
const int warp = tid / WARP_SIZE;
@@ -138,7 +136,7 @@ static __global__ void cumsum_kernel(
138136
float carry = *s_carry;
139137
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
140138
if (idx < ne00) {
141-
dst_row[idx] = static_cast<T>(final_val);
139+
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
142140
}
143141
__syncthreads();
144142

0 commit comments

Comments
 (0)