Skip to content

Commit 98a264a

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: better MoE implementation (#283)
* Make fused MoE reproducible As a bonus, peak performance at pp2048 with u_batch = 2048 is ~8% better. * Slightly better * Also do it for non-fused mul_mat_id --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent f9307d7 commit 98a264a

File tree

1 file changed

+85
-84
lines changed

1 file changed

+85
-84
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 85 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,35 +2164,19 @@ struct mmid_row_mapping {
21642164
int32_t i2;
21652165
};
21662166

2167-
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
2168-
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
2169-
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
2170-
int64_t ne11, int64_t ne10,
2171-
size_t nb11, size_t nb12) {
2172-
int32_t iid1 = blockIdx.x;
2173-
int32_t id = blockIdx.y;
2174-
2175-
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
2176-
2177-
if (row_id_i != i02) {
2178-
return;
2179-
}
2180-
2181-
const int64_t i11 = id % ne11;
2182-
const int64_t i12 = iid1;
2167+
static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous,
2168+
const mmid_row_mapping * __restrict__ row_mapping,
2169+
int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
2170+
int32_t i = blockIdx.x;
21832171

2184-
__shared__ int src1_row;
2185-
if (threadIdx.x == 0) {
2186-
src1_row = atomicAdd(cur_src1_row, 1);
2187-
row_mapping[src1_row] = {id, iid1};
2188-
}
2189-
__syncthreads();
2172+
const int32_t i11 = row_mapping[i].i1 % ne11;
2173+
const int32_t i12 = row_mapping[i].i2;
21902174

2191-
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
2192-
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
2175+
float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
2176+
const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
21932177

2194-
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
2195-
src1_row_contiguous[i] = src1_row_original[i];
2178+
for (int j = threadIdx.x; j < ne10; j += blockDim.x) {
2179+
src_row_contiguous[j] = src_row_original[j];
21962180
}
21972181
}
21982182

@@ -2213,6 +2197,51 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
22132197
}
22142198
}
22152199

2200+
static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
2201+
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
2202+
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
2203+
2204+
GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty());
2205+
2206+
auto stream = ctx.stream();
2207+
2208+
std::vector<char> ids_host(ggml_nbytes(ids));
2209+
const char * ids_dev = (const char *) ids->data;
2210+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2211+
//CUDA_CHECK(cudaStreamSynchronize(stream));
2212+
2213+
std::vector<mmid_row_mapping> rmapping(ids->ne[1]*n_ids);
2214+
moe_counts.resize(n_as, 0);
2215+
cum_moe_counts.resize(n_as + 1);
2216+
2217+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2218+
for (int64_t id = 0; id < n_ids; id++) {
2219+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2220+
if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
2221+
}
2222+
}
2223+
cum_moe_counts[0] = 0;
2224+
for (int i = 0; i < (int)n_as; ++i) {
2225+
cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i];
2226+
}
2227+
2228+
dev_row_mapping.alloc(cum_moe_counts[n_as]);
2229+
2230+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2231+
for (int64_t id = 0; id < n_ids; id++) {
2232+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2233+
if (row_id_i >= 0 && row_id_i < n_as) {
2234+
rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1};
2235+
}
2236+
}
2237+
}
2238+
2239+
for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
2240+
2241+
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2242+
2243+
}
2244+
22162245
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
22172246
const ggml_tensor * src0 = dst->src[0];
22182247
const ggml_tensor * src1 = dst->src[1];
@@ -2273,10 +2302,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22732302
const int64_t n_as = ne02;
22742303
const int64_t n_ids = ids->ne[0];
22752304

2276-
std::vector<char> ids_host(ggml_nbytes(ids));
2277-
const char * ids_dev = (const char *) ids->data;
2278-
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2279-
CUDA_CHECK(cudaStreamSynchronize(stream));
2305+
//std::vector<char> ids_host(ggml_nbytes(ids));
2306+
//const char * ids_dev = (const char *) ids->data;
2307+
//CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2308+
//CUDA_CHECK(cudaStreamSynchronize(stream));
22802309

22812310
ggml_tensor src0_row = *src0;
22822311
ggml_tensor src1_row = *src1;
@@ -2303,6 +2332,9 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23032332
dst_row.nb[3] = nb1;
23042333

