Skip to content

Commit 08b3f2d

Browse files
committed
Use constexpr and call prefix_inclusive with warp_size template param
1 parent 579eba6 commit 08b3f2d

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
319319
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
320320
}
321321

322+
static constexpr __host__ int ggml_cuda_get_physical_warp_size_host() {
323+
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
324+
return 64;
325+
#else
326+
return 32;
327+
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
328+
}
329+
330+
322331
// Maximum number of bytes that can be copied in a single instruction.
323332
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
324333
#ifdef GGML_USE_HIP

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
#include <algorithm>
22
#include "cumsum.cuh"
33
#include "convert.cuh"
4+
#include "ggml-cuda/common.cuh"
45
#include "ggml.h"
56

6-
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
7-
# define CUMSUM_WARP_SIZE 64
8-
#else
9-
# define CUMSUM_WARP_SIZE 32
10-
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
11-
127
#ifdef GGML_CUDA_USE_CUB
138
# include <cub/device/device_scan.cuh>
149
#endif
@@ -85,9 +80,10 @@ static __global__ void cumsum_kernel(
8580
GGML_UNUSED_VARS(nb00, nb0);
8681

8782
const int tid = threadIdx.x;
88-
const int lane = tid & (CUMSUM_WARP_SIZE - 1);
89-
const int warp = tid / CUMSUM_WARP_SIZE;
90-
const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE;
83+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
84+
const int lane = tid & (warp_size - 1);
85+
const int warp = tid / warp_size;
86+
const int warps_per_block = blockDim.x / warp_size;
9187

9288
extern __shared__ float smem[];
9389
float* s_vals = smem;
@@ -116,19 +112,19 @@ static __global__ void cumsum_kernel(
116112
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
117113

118114
// 1. Warp inclusive scan
119-
val = warp_prefix_inclusive_sum(val);
115+
val = warp_prefix_inclusive_sum<T, warp_size>(val);
120116
s_vals[tid] = val;
121117

122118
// Store warp total
123-
if (lane == CUMSUM_WARP_SIZE - 1) {
119+
if (lane == warp_size - 1) {
124120
s_warp_sums[warp] = val;
125121
}
126122
__syncthreads();
127123

128124
// 2. Exclusive scan of warp sums (warp 0 only)
129125
if (warp == 0) {
130126
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
131-
float inc = warp_prefix_inclusive_sum(w);
127+
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
132128
if (tid < warps_per_block) {
133129
s_warp_sums[tid] = inc - w; // exclusive sum
134130
}
@@ -172,11 +168,12 @@ static void cumsum_cuda(
172168
}
173169
#endif // GGML_CUDA_USE_CUB
174170
dim3 grid_dims(ne01, ne02, ne03);
175-
const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE;
176-
int block_size = num_warps * CUMSUM_WARP_SIZE;
171+
constexpr int warp_size = ggml_cuda_get_physical_warp_size_host();
172+
const int num_warps = (ne00 + warp_size - 1) / warp_size;
173+
int block_size = num_warps * warp_size;
177174
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
178175
dim3 block_dims(block_size, 1, 1);
179-
const int warps_per_block = block_size / CUMSUM_WARP_SIZE;
176+
const int warps_per_block = block_size / warp_size;
180177
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
181178

182179
if (use_cub) {

0 commit comments

Comments
 (0)