|
3 | 3 | #include "convert.cuh" |
4 | 4 | #include "ggml.h" |
5 | 5 |
|
| 6 | +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) |
| 7 | +# define CUMSUM_WARP_SIZE 64 |
| 8 | +#else |
| 9 | +# define CUMSUM_WARP_SIZE 32 |
| 10 | +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) |
| 11 | + |
6 | 12 | #ifdef GGML_CUDA_USE_CUB |
7 | 13 | # include <cub/device/device_scan.cuh> |
8 | 14 | #endif |
9 | 15 |
|
10 | | - |
11 | 16 | template<typename T, int BLOCK_SIZE> |
12 | 17 | static __global__ void cumsum_cub_kernel( |
13 | 18 | const T* __restrict__ src, |
@@ -80,9 +85,9 @@ static __global__ void cumsum_kernel( |
80 | 85 | GGML_UNUSED_VARS(nb00, nb0); |
81 | 86 |
|
82 | 87 | const int tid = threadIdx.x; |
83 | | - const int lane = tid & (WARP_SIZE - 1); |
84 | | - const int warp = tid / WARP_SIZE; |
85 | | - const int warps_per_block = blockDim.x / WARP_SIZE; |
| 88 | + const int lane = tid & (CUMSUM_WARP_SIZE - 1); |
| 89 | + const int warp = tid / CUMSUM_WARP_SIZE; |
| 90 | + const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE; |
86 | 91 |
|
87 | 92 | extern __shared__ float smem[]; |
88 | 93 | float* s_vals = smem; |
@@ -115,7 +120,7 @@ static __global__ void cumsum_kernel( |
115 | 120 | s_vals[tid] = val; |
116 | 121 |
|
117 | 122 | // Store warp total |
118 | | - if (lane == WARP_SIZE - 1) { |
| 123 | + if (lane == CUMSUM_WARP_SIZE - 1) { |
119 | 124 | s_warp_sums[warp] = val; |
120 | 125 | } |
121 | 126 | __syncthreads(); |
@@ -167,11 +172,11 @@ static void cumsum_cuda( |
167 | 172 | } |
168 | 173 | #endif // GGML_CUDA_USE_CUB |
169 | 174 | dim3 grid_dims(ne01, ne02, ne03); |
170 | | - const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; |
171 | | - int block_size = num_warps * WARP_SIZE; |
| 175 | + const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE; |
| 176 | + int block_size = num_warps * CUMSUM_WARP_SIZE; |
172 | 177 | block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); |
173 | 178 | dim3 block_dims(block_size, 1, 1); |
174 | | - const int warps_per_block = block_size / WARP_SIZE; |
| 179 | + const int warps_per_block = block_size / CUMSUM_WARP_SIZE; |
175 | 180 | const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); |
176 | 181 |
|
177 | 182 | if (use_cub) { |
|
0 commit comments