23052334
if (ne12 == 1) {
2335+
std::vector<char> ids_host(ggml_nbytes(ids));
2336+
const char * ids_dev = (const char *) ids->data;
2337+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
23062338
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
23072339
for (int64_t id = 0; id < n_ids; id++) {
23082340
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
@@ -2324,44 +2356,32 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23242356
}
23252357
}
23262358
} else {
2359+
2360+
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
2361+
std::vector<int> moe_counts, cum_moe_counts;
2362+
prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2363+
23272364
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
23282365
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
23292366

23302367
src1_row.data = src1_contiguous.get();
23312368
dst_row.data = dst_contiguous.get();
23322369

23332370
for (int64_t i02 = 0; i02 < n_as; i02++) {
2334-
int64_t num_src1_rows = 0;
2335-
2336-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2337-
for (int64_t id = 0; id < n_ids; id++) {
2338-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
23392371

2340-
if (row_id_i != i02) {
2341-
continue;
2342-
}
2343-
2344-
num_src1_rows++;
2345-
}
2346-
}
2372+
int64_t num_src1_rows = moe_counts[i02];
23472373

23482374
if (num_src1_rows == 0) {
23492375
continue;
23502376
}
23512377

2352-
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2353-
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2354-
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
2378+
size_t mapping_offset = cum_moe_counts[i02];
23552379

23562380
{
23572381
dim3 block_dims(std::min((unsigned int)ne10, 768u));
2358-
dim3 grid_dims(ids->ne[1], n_ids);
2359-
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2360-
src1_original, src1_contiguous.get(),
2361-
dev_cur_src1_row.get(), dev_row_mapping.get(),
2362-
ids_dev, i02, ids->nb[1], ids->nb[0],
2363-
ne11, ne10,
2364-
nb11, nb12);
2382+
dim3 grid_dims(num_src1_rows);
2383+
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2384+
src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12);
23652385
CUDA_CHECK(cudaGetLastError());
23662386
}
23672387

@@ -2387,7 +2407,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23872407
dim3 grid_dims(num_src1_rows);
23882408
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
23892409
dst_original, dst_contiguous.get(),
2390-
dev_row_mapping.get(),
2410+
dev_row_mapping.get() + mapping_offset,
23912411
ne0,
23922412
nb1, nb2);
23932413
CUDA_CHECK(cudaGetLastError());
@@ -2642,41 +2662,22 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26422662

26432663
bool first = false; //true;
26442664

2645-
for (int64_t i02 = 0; i02 < n_as; i02++) {
2646-
int64_t num_src1_rows = 0;
2647-
2648-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2649-
for (int64_t id = 0; id < n_ids; id++) {
2650-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2651-
2652-
if (row_id_i < 0 || row_id_i >= n_as) continue;
2653-
//GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
2654-
2655-
if (row_id_i != i02) {
2656-
continue;
2657-
}
2665+
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
2666+
std::vector<int> moe_counts, cum_moe_counts;
26582667

2659-
num_src1_rows++;
2660-
}
2661-
}
2668+
prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
26622669

2663-
if (num_src1_rows == 0) {
2664-
continue;
2665-
}
2670+
for (int64_t i02 = 0; i02 < n_as; i02++) {
2671+
int64_t num_src1_rows = moe_counts[i02];
26662672

2667-
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2668-
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2669-
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
2673+
if (num_src1_rows == 0) continue;
2674+
size_t mapping_offset = cum_moe_counts[i02];
26702675

26712676
{
26722677
dim3 block_dims(std::min((unsigned int)ne10, 768u));
2673-
dim3 grid_dims(ids->ne[1], n_ids);
2674-
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2675-
src1_original, src1_contiguous.get(),
2676-
dev_cur_src1_row.get(), dev_row_mapping.get(),
2677-
ids_dev, i02, ids->nb[1], ids->nb[0],
2678-
ne11, ne10,
2679-
nb11, nb12);
2678+
dim3 grid_dims(num_src1_rows);
2679+
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2680+
src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12);
26802681
CUDA_CHECK(cudaGetLastError());
26812682
}
26822683

@@ -2733,7 +2734,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27332734
dim3 grid_dims(num_src1_rows);
27342735
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
27352736
(char *)next->data, final_dst_contiguous.get(),
2736-
dev_row_mapping.get(),
2737+
dev_row_mapping.get() + mapping_offset,
27372738
next->ne[0],
27382739
next->nb[1], next->nb[2]);
27392740
CUDA_CHECK(cudaGetLastError());
@@ -2745,7 +2746,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27452746
dim3 grid_dims(num_src1_rows);
27462747
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
27472748
dst_original, dst_gate_contiguous.get(),
2748-
dev_row_mapping.get(),
2749+
dev_row_mapping.get() + mapping_offset,
27492750
ne0,
27502751
nb1, nb2);
27512752
CUDA_CHECK(cudaGetLastError());

0 commit comments

Comments
 (0)