|
1 | 1 | #include <algorithm> |
2 | 2 | #include "cumsum.cuh" |
3 | 3 | #include "convert.cuh" |
| 4 | +#include "ggml-cuda/common.cuh" |
4 | 5 | #include "ggml.h" |
5 | 6 |
|
6 | | -#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) |
7 | | -# define CUMSUM_WARP_SIZE 64 |
8 | | -#else |
9 | | -# define CUMSUM_WARP_SIZE 32 |
10 | | -#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) |
11 | | - |
12 | 7 | #ifdef GGML_CUDA_USE_CUB |
13 | 8 | # include <cub/device/device_scan.cuh> |
14 | 9 | #endif |
@@ -85,9 +80,10 @@ static __global__ void cumsum_kernel( |
85 | 80 | GGML_UNUSED_VARS(nb00, nb0); |
86 | 81 |
|
87 | 82 | const int tid = threadIdx.x; |
88 | | - const int lane = tid & (CUMSUM_WARP_SIZE - 1); |
89 | | - const int warp = tid / CUMSUM_WARP_SIZE; |
90 | | - const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE; |
| 83 | + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |
| 84 | + const int lane = tid & (warp_size - 1); |
| 85 | + const int warp = tid / warp_size; |
| 86 | + const int warps_per_block = blockDim.x / warp_size; |
91 | 87 |
|
92 | 88 | extern __shared__ float smem[]; |
93 | 89 | float* s_vals = smem; |
@@ -116,19 +112,19 @@ static __global__ void cumsum_kernel( |
116 | 112 | float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f; |
117 | 113 |
|
118 | 114 | // 1. Warp inclusive scan |
119 | | - val = warp_prefix_inclusive_sum(val); |
| 115 | + val = warp_prefix_inclusive_sum<T, warp_size>(val); |
120 | 116 | s_vals[tid] = val; |
121 | 117 |
|
122 | 118 | // Store warp total |
123 | | - if (lane == CUMSUM_WARP_SIZE - 1) { |
| 119 | + if (lane == warp_size - 1) { |
124 | 120 | s_warp_sums[warp] = val; |
125 | 121 | } |
126 | 122 | __syncthreads(); |
127 | 123 |
|
128 | 124 | // 2. Exclusive scan of warp sums (warp 0 only) |
129 | 125 | if (warp == 0) { |
130 | 126 | float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; |
131 | | - float inc = warp_prefix_inclusive_sum(w); |
| 127 | + float inc = warp_prefix_inclusive_sum<T, warp_size>(w); |
132 | 128 | if (tid < warps_per_block) { |
133 | 129 | s_warp_sums[tid] = inc - w; // exclusive sum |
134 | 130 | } |
@@ -172,11 +168,12 @@ static void cumsum_cuda( |
172 | 168 | } |
173 | 169 | #endif // GGML_CUDA_USE_CUB |
174 | 170 | dim3 grid_dims(ne01, ne02, ne03); |
175 | | - const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE; |
176 | | - int block_size = num_warps * CUMSUM_WARP_SIZE; |
| 171 | + constexpr int warp_size = ggml_cuda_get_physical_warp_size_host(); |
| 172 | + const int num_warps = (ne00 + warp_size - 1) / warp_size; |
| 173 | + int block_size = num_warps * warp_size; |
177 | 174 | block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); |
178 | 175 | dim3 block_dims(block_size, 1, 1); |
179 | | - const int warps_per_block = block_size / CUMSUM_WARP_SIZE; |
| 176 | + const int warps_per_block = block_size / warp_size; |
180 | 177 | const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); |
181 | 178 |
|
182 | 179 | if (use_cub) { |
|
0 commit comments