Skip to content

Commit 5aa7438

Browse files
committed
Vary warp-size based on physical warp size
1 parent 069413a commit 5aa7438

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

ggml/src/ggml-cuda/cumsum.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
#include "convert.cuh"
44
#include "ggml.h"
55

6+
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
7+
# define CUMSUM_WARP_SIZE 64
8+
#else
9+
# define CUMSUM_WARP_SIZE 32
10+
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
11+
612
#ifdef GGML_CUDA_USE_CUB
713
# include <cub/device/device_scan.cuh>
814
#endif
915

10-
1116
template<typename T, int BLOCK_SIZE>
1217
static __global__ void cumsum_cub_kernel(
1318
const T* __restrict__ src,
@@ -80,9 +85,9 @@ static __global__ void cumsum_kernel(
8085
GGML_UNUSED_VARS(nb00, nb0);
8186

8287
const int tid = threadIdx.x;
83-
const int lane = tid & (WARP_SIZE - 1);
84-
const int warp = tid / WARP_SIZE;
85-
const int warps_per_block = blockDim.x / WARP_SIZE;
88+
const int lane = tid & (CUMSUM_WARP_SIZE - 1);
89+
const int warp = tid / CUMSUM_WARP_SIZE;
90+
const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE;
8691

8792
extern __shared__ float smem[];
8893
float* s_vals = smem;
@@ -115,7 +120,7 @@ static __global__ void cumsum_kernel(
115120
s_vals[tid] = val;
116121

117122
// Store warp total
118-
if (lane == WARP_SIZE - 1) {
123+
if (lane == CUMSUM_WARP_SIZE - 1) {
119124
s_warp_sums[warp] = val;
120125
}
121126
__syncthreads();
@@ -167,11 +172,11 @@ static void cumsum_cuda(
167172
}
168173
#endif // GGML_CUDA_USE_CUB
169174
dim3 grid_dims(ne01, ne02, ne03);
170-
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
171-
int block_size = num_warps * WARP_SIZE;
175+
const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE;
176+
int block_size = num_warps * CUMSUM_WARP_SIZE;
172177
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
173178
dim3 block_dims(block_size, 1, 1);
174-
const int warps_per_block = block_size / WARP_SIZE;
179+
const int warps_per_block = block_size / CUMSUM_WARP_SIZE;
175180
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
176181

177182
if (use_cub) {

0 commit comments

Comments
 (0)