From 8f5f93e6b11d039e24b7253b1883b13bfc75f41f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 18 Oct 2025 10:09:32 +0300 Subject: [PATCH 1/9] Fuse sigmoid+add+grouped_topk+get_rows (CPU) --- ggml/src/ggml.c | 12 +++- ggml/src/iqk/iqk_cpu_ops.cpp | 115 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_cpu_ops.h | 2 + src/llama-build-context.cpp | 5 +- 4 files changed, 131 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 63c0b9950..5ca73ae54 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22611,7 +22611,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_UNARY: { - ggml_compute_forward_unary(params, tensor); + const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor); + if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { + iqk_bailingmoev2_experts(cgraph->nodes[i+4], cgraph->nodes[i+3], params->ith, params->nth); + i += 4; + } else { + ggml_compute_forward_unary(params, tensor); + } } break; case GGML_OP_GLU: { diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 115f25ceb..5d0adcbac 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -5,6 +5,7 @@ // #include "iqk_cpu_ops.h" +#include "iqk_utils.h" #include "ggml.h" #include @@ -39,6 +40,49 @@ inline std::vector> & get_work_buffer(size_t size) { return buffer; } +#ifdef __ARM_NEON +inline float32x4_t v_biased_sigmoid(float32x4_t x, float32x4_t b) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vaddq_f32(b, vdivq_f32(one, one_plus_exp_neg_x)); +} +#endif +#ifdef __AVX2__ +inline __m256 v_biased_sigmoid(__m256 x, __m256 b) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_add_ps(b, _mm256_div_ps(one, one_plus_exp_neg_x)); +} +#endif +#if defined __AVX512F__ && defined __AVX512DQ__ +inline __m512 v_biased_sigmoid(__m512 x, __m512 b) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_add_ps(b, _mm512_div_ps(one, one_plus_exp_neg_x)); +} +#endif +inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_biased_sigmoid(_mm512_loadu_ps(x + i), _mm512_loadu_ps(bias + i))); +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_biased_sigmoid(_mm256_loadu_ps(x + i), _mm256_loadu_ps(bias + i))); +#endif +#ifdef __ARM_NEON + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_biased_sigmoid(vld1q_f32(x + i), vld1q_f32(bias + i))); +#endif + for (; i < n; ++i) y[i] = 1/(1 + expf(-x[i])) + bias[i]; +} } void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { @@ -143,3 +187,74 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) { } +void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth) { + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto t_bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int n_groups = topk->op_params[0]; + int n_top_groups = topk->op_params[1]; + int nk = topk->op_params[2]; + + int ne00 = probs->ne[0]; + int ne0 = topk->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(t_bias->ne[1] == 1); + GGML_ASSERT(t_bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + GGML_ASSERT(n_top_groups <= n_groups); + + size_t work_size = n_groups + n_per_group*n_top_groups + (ne00 + 1)/2; + auto& aux = get_work_buffer(work_size); + + auto groups = aux.data() + n_per_group*n_top_groups; + auto values = (float *)(groups + n_groups); + + auto bias = (const float *)t_bias->data; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + biased_sigmoid(ne00, data, bias, values); + //for (int j = 0; j < ne00; ++j) values[j] = 1/(1 + expf(-data[j])) + bias[j]; + auto weights = (float *)((char *)dst->data + ir*dst->nb[2]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + if (ne0 > n_per_group*n_top_groups) { + for (int j = 0; j < ne0; ++j) { + weights[j] = values[j]; + ids[j] = j; + } + continue; + } + if (n_top_groups < n_groups) { + for (int ig = 0; ig < n_groups; ++ig) { + groups[ig] = { group_score(n_per_group, nk, values + ig*n_per_group, (float *)aux.data()), ig }; + } + std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater>{}); + + for (int ig = 0; ig < n_top_groups; ++ig) { + int i0 = n_per_group * ig; + int j0 = n_per_group * groups[ig].second; + for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { values[j0 + j], j0 + j }; + } + } else { + for (int j = 0; j < ne00; ++j) aux[j] = { values[j], j }; + } + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater>{}); + for (int j = 0; j < ne0; ++j) { + weights[j] = aux[j].first; + ids[j] = aux[j].second; + } + + } +} + diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index c83d80618..81c14fd59 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -18,6 +18,8 @@ void iqk_grouped_top_k(struct ggml_tensor * dst, int ith, int nth); void iqk_argsort(struct ggml_tensor * dst, int ith, int nth); +void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index b3208dcd6..e2fccfca6 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -827,8 +827,9 @@ llm_expert_gating_func_type gating_op, auto& hparams = lctx.model.hparams; selected_experts = ggml_grouped_topk(ctx, selection_probs, hparams.n_expert_groups, hparams.n_group_used, 2, n_expert_used); } else { - selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, - lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] + //selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + // lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] + selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] } cb(selected_experts, "ffn_moe_topk", il); ggml_tensor * weights = ggml_get_rows(ctx, From 2c66dc86fcc1f9ddb56b7e64f0efd2ea9763aedd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 18 Oct 2025 14:55:30 +0300 Subject: [PATCH 2/9] Fix CPU + CUDA but CUDA is somehow not 100% correct as I get a slightly different PPL (lower!) --- ggml/src/ggml-cuda.cu | 11 +- ggml/src/ggml-cuda/argsort.cu | 181 ++++++++++++++++++++++++--------- ggml/src/ggml-cuda/argsort.cuh | 2 + ggml/src/iqk/iqk_cpu_ops.cpp | 46 +++++---- 4 files changed, 175 insertions(+), 65 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index fd5a5cacd..12da791f1 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3173,7 +3173,16 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - ggml_cuda_op_sigmoid(ctx, dst); + if (i + 4 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && + cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { + cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]); + i += 4; + } else { + ggml_cuda_op_sigmoid(ctx, dst); + } break; case GGML_UNARY_OP_HARDSIGMOID: ggml_cuda_op_hardsigmoid(ctx, dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 25c615888..3976c9d20 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -25,25 +25,8 @@ struct store { constexpr static bool has_thresh = false; }; -template -static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) { -// int min_experts, float thresh_experts) { - // bitonic sort - int col = threadIdx.x; - int row = blockIdx.y; - - if (col >= ncols_pad) { - return; - } - - const float * x_row = x + row * ncols; - extern __shared__ int dst_row[]; - - // initialize indices - dst_row[col] = col; - - __syncthreads(); - +template +static __device__ __forceinline__ void sort(int ncols_pad, int ncols, int col, const float * x_row, int * dst_row) { for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; @@ -69,6 +52,28 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n __syncthreads(); } } +} + +template +static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int ncols, int ncols_pad, int ntop, Store s) { +// int min_experts, float thresh_experts) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + const float * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); if constexpr (Store::has_thresh) { __syncthreads(); @@ -92,7 +97,8 @@ static __global__ void k_argsort_f32_T(const float * x, dst_t * dst, const int n } template -static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, int ncols_pad, int n_top_k) { +static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop, + size_t nb_ids) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -101,7 +107,7 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, return; } - const float * x_row = x + row * ncols; + const float * x_row = x_biased + row * ncols; extern __shared__ int dst_row[]; // initialize indices @@ -109,32 +115,43 @@ static __global__ void k_topk_sum(const float * x, float * dst, const int ncols, __syncthreads(); - for (int k = 2; k <= ncols_pad; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - ggml_cuda_swap(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - ggml_cuda_swap(dst_row[col], dst_row[ixj]); - } - } - } - __syncthreads(); + sort(ncols_pad, ncols, col, x_row, dst_row); + + if (col < ntop) { + weights[row * ntop + col] = x[row * ncols + dst_row[col]]; + auto row_ids = (int *)((char *)ids + row*nb_ids); + row_ids[col] = dst_row[col]; + } +} + +template +static __global__ void k_topk_sum(float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + float * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + if (bias && x_p) { + float * x_p_row = x_p + row * ncols; + if (col < ncols) { + x_row[col] = 1/(1 + expf(-x_row[col])); + x_p_row[col] = x_row[col] + bias[col]; } + x_row = x_p_row; } + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); + float val = col < n_top_k ? x_row[dst_row[col]] : 0; val = warp_reduce_sum(val); if (blockDim.x > WARP_SIZE) { @@ -208,6 +225,29 @@ static void argsort_f32_T_cuda(const float * x, dst_t * dst, const int ncols, co } } +static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop, + size_t nb_ids, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_f32_f32_i32<<>>(x_biased, x, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_f32_f32_i32<<>>(x_biased, x, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else { + GGML_ABORT("fatal error"); + } +} + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -246,7 +286,8 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } -static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, float * dst, int ncols, int nrows, int n_top_k) { +static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, float * src, const float * bias, float * src_p, float * dst, + int ncols, int nrows, int n_top_k) { GGML_ASSERT(n_top_k <= ncols); @@ -257,7 +298,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * const size_t shared_mem = std::max(ncols_pad, WARP_SIZE) * sizeof(int); GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); - k_topk_sum<<>>(src, dst, ncols, ncols_pad, n_top_k); + k_topk_sum<<>>(src, bias, src_p, dst, ncols, ncols_pad, n_top_k); } void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -291,7 +332,7 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds CUDA_CHECK(cudaGetLastError()); #else ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); - ggml_cuda_op_topk_sum(ctx, (const float *)src->data, group_scores.get(), n_per_group, nrows*n_groups, nk); + ggml_cuda_op_topk_sum(ctx, (float *)src->data, nullptr, nullptr, group_scores.get(), n_per_group, nrows*n_groups, nk); CUDA_CHECK(cudaGetLastError()); #endif @@ -310,3 +351,49 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds argsort_f32_T_cuda((const float *)src->data, (int *)dst->data, ne00, nrows, ne0, GGML_SORT_ORDER_DESC, -1, 0.0f, ctx.stream()); } + +void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk) { + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + + int n_groups = topk->op_params[0]; + int n_top_groups = topk->op_params[1]; + int nk = topk->op_params[2]; + + int ne00 = probs->ne[0]; + int ne0 = topk->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(bias->ne[1] == 1); + GGML_ASSERT(bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + GGML_ASSERT(ne00%n_groups == 0); + int n_per_group = ne00/n_groups; + GGML_ASSERT(nk <= n_per_group); + GGML_ASSERT(n_top_groups <= n_groups); + int n_discarded_groups = n_groups - n_top_groups; + + ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); + ggml_cuda_op_topk_sum(ctx, (float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(), + n_per_group, nrows*n_groups, nk); + CUDA_CHECK(cudaGetLastError()); + + ggml_cuda_pool_alloc discarded_groups(ctx.pool(), nrows*n_discarded_groups); + argsort_f32_T_cuda(group_scores.get(), discarded_groups.get(), n_groups, nrows, n_discarded_groups, GGML_SORT_ORDER_ASC, -1, 0.0f, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows, 1); + cudaStream_t stream = ctx.stream(); + k_apply_mask<<>>((float *)topk_src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00); + CUDA_CHECK(cudaGetLastError()); + } + + argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data, ne00, nrows, ne0, + topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 7bd28a1f0..e467abf0b 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -11,3 +11,5 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 5d0adcbac..f823d2825 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -41,47 +41,58 @@ inline std::vector> & get_work_buffer(size_t size) { } #ifdef __ARM_NEON -inline float32x4_t v_biased_sigmoid(float32x4_t x, float32x4_t b) { +inline float32x4_t v_sigmoid(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); const float32x4_t neg_x = vsubq_f32(zero, x); const float32x4_t exp_neg_x = v_expf(neg_x); const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); - return vaddq_f32(b, vdivq_f32(one, one_plus_exp_neg_x)); + return vdivq_f32(one, one_plus_exp_neg_x); } #endif #ifdef __AVX2__ -inline __m256 v_biased_sigmoid(__m256 x, __m256 b) { +inline __m256 v_sigmoid(__m256 x) { const __m256 one = _mm256_set1_ps(1); const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = v_expf(neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); - return _mm256_add_ps(b, _mm256_div_ps(one, one_plus_exp_neg_x)); + return _mm256_div_ps(one, one_plus_exp_neg_x); } #endif #if defined __AVX512F__ && defined __AVX512DQ__ -inline __m512 v_biased_sigmoid(__m512 x, __m512 b) { +inline __m512 v_sigmoid(__m512 x) { const __m512 one = _mm512_set1_ps(1); const __m512 zero = _mm512_setzero_ps(); const __m512 neg_x = _mm512_sub_ps(zero, x); const __m512 exp_neg_x = v_expf(neg_x); const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); - return _mm512_add_ps(b, _mm512_div_ps(one, one_plus_exp_neg_x)); + return _mm512_div_ps(one, one_plus_exp_neg_x); } #endif -inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) { +inline void biased_sigmoid(int n, const float * x, const float * bias, float * y, float * z) { int i = 0; #if defined __AVX512F__ && defined __AVX512DQ__ - for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_biased_sigmoid(_mm512_loadu_ps(x + i), _mm512_loadu_ps(bias + i))); + for (; i + 15 < n; i += 16) { + auto v = v_sigmoid(_mm512_loadu_ps(x + i)); + _mm512_storeu_ps(y + i, _mm512_add_ps(v, _mm512_loadu_ps(bias + i))); + _mm512_storeu_ps(z + i, v); + } #endif #if defined __AVX2__ && defined __FMA__ - for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_biased_sigmoid(_mm256_loadu_ps(x + i), _mm256_loadu_ps(bias + i))); + for (; i + 7 < n; i += 8) { + auto v = v_sigmoid(_mm256_loadu_ps(x + i)); + _mm256_storeu_ps(y + i, _mm256_add_ps(v, _mm256_loadu_ps(bias + i))); + _mm256_storeu_ps(z + i, v); + } #endif #ifdef __ARM_NEON for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_biased_sigmoid(vld1q_f32(x + i), vld1q_f32(bias + i))); #endif - for (; i < n; ++i) y[i] = 1/(1 + expf(-x[i])) + bias[i]; + for (; i < n; ++i) { + z[i] = 1/(1 + expf(-x[i])); + y[i] = y[i] + bias[i]; + } } } @@ -214,17 +225,18 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top GGML_ASSERT(nk <= n_per_group); GGML_ASSERT(n_top_groups <= n_groups); - size_t work_size = n_groups + n_per_group*n_top_groups + (ne00 + 1)/2; + size_t work_size = n_groups + n_per_group*n_top_groups + ne00; auto& aux = get_work_buffer(work_size); auto groups = aux.data() + n_per_group*n_top_groups; - auto values = (float *)(groups + n_groups); + auto biased_values = (float *)(groups + n_groups); + auto values = biased_values + ne00; auto bias = (const float *)t_bias->data; for (int ir = first; ir < last; ++ir) { auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); - biased_sigmoid(ne00, data, bias, values); + biased_sigmoid(ne00, data, bias, biased_values, values); //for (int j = 0; j < ne00; ++j) values[j] = 1/(1 + expf(-data[j])) + bias[j]; auto weights = (float *)((char *)dst->data + ir*dst->nb[2]); auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); @@ -237,21 +249,21 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top } if (n_top_groups < n_groups) { for (int ig = 0; ig < n_groups; ++ig) { - groups[ig] = { group_score(n_per_group, nk, values + ig*n_per_group, (float *)aux.data()), ig }; + groups[ig] = { group_score(n_per_group, nk, biased_values + ig*n_per_group, (float *)aux.data()), ig }; } std::partial_sort(groups, groups + n_top_groups, groups + n_groups, std::greater>{}); for (int ig = 0; ig < n_top_groups; ++ig) { int i0 = n_per_group * ig; int j0 = n_per_group * groups[ig].second; - for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { values[j0 + j], j0 + j }; + for (int j = 0; j < n_per_group; ++j) aux[i0 + j] = { biased_values[j0 + j], j0 + j }; } } else { - for (int j = 0; j < ne00; ++j) aux[j] = { values[j], j }; + for (int j = 0; j < ne00; ++j) aux[j] = { biased_values[j], j }; } std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + n_top_groups*n_per_group, std::greater>{}); for (int j = 0; j < ne0; ++j) { - weights[j] = aux[j].first; + weights[j] = values[aux[j].second]; ids[j] = aux[j].second; } From f3ff1a5c4879493e84e55f30683e6c4625b77eb1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 07:21:41 +0300 Subject: [PATCH 3/9] Minor --- ggml/src/ggml-cuda/argsort.cu | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3976c9d20..3db45b3b1 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -118,14 +118,14 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float sort(ncols_pad, ncols, col, x_row, dst_row); if (col < ntop) { - weights[row * ntop + col] = x[row * ncols + dst_row[col]]; + weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]])); auto row_ids = (int *)((char *)ids + row*nb_ids); row_ids[col] = dst_row[col]; } } template -static __global__ void k_topk_sum(float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { +static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -134,7 +134,7 @@ static __global__ void k_topk_sum(float * x, const float * bias, float * x_p, fl return; } - float * x_row = x + row * ncols; + const float * x_row = x + row * ncols; extern __shared__ int dst_row[]; // initialize indices @@ -142,8 +142,7 @@ static __global__ void k_topk_sum(float * x, const float * bias, float * x_p, fl if (bias && x_p) { float * x_p_row = x_p + row * ncols; if (col < ncols) { - x_row[col] = 1/(1 + expf(-x_row[col])); - x_p_row[col] = x_row[col] + bias[col]; + x_p_row[col] = 1/(1 + expf(-x_row[col])) + bias[col]; } x_row = x_p_row; } @@ -156,7 +155,7 @@ static __global__ void k_topk_sum(float * x, const float * bias, float * x_p, fl val = warp_reduce_sum(val); if (blockDim.x > WARP_SIZE) { __syncthreads(); - float * s_sum = (float *)dst_row; + float * s_sum = (float *)(dst_row + ncols_pad); const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { @@ -286,7 +285,7 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * argsort_f32_T_cuda(src0_d, (int *)dst_d, ncols, nrows, ncols, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } -static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, float * src, const float * bias, float * src_p, float * dst, +static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, const float * src, const float * bias, float * src_p, float * dst, int ncols, int nrows, int n_top_k) { GGML_ASSERT(n_top_k <= ncols); @@ -295,7 +294,7 @@ static void ggml_cuda_op_topk_sum(ggml_backend_cuda_context & ctx, float * src, const dim3 block_dims(ncols_pad, 1, 1); const dim3 block_nums(1, nrows, 1); - const size_t shared_mem = std::max(ncols_pad, WARP_SIZE) * sizeof(int); + const size_t shared_mem = (ncols_pad + WARP_SIZE) * sizeof(int); GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); k_topk_sum<<>>(src, bias, src_p, dst, ncols, ncols_pad, n_top_k); @@ -377,7 +376,7 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds int n_discarded_groups = n_groups - n_top_groups; ggml_cuda_pool_alloc group_scores(ctx.pool(), nrows*n_groups); - ggml_cuda_op_topk_sum(ctx, (float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(), + ggml_cuda_op_topk_sum(ctx, (const float *)probs->data, (const float *)bias->data, (float *)topk_src->data, group_scores.get(), n_per_group, nrows*n_groups, nk); CUDA_CHECK(cudaGetLastError()); @@ -388,12 +387,11 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(1, nrows, 1); - cudaStream_t stream = ctx.stream(); k_apply_mask<<>>((float *)topk_src->data, discarded_groups.get(), n_discarded_groups, n_per_group, ne00); CUDA_CHECK(cudaGetLastError()); } - argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data, ne00, nrows, ne0, - topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + argsort_f32_f32_i32_cuda((const float *)topk_src->data, (const float *)probs->data, (float *)dst->data, (int *)topk->data, + ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); } From 8fe2bb927a481e32af9b8f56d21fcdb7bc8ab163 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 09:13:34 +0300 Subject: [PATCH 4/9] Fuse sigmoid+add+topk+get_rows (CUDA) --- ggml/src/ggml-cuda.cu | 21 ++++++++- ggml/src/ggml-cuda/argsort.cu | 78 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/argsort.cuh | 2 + 3 files changed, 99 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 12da791f1..4372d99f5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3173,12 +3173,29 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - if (i + 4 < cgraph->n_nodes && + if (i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+4]->op == GGML_OP_VIEW && + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) { + cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); + i += 5; + } + //else if (i + 5 < cgraph->n_nodes) { + // printf("sigmoid(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s)\n", dst->name, + // ggml_op_name(cgraph->nodes[i+1]->op), cgraph->nodes[i+1]->name, + // ggml_op_name(cgraph->nodes[i+2]->op), cgraph->nodes[i+2]->name, + // ggml_op_name(cgraph->nodes[i+3]->op), cgraph->nodes[i+3]->name, + // ggml_op_name(cgraph->nodes[i+4]->op), cgraph->nodes[i+4]->name, + // ggml_op_name(cgraph->nodes[i+5]->op), cgraph->nodes[i+5]->name); + //} + else if (i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS) { - cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]); + cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+4]); i += 4; } else { ggml_cuda_op_sigmoid(ctx, dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 3db45b3b1..7c3c5e666 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -124,6 +124,35 @@ static __global__ void k_argsort_f32_f32_i32(const float * x_biased, const float } } +template +static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float * bias, float * weights, int * ids, const int ncols, int ncols_pad, int ntop, + size_t nb_ids) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + extern __shared__ int dst_row[]; + auto x_row = (float *)(dst_row + ncols_pad); + + // initialize indices + dst_row[col] = col; + x_row[col] = col < ncols ? 1/(1 + expf(-x[row*ncols + col])) + bias[col] : -INFINITY; + + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); + + if (col < ntop) { + weights[row * ntop + col] = 1/(1 + expf(-x[row * ncols + dst_row[col]])); + auto row_ids = (int *)((char *)ids + row*nb_ids); + row_ids[col] = dst_row[col]; + } +} + template static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort @@ -247,6 +276,29 @@ static void argsort_f32_f32_i32_cuda(const float * x_biased, const float * x, fl } } +static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, float * weights, int * ids, const int ncols, const int nrows, int ntop, + size_t nb_ids, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * (sizeof(int) + sizeof(float)); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_biased_f32_f32_i32<<>>(x, bias, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_biased_f32_f32_i32<<>>(x, bias, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else { + GGML_ABORT("fatal error"); + } +} + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -395,3 +447,29 @@ void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * ds ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); } + +void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk_view) { + GGML_ASSERT(topk_view->op == GGML_OP_VIEW); + auto topk = topk_view->src[0]; + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + + int ne00 = probs->ne[0]; + int ne0 = topk_view->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(bias->ne[1] == 1); + GGML_ASSERT(bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + + //printf("probs: %ld x %ld x %ld x %ld. topk: %ld x %ld x %ld x %ld. dst: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", + // probs->ne[0], probs->ne[1], probs->ne[2], probs->ne[3], topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], + // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + + argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data, + ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index e467abf0b..43987fbbe 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -13,3 +13,5 @@ void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); + +void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); From 18d9f4fc4dfe3b400e9a94e17e9fb45b7d045f95 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 10:03:51 +0300 Subject: [PATCH 5/9] Fuse sigmoid+add+topk+get_rows (CPU) --- ggml/src/ggml-cuda.cu | 8 ---- ggml/src/ggml.c | 11 +++++- ggml/src/iqk/iqk_cpu_ops.cpp | 77 +++++++++++++++++++++++++++++++++++- ggml/src/iqk/iqk_cpu_ops.h | 2 + 4 files changed, 88 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4372d99f5..4e3b04134 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3182,14 +3182,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); i += 5; } - //else if (i + 5 < cgraph->n_nodes) { - // printf("sigmoid(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s) -> %s(%s)\n", dst->name, - // ggml_op_name(cgraph->nodes[i+1]->op), cgraph->nodes[i+1]->name, - // ggml_op_name(cgraph->nodes[i+2]->op), cgraph->nodes[i+2]->name, - // ggml_op_name(cgraph->nodes[i+3]->op), cgraph->nodes[i+3]->name, - // ggml_op_name(cgraph->nodes[i+4]->op), cgraph->nodes[i+4]->name, - // ggml_op_name(cgraph->nodes[i+5]->op), cgraph->nodes[i+5]->name); - //} else if (i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ca73ae54..a556b2a41 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22612,7 +22612,16 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml case GGML_OP_UNARY: { const enum ggml_unary_op unary_op = ggml_get_unary_op(tensor); - if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && + if (unary_op == GGML_UNARY_OP_SIGMOID && i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+2]->op == GGML_OP_ADD && + cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && + cgraph->nodes[i+4]->op == GGML_OP_VIEW && + cgraph->nodes[i+5]->op == GGML_OP_GET_ROWS) { + iqk_glm45moe_experts(cgraph->nodes[i+5], cgraph->nodes[i+4], params->ith, params->nth); + i += 5; + } + else if (unary_op == GGML_UNARY_OP_SIGMOID && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index f823d2825..51cdfcc8d 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -87,13 +87,41 @@ inline void biased_sigmoid(int n, const float * x, const float * bias, float * y } #endif #ifdef __ARM_NEON - for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_biased_sigmoid(vld1q_f32(x + i), vld1q_f32(bias + i))); + for (; i + 3 < n; i += 4) { + auto v = v_sigmoid(vld1q_f32(x + i)); + vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i))); + vst1q_f32(z + i, v); + } #endif for (; i < n; ++i) { z[i] = 1/(1 + expf(-x[i])); y[i] = y[i] + bias[i]; } } +inline void biased_sigmoid(int n, const float * x, const float * bias, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) { + auto v = v_sigmoid(_mm512_loadu_ps(x + i)); + _mm512_storeu_ps(y + i, _mm512_add_ps(v, _mm512_loadu_ps(bias + i))); + } +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) { + auto v = v_sigmoid(_mm256_loadu_ps(x + i)); + _mm256_storeu_ps(y + i, _mm256_add_ps(v, _mm256_loadu_ps(bias + i))); + } +#endif +#ifdef __ARM_NEON + for (; i + 3 < n; i += 4) { + auto v = v_sigmoid(vld1q_f32(x + i)); + vst1q_f32(y + i, vaddq_f32(v, vld1q_f32(bias + i))); + } +#endif + for (; i < n; ++i) { + y[i] = 1/(1 + expf(-x[i])) + bias[i]; + } +} } void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { @@ -270,3 +298,50 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top } } +void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth) { + GGML_ASSERT(topk_view->op == GGML_OP_VIEW); + auto topk = topk_view->src[0]; + auto topk_src = topk->src[0]; + auto probs = topk_src->src[0]->src[0]; + auto t_bias = topk_src->src[1]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int ne00 = probs->ne[0]; + int ne0 = topk_view->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(t_bias->ne[1] == 1); + GGML_ASSERT(t_bias->ne[0] == probs->ne[0]); + GGML_ASSERT(ne0 == dst->ne[1]); + GGML_ASSERT(ne0 <= ne00); + + size_t work_size = 2*ne00; + auto& aux = get_work_buffer(work_size); + + auto biased_values = (float *)(aux.data() + ne00); + //auto values = biased_values + ne00; + + auto bias = (const float *)t_bias->data; + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + //biased_sigmoid(ne00, data, bias, biased_values, values); + biased_sigmoid(ne00, data, bias, biased_values); + auto weights = (float *)((char *)dst->data + ir*dst->nb[2]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + for (int j = 0; j < ne00; ++j) aux[j] = { biased_values[j], j }; + if (ne0 < ne00) { + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater>{}); + } else { + std::sort(aux.begin(), aux.begin() + ne00, std::greater>{}); + } + for (int j = 0; j < ne0; ++j) { + weights[j] = 1/(1 + expf(-data[aux[j].second])); + ids[j] = aux[j].second; + } + } +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 81c14fd59..2de3a5cb6 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -20,6 +20,8 @@ void iqk_argsort(struct ggml_tensor * dst, int ith, int nth); void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * topk, int ith, int nth); +void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth); + #ifdef __cplusplus } #endif From c8ed4545644cf9bfc8e40be3279ba990b9b8f1b0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 11:45:10 +0300 Subject: [PATCH 6/9] Fuse topk+view+get_rows+reshape+softmax (CPU) --- ggml/src/ggml.c | 12 +++++++++- ggml/src/iqk/iqk_cpu_ops.cpp | 44 ++++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_cpu_ops.h | 2 ++ src/llama-build-context.cpp | 8 +++---- 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a556b2a41..d6b76a5b0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22568,7 +22568,17 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_ARGSORT: { - ggml_compute_forward_argsort(params, tensor); + if (i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && + cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX && + cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) { + iqk_openai_experts(tensor, cgraph->nodes[i+4], params->ith, params->nth); + i += 5; + } else { + ggml_compute_forward_argsort(params, tensor); + } } break; case GGML_OP_ARGSORT_THRESH: { diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 51cdfcc8d..ff34abcf4 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -345,3 +345,47 @@ void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_vi } } } + +void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth) { + + auto probs = topk->src[0]; + + auto nrows = ggml_nrows(probs); + auto npt = (nrows + nth - 1)/nth; + auto first = npt*ith; + auto last = std::min(first + npt, nrows); + if (last <= first) return; + + int ne00 = probs->ne[0]; + int ne0 = softmax->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(ggml_is_contiguous(softmax)); + GGML_ASSERT(ne0 <= ne00); + //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, softmax: %s\n", __func__, ne00, ne0, ggml_type_name(topk->type), ggml_type_name(softmax->type)); + //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", __func__, ne00, ne0, ggml_type_name(topk->type), topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], topk->nb[0], topk->nb[1], topk->nb[2], topk->nb[3]); + + size_t work_size = ne00; + auto& aux = get_work_buffer(work_size); + + for (int ir = first; ir < last; ++ir) { + auto data = (const float *)((const char *)probs->data + ir*probs->nb[1]); + for (int j = 0; j < ne00; ++j) aux[j] = { data[j], j }; + if (ne0 < ne00) { + std::partial_sort(aux.begin(), aux.begin() + ne0, aux.begin() + ne00, std::greater>{}); + } else { + std::sort(aux.begin(), aux.begin() + ne00, std::greater>{}); + } + auto weights = (float *)((char *)softmax->data + ir*softmax->nb[1]); + auto ids = (int32_t *)((char *)topk->data + ir*topk->nb[1]); + float max = aux.front().first; + float sum = 0; + for (int j = 0; j < ne0; ++j) { + weights[j] = expf(aux[j].first - max); + ids[j] = aux[j].second; + sum += weights[j]; + } + GGML_ASSERT(sum > 0); + float norm = 1/sum; + for (int j = 0; j < ne0; ++j) weights[j] *= norm; + } +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index 2de3a5cb6..ef2bbe1ac 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -22,6 +22,8 @@ void iqk_bailingmoev2_experts(struct ggml_tensor * dst, struct ggml_tensor * top void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_view, int ith, int nth); +void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index e2fccfca6..dd0b62be1 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -836,10 +836,6 @@ llm_expert_gating_func_type gating_op, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - if (graph) { - ggml_build_forward_expand(graph, weights); - } - if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens] @@ -847,6 +843,10 @@ llm_expert_gating_func_type gating_op, cb(weights, "ffn_moe_weights_softmax", il); } + if (graph) { + ggml_build_forward_expand(graph, weights); + } + if (norm_w) { weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); From b79aad9d07f28a9121a50f22ce5a84d3d43ff699 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 13:00:37 +0300 Subject: [PATCH 7/9] Fuse topk+view+get_rows+reshape+softmax (CUDA) --- ggml/src/ggml-cuda.cu | 12 ++++- ggml/src/ggml-cuda/argsort.cu | 96 ++++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/argsort.cuh | 2 + 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4e3b04134..412ef16e7 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3336,7 +3336,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_sum_rows(ctx, dst); break; case GGML_OP_ARGSORT: - ggml_cuda_op_argsort(ctx, dst); + if (i + 5 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && + cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && + cgraph->nodes[i+4]->op == GGML_OP_SOFT_MAX && + cgraph->nodes[i+5]->op == GGML_OP_RESHAPE) { + cuda_openai_experts(ctx, dst, cgraph->nodes[i+4]); + i += 5; + } else { + ggml_cuda_op_argsort(ctx, dst); + } break; case GGML_OP_ARGSORT_THRESH: ggml_cuda_op_argsort_thresh(ctx, dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 7c3c5e666..99c0b7fe0 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -153,6 +153,53 @@ static __global__ void k_argsort_biased_f32_f32_i32(const float * x, const float } } +template +static __global__ void k_openai_f32_f32_i32(const float * x, float * weights, int * ids, const int ncols, int ncols_pad, int ntop, + size_t nb_ids) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + extern __shared__ int dst_row[]; + auto x_row = x + row*ncols; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + sort(ncols_pad, ncols, col, x_row, dst_row); + + float max = x_row[dst_row[0]]; + float val = col < ntop ? expf(x_row[dst_row[col]] - max) : 0.0f; + float sum = warp_reduce_sum(val); + if (blockDim.x > WARP_SIZE) { + __syncthreads(); + float * s_sum = (float *)(dst_row + ncols_pad); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = sum; + } + __syncthreads(); + sum = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + sum = s_sum[lane_id]; + } + sum = warp_reduce_sum(sum); + } + float norm = 1/sum; + if (col < ntop) { + weights[row * ntop + col] = norm*val; + auto row_ids = (int *)((char *)ids + row*nb_ids); + row_ids[col] = dst_row[col]; + } +} + template static __global__ void k_topk_sum(const float * x, const float * bias, float * x_p, float * dst, const int ncols, int ncols_pad, int n_top_k) { // bitonic sort @@ -299,6 +346,29 @@ static void argsort_biased_f32_f32_i32_cuda(const float * x, const float * bias, } } +static void argsort_openai_f32_f32_i32_cuda(const float * x, float * weights, int * ids, const int ncols, const int nrows, int ntop, + size_t nb_ids, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = (ncols_pad + ncols_pad > WARP_SIZE ? WARP_SIZE : 0) * sizeof(int); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_openai_f32_f32_i32<<>>(x, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else if (order == GGML_SORT_ORDER_DESC) { + k_openai_f32_f32_i32<<>>(x, weights, ids, + ncols, ncols_pad, ntop, nb_ids); + } else { + GGML_ABORT("fatal error"); + } +} + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -465,11 +535,29 @@ void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g GGML_ASSERT(ne0 == dst->ne[1]); GGML_ASSERT(ne0 <= ne00); - //printf("probs: %ld x %ld x %ld x %ld. topk: %ld x %ld x %ld x %ld. dst: %ld x %ld x %ld x %ld; %zu x %zu x %zu x %zu\n", - // probs->ne[0], probs->ne[1], probs->ne[2], probs->ne[3], topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], - // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); - argsort_biased_f32_f32_i32_cuda((const float *)probs->data, (const float *)bias->data, (float *)dst->data, (int *)topk->data, ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); } + +void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax) { + + auto probs = topk->src[0]; + int ntop = topk->op_params[1]; + + auto nrows = ggml_nrows(probs); + int ne00 = probs->ne[0]; + int ne0 = softmax->ne[0]; + GGML_ASSERT(ggml_is_contiguous(probs)); + GGML_ASSERT(ggml_is_contiguous(softmax)); + GGML_ASSERT(ne0 <= ne00); + if (ntop != ne0) { + printf("Oops: ntop = %d, ne0 = %d\n", ntop, ne0); + GGML_ASSERT(false); + } + //GGML_ASSERT(ne0 == ntop); + + argsort_openai_f32_f32_i32_cuda((const float *)probs->data, (float *)softmax->data, (int *)topk->data, + ne00, nrows, ne0, topk->nb[1], GGML_SORT_ORDER_DESC, ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 43987fbbe..331f373bf 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -15,3 +15,5 @@ void ggml_cuda_op_grouped_topk(ggml_backend_cuda_context & ctx, ggml_tensor * ds void cuda_bailingmoev2_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); void cuda_glm45moe_experts(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * topk); + +void cuda_openai_experts(ggml_backend_cuda_context & ctx, ggml_tensor * topk, ggml_tensor * softmax); From 0fb9d4963f6588ecf46c45773084b9b98ef6bee3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 13:11:35 +0300 Subject: [PATCH 8/9] cpu: turn off the openai topk fusing for now Something is not right and I don't see the bug. On the CPU one doesn't gain much if anything, so not a big loss. --- ggml/src/ggml.c | 2 +- ggml/src/iqk/iqk_cpu_ops.cpp | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d6b76a5b0..af6547643 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22568,7 +22568,7 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_ARGSORT: { - if (i + 5 < cgraph->n_nodes && + if (false && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index ff34abcf4..869fd1fe0 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -209,15 +209,15 @@ void iqk_argsort(ggml_tensor * dst, int ith, int nth) { for (int j = 0; j < ne00; ++j) aux[j] = {data[j], j}; if (nk < ne00) { if (order == GGML_SORT_ORDER_DESC) { - std::partial_sort(aux.begin(), aux.begin() + nk, aux.end(), std::greater>{}); + std::partial_sort(aux.begin(), aux.begin() + nk, aux.begin() + ne00, std::greater>{}); } else { - std::partial_sort(aux.begin(), aux.begin() + nk, aux.end()); + std::partial_sort(aux.begin(), aux.begin() + nk, aux.begin() + ne00); } } else { if (order == GGML_SORT_ORDER_DESC) { - std::sort(aux.begin(), aux.end(), std::greater>{}); + std::sort(aux.begin(), aux.begin() + ne00, std::greater>{}); } else { - std::sort(aux.begin(), aux.end()); + std::sort(aux.begin(), aux.begin() + ne00); } } auto y = (int32_t *)((char *)dst->data + ir*dst->nb[1]); @@ -361,8 +361,6 @@ void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, GGML_ASSERT(ggml_is_contiguous(probs)); GGML_ASSERT(ggml_is_contiguous(softmax)); GGML_ASSERT(ne0 <= ne00); - //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, softmax: %s\n", __func__, ne00, ne0, ggml_type_name(topk->type), ggml_type_name(softmax->type)); - //if (ith == 0) printf("%s: ne00 = %d, ne0 = %d, topk: %s, %ld x %ld x %ld x %ld, %zu x %zu x %zu x %zu\n", __func__, ne00, ne0, ggml_type_name(topk->type), topk->ne[0], topk->ne[1], topk->ne[2], topk->ne[3], topk->nb[0], topk->nb[1], topk->nb[2], topk->nb[3]); size_t work_size = ne00; auto& aux = get_work_buffer(work_size); From 1d70b89d351e8346b53467c47f47cb08a01882a3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 19 Oct 2025 18:04:13 +0300 Subject: [PATCH 9/9] Also fuse sum_rows and div --- ggml/src/ggml-cuda.cu | 10 +++++++- ggml/src/ggml-cuda/sumrows.cu | 42 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/sumrows.cuh | 2 ++ ggml/src/ggml.c | 10 +++++++- ggml/src/iqk/iqk_cpu_ops.cpp | 22 ++++++++++++++++++ ggml/src/iqk/iqk_cpu_ops.h | 2 ++ 6 files changed, 86 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 412ef16e7..74597763d 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3333,7 +3333,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_pool2d(ctx, dst); break; case GGML_OP_SUM_ROWS: - ggml_cuda_op_sum_rows(ctx, dst); + if (i + 1 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_DIV && + cgraph->nodes[i+1]->src[1] == dst && + cgraph->nodes[i+1]->src[0] == dst->src[0]) { + ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+1]); + ++i; + } else { + ggml_cuda_op_sum_rows(ctx, dst); + } break; case GGML_OP_ARGSORT: if (i + 5 < cgraph->n_nodes && diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 40be14cfc..4888bbfd3 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -16,12 +16,38 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc } } +static __global__ void k_sum_rows_div_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + float norm = sum > 0 ? 1/sum : 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + dst[row * ncols + i] = x[row * ncols + i] * norm; + } + //for (int i = col; i < ncols; i += blockDim.x) { + // dst[row * ncols + i] = x[row * ncols + i] / sum; + //} +} + void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); } +static void sum_rows_div_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + k_sum_rows_div_f32<<>>(x, dst, ncols); +} + void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -38,3 +64,19 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); } + +void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + sum_rows_div_f32_cuda(src0_d, dst_d, ncols, nrows, stream); +} diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index 0c0f47838..b6c0dc26e 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -3,3 +3,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); + +void ggml_cuda_op_sum_rows_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index af6547643..9434c8be1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -22357,7 +22357,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml } break; case GGML_OP_SUM_ROWS: { - ggml_compute_forward_sum_rows(params, tensor); + if (i + 1 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_DIV && + cgraph->nodes[i+1]->src[1] == tensor && + cgraph->nodes[i+1]->src[0] == tensor->src[0]) { + iqk_sumrows_div(cgraph->nodes[i+1], params->ith, params->nth); + ++i; + } else { + ggml_compute_forward_sum_rows(params, tensor); + } } break; case GGML_OP_MEAN: { diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index 869fd1fe0..a9bffadb6 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -124,6 +124,28 @@ inline void biased_sigmoid(int n, const float * x, const float * bias, float * y } } +void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth) { + auto src = div->src[0]; + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(div->type == GGML_TYPE_F32); + + int ne00 = src->ne[0]; + int nrows = ggml_nrows(src); + int npt = (nrows + nth - 1)/nth; + int first = ith*npt; + int last = std::min(first + npt, nrows); + if (last < first) return; + + for (int ir = first; ir < last; ++ir) { + auto values = (const float *)((const char *)src->data + ir*src->nb[1]); + float sum = 0; + for (int j = 0; j < ne00; ++j) sum += values[j]; + float norm = sum > 0 ? 1/sum : 0.0f; + auto result = (float *)((char *)div->data + ir*div->nb[1]); + for (int j = 0; j < ne00; ++j) result[j] = values[j]*norm; + } +} + void iqk_grouped_top_k(ggml_tensor * dst, int ith, int nth) { auto src = dst->src[0]; GGML_ASSERT(dst->type == GGML_TYPE_I32); diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index ef2bbe1ac..e00157a5a 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -14,6 +14,8 @@ extern "C" { struct ggml_tensor; +void iqk_sumrows_div(struct ggml_tensor * div, int ith, int nth); + void iqk_grouped_top_k(struct ggml_tensor * dst, int ith, int nth); void iqk_argsort(struct ggml_tensor * dst, int ith, int nth);