@@ -2164,35 +2164,19 @@ struct mmid_row_mapping {
21642164 int32_t i2;
21652165};
21662166
2167- static __global__ void k_copy_src1_to_contiguous (const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
2168- int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
2169- const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
2170- int64_t ne11, int64_t ne10,
2171- size_t nb11, size_t nb12) {
2172- int32_t iid1 = blockIdx .x ;
2173- int32_t id = blockIdx .y ;
2174-
2175- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
2176-
2177- if (row_id_i != i02) {
2178- return ;
2179- }
2180-
2181- const int64_t i11 = id % ne11;
2182- const int64_t i12 = iid1;
2167+ static __global__ void k_copy_src_to_contiguous (const char * __restrict__ src_original, char * __restrict__ src_contiguous,
2168+ const mmid_row_mapping * __restrict__ row_mapping,
2169+ int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
2170+ int32_t i = blockIdx .x ;
21832171
2184- __shared__ int src1_row;
2185- if (threadIdx .x == 0 ) {
2186- src1_row = atomicAdd (cur_src1_row, 1 );
2187- row_mapping[src1_row] = {id, iid1};
2188- }
2189- __syncthreads ();
2172+ const int32_t i11 = row_mapping[i].i1 % ne11;
2173+ const int32_t i12 = row_mapping[i].i2 ;
21902174
2191- const float * src1_row_original = (const float *)(src1_original + i11 *nb11 + i12*nb12 );
2192- float * src1_row_contiguous = (float *)(src1_contiguous + src1_row *nb11);
2175+ float * src_row_contiguous = (float *)(src_contiguous + i *nb11);
2176+ const float * src_row_original = (const float *)(src_original + i11 *nb11 + i12*nb12 );
21932177
2194- for (int i = threadIdx .x ; i < ne10; i += blockDim .x ) {
2195- src1_row_contiguous[i ] = src1_row_original[i ];
2178+ for (int j = threadIdx .x ; j < ne10; j += blockDim .x ) {
2179+ src_row_contiguous[j ] = src_row_original[j ];
21962180 }
21972181}
21982182
@@ -2213,6 +2197,51 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
22132197 }
22142198}
22152199
2200+ static inline void prepare_row_mappigs (ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
2201+ const ggml_tensor * ids, std::vector<int >& moe_counts, std::vector<int >& cum_moe_counts,
2202+ ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
2203+
2204+ GGML_ASSERT (moe_counts.empty () && cum_moe_counts.empty ());
2205+
2206+ auto stream = ctx.stream ();
2207+
2208+ std::vector<char > ids_host (ggml_nbytes (ids));
2209+ const char * ids_dev = (const char *) ids->data ;
2210+ CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
2211+ // CUDA_CHECK(cudaStreamSynchronize(stream));
2212+
2213+ std::vector<mmid_row_mapping> rmapping (ids->ne [1 ]*n_ids);
2214+ moe_counts.resize (n_as, 0 );
2215+ cum_moe_counts.resize (n_as + 1 );
2216+
2217+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2218+ for (int64_t id = 0 ; id < n_ids; id++) {
2219+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2220+ if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
2221+ }
2222+ }
2223+ cum_moe_counts[0 ] = 0 ;
2224+ for (int i = 0 ; i < (int )n_as; ++i) {
2225+ cum_moe_counts[i+1 ] = cum_moe_counts[i] + moe_counts[i];
2226+ }
2227+
2228+ dev_row_mapping.alloc (cum_moe_counts[n_as]);
2229+
2230+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2231+ for (int64_t id = 0 ; id < n_ids; id++) {
2232+ const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2233+ if (row_id_i >= 0 && row_id_i < n_as) {
2234+ rmapping[cum_moe_counts[row_id_i]++] = {(int )id, (int )iid1};
2235+ }
2236+ }
2237+ }
2238+
2239+ for (int i = 0 ; i < (int )n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
2240+
2241+ CUDA_CHECK (cudaMemcpyAsync (dev_row_mapping.get (), rmapping.data (), cum_moe_counts[n_as]*sizeof (mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2242+
2243+ }
2244+
22162245static void ggml_cuda_mul_mat_id (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
22172246 const ggml_tensor * src0 = dst->src [0 ];
22182247 const ggml_tensor * src1 = dst->src [1 ];
@@ -2273,10 +2302,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22732302 const int64_t n_as = ne02;
22742303 const int64_t n_ids = ids->ne [0 ];
22752304
2276- std::vector<char > ids_host (ggml_nbytes (ids));
2277- const char * ids_dev = (const char *) ids->data ;
2278- CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
2279- CUDA_CHECK (cudaStreamSynchronize (stream));
2305+ // std::vector<char> ids_host(ggml_nbytes(ids));
2306+ // const char * ids_dev = (const char *) ids->data;
2307+ // CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2308+ // CUDA_CHECK(cudaStreamSynchronize(stream));
22802309
22812310 ggml_tensor src0_row = *src0;
22822311 ggml_tensor src1_row = *src1;
@@ -2303,6 +2332,9 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23032332 dst_row.nb [3 ] = nb1;
23042333
23052334 if (ne12 == 1 ) {
2335+ std::vector<char > ids_host (ggml_nbytes (ids));
2336+ const char * ids_dev = (const char *) ids->data ;
2337+ CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids_dev, ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
23062338 for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
23072339 for (int64_t id = 0 ; id < n_ids; id++) {
23082340 const int32_t i02 = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
@@ -2324,44 +2356,32 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23242356 }
23252357 }
23262358 } else {
2359+
2360+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool ());
2361+ std::vector<int > moe_counts, cum_moe_counts;
2362+ prepare_row_mappigs (ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2363+
23272364 ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
23282365 ggml_cuda_pool_alloc<char > dst_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
23292366
23302367 src1_row.data = src1_contiguous.get ();
23312368 dst_row.data = dst_contiguous.get ();
23322369
23332370 for (int64_t i02 = 0 ; i02 < n_as; i02++) {
2334- int64_t num_src1_rows = 0 ;
2335-
2336- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2337- for (int64_t id = 0 ; id < n_ids; id++) {
2338- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
23392371
2340- if (row_id_i != i02) {
2341- continue ;
2342- }
2343-
2344- num_src1_rows++;
2345- }
2346- }
2372+ int64_t num_src1_rows = moe_counts[i02];
23472373
23482374 if (num_src1_rows == 0 ) {
23492375 continue ;
23502376 }
23512377
2352- ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2353- ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2354- CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
2378+ size_t mapping_offset = cum_moe_counts[i02];
23552379
23562380 {
23572381 dim3 block_dims (std::min ((unsigned int )ne10, 768u ));
2358- dim3 grid_dims (ids->ne [1 ], n_ids);
2359- k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2360- src1_original, src1_contiguous.get (),
2361- dev_cur_src1_row.get (), dev_row_mapping.get (),
2362- ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2363- ne11, ne10,
2364- nb11, nb12);
2382+ dim3 grid_dims (num_src1_rows);
2383+ k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2384+ src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
23652385 CUDA_CHECK (cudaGetLastError ());
23662386 }
23672387
@@ -2387,7 +2407,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23872407 dim3 grid_dims (num_src1_rows);
23882408 k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
23892409 dst_original, dst_contiguous.get (),
2390- dev_row_mapping.get (),
2410+ dev_row_mapping.get () + mapping_offset ,
23912411 ne0,
23922412 nb1, nb2);
23932413 CUDA_CHECK (cudaGetLastError ());
@@ -2642,41 +2662,22 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26422662
26432663 bool first = false ; // true;
26442664
2645- for (int64_t i02 = 0 ; i02 < n_as; i02++) {
2646- int64_t num_src1_rows = 0 ;
2647-
2648- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2649- for (int64_t id = 0 ; id < n_ids; id++) {
2650- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2651-
2652- if (row_id_i < 0 || row_id_i >= n_as) continue ;
2653- // GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
2654-
2655- if (row_id_i != i02) {
2656- continue ;
2657- }
2665+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool ());
2666+ std::vector<int > moe_counts, cum_moe_counts;
26582667
2659- num_src1_rows++;
2660- }
2661- }
2668+ prepare_row_mappigs (ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
26622669
2663- if (num_src1_rows == 0 ) {
2664- continue ;
2665- }
2670+ for (int64_t i02 = 0 ; i02 < n_as; i02++) {
2671+ int64_t num_src1_rows = moe_counts[i02];
26662672
2667- ggml_cuda_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
2668- ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
2669- CUDA_CHECK (cudaMemsetAsync (dev_cur_src1_row.get (), 0 , sizeof (int ), stream));
2673+ if (num_src1_rows == 0 ) continue ;
2674+ size_t mapping_offset = cum_moe_counts[i02];
26702675
26712676 {
26722677 dim3 block_dims (std::min ((unsigned int )ne10, 768u ));
2673- dim3 grid_dims (ids->ne [1 ], n_ids);
2674- k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2675- src1_original, src1_contiguous.get (),
2676- dev_cur_src1_row.get (), dev_row_mapping.get (),
2677- ids_dev, i02, ids->nb [1 ], ids->nb [0 ],
2678- ne11, ne10,
2679- nb11, nb12);
2678+ dim3 grid_dims (num_src1_rows);
2679+ k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
2680+ src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
26802681 CUDA_CHECK (cudaGetLastError ());
26812682 }
26822683
@@ -2733,7 +2734,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27332734 dim3 grid_dims (num_src1_rows);
27342735 k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
27352736 (char *)next->data , final_dst_contiguous.get (),
2736- dev_row_mapping.get (),
2737+ dev_row_mapping.get () + mapping_offset ,
27372738 next->ne [0 ],
27382739 next->nb [1 ], next->nb [2 ]);
27392740 CUDA_CHECK (cudaGetLastError ());
@@ -2745,7 +2746,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27452746 dim3 grid_dims (num_src1_rows);
27462747 k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
27472748 dst_original, dst_gate_contiguous.get (),
2748- dev_row_mapping.get (),
2749+ dev_row_mapping.get () + mapping_offset ,
27492750 ne0,
27502751 nb1, nb2);
27512752 CUDA_CHECK (cudaGetLastError ());
0 commit comments