From 3679d76080f4311f8a21b160761ecc0f02780577 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 20 Nov 2025 17:00:50 +0800 Subject: [PATCH 1/4] support noaux eplb --- custom_ops/gpu_ops/cpp_extensions.cc | 17 + custom_ops/gpu_ops/noaux_tc_redundant.cu | 105 ++++++ custom_ops/gpu_ops/noauxtc_kernel.h | 323 ++++++++++++++++++ custom_ops/setup_ops.py | 2 + fastdeploy/model_executor/layers/moe/ep.py | 41 ++- fastdeploy/model_executor/layers/moe/moe.py | 42 ++- .../model_executor/models/ernie4_5_moe.py | 2 +- tests/operators/test_noaux_tc_redundant.py | 111 ++++++ 8 files changed, 620 insertions(+), 23 deletions(-) create mode 100644 custom_ops/gpu_ops/noaux_tc_redundant.cu create mode 100644 tests/operators/test_noaux_tc_redundant.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 0e6853d9be5..ece563b671c 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -647,6 +647,19 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector NoauxTcRedundant( + paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + paddle::Tensor& expert_id_to_ep_rank_array, + paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor, + int redundant_ep_rank_num_plus_one); + #ifdef ENABLE_FP8 paddle::Tensor cutlass_fp8_fp8_half_gemm_func( const paddle::Tensor& x, @@ -1485,6 +1498,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("noaux_tc_redunant", + &NoauxTcRedundant, + "noaux_tc_redundant for MoE compute"); + #ifdef ENABLE_FP8 m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, diff --git a/custom_ops/gpu_ops/noaux_tc_redundant.cu b/custom_ops/gpu_ops/noaux_tc_redundant.cu new file mode 100644 index 00000000000..1fd8c1c2892 --- /dev/null +++ b/custom_ops/gpu_ops/noaux_tc_redundant.cu @@ -0,0 +1,105 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "helper.h" +#include "noauxtc_kernel.h" + +std::vector NoauxTcRedundant( + paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + paddle::Tensor& expert_id_to_ep_rank_array, + paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor, + int redundant_ep_rank_num_plus_one) { + auto input_shape = scores_with_bias.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto input_type = scores_with_bias.dtype(); + auto place = scores_with_bias.place(); + auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place); + auto topk_values = paddle::empty({num_tokens, topk}, input_type, place); + auto topk_indices = + paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + auto stream = scores_with_bias.stream(); + + invokeNoAuxTcRedundant( + reinterpret_cast(scores.data()), + reinterpret_cast(group_scores.data()), + reinterpret_cast(topk_values.data()), + reinterpret_cast(topk_indices.data()), + reinterpret_cast(scores_with_bias.data()), + reinterpret_cast(expert_id_to_ep_rank_array.data()), + reinterpret_cast(expert_in_rank_num_list.data()), + reinterpret_cast(tokens_per_expert_stats_list.data()), + num_tokens, + num_experts, + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + redundant_ep_rank_num_plus_one, + stream); + + return {scores, topk_values, topk_indices}; +} + +std::vector NoauxTcRedundantInferDtype( + const paddle::DataType& scores_dtype, + const paddle::DataType& scores_with_bias_dtype) { + return {scores_dtype, scores_dtype, paddle::DataType::INT64}; +} + +std::vector> NoauxTcRedundantInferShape( + const std::vector& scores_shape, + const std::vector&, + const int topk) { + auto num_tokens = scores_shape[0]; + auto topk_values_shape = std::vector{num_tokens, topk}; + auto topk_indices_shape = std::vector{num_tokens, topk}; + return {scores_shape, topk_values_shape, topk_indices_shape}; +} + +PD_BUILD_STATIC_OP(noaux_tc_redundant) + .Inputs({"scores", + "scores_with_bias", + "expert_id_to_ep_rank_array", + "expert_in_rank_num_list", + "tokens_per_expert_stats_list"}) + .Outputs({"output_tensor", + "topk_values", + "topk_indices", + "tokens_per_expert_stats_list_out"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk:int", + "renormalize: bool", + "routed_scaling_factor: float", + "redundant_ep_rank_num_plus_one:int"}) + .SetInplaceMap({{"tokens_per_expert_stats_list", + "tokens_per_expert_stats_list_out"}}) + .SetKernelFn(PD_KERNEL(NoauxTcRedundant)) + .SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype)); diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index a3a52051bd0..0093ab9bdea 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -420,6 +420,13 @@ class WarpSelect : public WarpSort { }; // end class WarpSelect } // namespace warp_topk +inline __device__ unsigned int xorwow_moe(unsigned int& state) { + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} + template __device__ void topk_with_k2(T* output, T const* input, @@ -656,6 +663,195 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } +template +__global__ void group_idx_and_topk_idx_redundant_kernel( + T* scores, + T const* group_scores, + T* topk_values, + IdxT* topk_indices, + T* scores_with_bias, + int32_t* expert_id_to_ep_rank_array, + int32_t* expert_in_rank_num_list, + int32_t* tokens_per_expert_stats_list, + int64_t const num_tokens, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + int64_t const num_experts, + int64_t const num_experts_per_group, + double routed_scaling_factor, + int64_t const redundant_ep_rank_num_plus_one) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + unsigned int state = case_id; + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = neg_inf(); + T topk_group_value = neg_inf(); + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = neg_inf(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = + __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, neg_inf()); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : neg_inf(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = i < topk ? scores[s_topk_idx[i]] + : 0.0f; // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + // Note(ZKK): a little trick. + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores[i] = 0; + } + } + __syncwarp(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + scores[s_topk_idx[i]] = value; + + int expert_topk = s_topk_idx[i]; + int len = expert_in_rank_num_list[expert_topk]; + int select = (int)xorwow_moe(state) % len; + // int select = 0; + int selected_rank = + expert_id_to_ep_rank_array[expert_topk * + redundant_ep_rank_num_plus_one + + select]; + atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1); + topk_indices[i] = (IdxT)selected_rank; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + int expert_topk = i; + int len = expert_in_rank_num_list[expert_topk]; + int select = (int)xorwow_moe(state) % len; + // int select = 0; + int selected_rank = + expert_id_to_ep_rank_array[expert_topk * + redundant_ep_rank_num_plus_one + + select]; + atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1); + topk_indices[i] = (IdxT)selected_rank; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + template void invokeNoAuxTc(T* scores, T* group_scores, @@ -752,6 +948,111 @@ void invokeNoAuxTc(T* scores, #endif } +template +void invokeNoAuxTcRedundant(T* scores, + T* group_scores, + T* topk_values, + IdxT* topk_indices, + T* scores_with_bias, + int32_t* expert_id_to_ep_rank_array, + int32_t* expert_in_rank_num_list, + int32_t* tokens_per_expert_stats_list, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + int64_t const redundant_ep_rank_num_plus_one, + cudaStream_t const stream) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + topk_with_k2_kernel<<>>( + group_scores, + scores_with_bias, + num_cases, + n_group, + num_experts / n_group); +#else + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, + kernel_instance1, + group_scores, + scores_with_bias, + num_cases, + n_group, + num_experts / n_group); +#endif + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + group_idx_and_topk_idx_redundant_kernel + <<>>(scores, + group_scores, + topk_values, + topk_indices, + scores_with_bias, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + num_tokens, + n_group, + topk_group, + topk, + renormalize, + num_experts, + num_experts / n_group, + routed_scaling_factor, + redundant_ep_rank_num_plus_one); +#else + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, + kernel_instance2, + scores, + group_scores, + topk_values, + topk_indices, + scores_with_bias, + num_tokens, + n_group, + topk_group, + topk, + num_experts, + num_experts / n_group, + renormalize, + routed_scaling_factor); +#endif +} + #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc(T * scores, \ T * group_scores, \ @@ -768,3 +1069,25 @@ void invokeNoAuxTc(T* scores, cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); + +#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \ + template void invokeNoAuxTcRedundant( \ + T * scores, \ + T * group_scores, \ + T * topk_values, \ + IdxT * topk_indices, \ + T * scores_with_bias, \ + int32_t* expert_id_to_ep_rank_array, \ + int32_t* expert_in_rank_num_list, \ + int32_t* tokens_per_expert_stats_list, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + int64_t const redundant_ep_rank_num_plus_one, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC_Redundant(float, int32_t); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 3460b077ec4..40900b18771 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -301,6 +301,7 @@ def find_end_files(directory, end_str): "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", + "gpu_ops/noaux_tc_redundant.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length_v1.cu", @@ -614,6 +615,7 @@ def find_end_files(directory, end_str): "gpu_ops/share_external_data.cu", "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", + "gpu_ops/noaux_tc_redundant.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 54bf6ef10cd..ef37863e5e9 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -430,18 +430,37 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_in_rank_num_list, tokens_per_expert_stats_list, ) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx) + if layer.is_rearrange is False: + expert_id_to_ep_rank_array = paddle.arange(layer.num_experts).cast("int32") - topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( - gating_logits=gate_out, - expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, - expert_in_rank_num_list=expert_in_rank_num_list, - tokens_per_expert_stats_list=tokens_per_expert_stats_list, - bias=layer.gate_correction_bias, - moe_topk=self.top_k, - apply_norm_weight=True, - enable_softmax_top_k_fused=False, - redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, - ) + if layer.topk_method == "noaux_tc": + from .moe import get_moe_scores + + score, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + ) + else: + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( + gating_logits=gate_out, + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + bias=layer.gate_correction_bias, + moe_topk=self.top_k, + apply_norm_weight=True, + enable_softmax_top_k_fused=False, + redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1, + ) else: if layer.topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index ede87972185..352a4b657ef 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -27,7 +27,7 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc + from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant except: logger.warning("import noaux_tc Failed!") import numpy as np @@ -74,6 +74,10 @@ def get_moe_scores( routed_scaling_factor, e_score_correction_bias, renormalize: bool = False, + expert_id_to_ep_rank_array: paddle.Tensor = None, + expert_in_rank_num_list: paddle.Tensor = None, + tokens_per_expert_stats_list: paddle.Tensor = None, + redundant_ep_rank_num_plus_one: int = 1, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. @@ -81,15 +85,30 @@ def get_moe_scores( scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" scores_with_bias = scores + e_score_correction_bias - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - n_group if n_group > 0 else 1, - topk_group if topk_group > 0 else 1, - top_k, - renormalize, - routed_scaling_factor, - ) + if expert_id_to_ep_rank_array is None: + scores, topk_values, topk_idx = noaux_tc( + scores, + scores_with_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + ) + else: + scores, topk_values, topk_idx, _ = noaux_tc_redundant( + scores, + scores_with_bias, + expert_id_to_ep_rank_array, + expert_in_rank_num_list, + tokens_per_expert_stats_list, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + redundant_ep_rank_num_plus_one, + ) return scores, topk_values, topk_idx @@ -196,6 +215,7 @@ def __init__( self.quant_method = get_moe_method() assert self.quant_method is not None, "self.quant_method should not be None" self.redundant_table_manger = redundant_table_manger + self.is_rearrange = False if self.ep_size > 1: self.quant_method.init_ep(self) @@ -438,7 +458,7 @@ def load_experts_weight( ) ] ep_rank_to_expert_id_list = [i for i in range(self.num_experts)] - if self.redundant_table_manger is not None: + if self.redundant_table_manger is not None and is_rearrange is True: ( ep_rank_to_expert_id_list, expert_id_to_ep_rank_array, diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 75947590be8..9ff0a218510 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -211,7 +211,7 @@ def load_state_dict(self, state_dict): self.shared_experts.load_state_dict(state_dict) def update_state_dict(self, state_dict): - self.fused_moe.load_state_dict(state_dict, True) + self.experts.load_state_dict(state_dict, True) def forward(self, hidden_states: paddle.Tensor): out = self.experts(hidden_states, self.gate) diff --git a/tests/operators/test_noaux_tc_redundant.py b/tests/operators/test_noaux_tc_redundant.py new file mode 100644 index 00000000000..1afa24aabb8 --- /dev/null +++ b/tests/operators/test_noaux_tc_redundant.py @@ -0,0 +1,111 @@ +import unittest + +import paddle + +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + + +class TestMoeRouting(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + print(paddle.device.cuda.get_device_properties()) + print(paddle.__git_commit__) + + def native_group_topk( + self, + gating_output: paddle.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + routed_scaling_factor: float, + e_score_correction_bias: paddle.Tensor, + ): + original_scores = paddle.nn.functional.sigmoid(gating_output) + if len(e_score_correction_bias.shape) == 1: + e_score_correction_bias = e_score_correction_bias.unsqueeze(0) + scores = original_scores + e_score_correction_bias + + num_token, n_experts = scores.shape + group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores) # [n, n_group] + group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand([num_token, num_expert_group, n_experts // num_expert_group]) + .reshape([num_token, -1]) + ) + tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) + + topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1] + topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1) + + if renormalize: + topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + + return topk_weights, topk_ids + + def test_group_topk(self): + + renormalize = True + + test_cases = [ + # (num_experts, n_group, topk_group, top_k, routed_scaling_factor) + (128, 1, 1, 8, 1.0), # glm45-air + (256, 8, 4, 8, 2.5), # deepseek + ] + + for case_tuple in test_cases: + num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple + for num_tokens in [1, 32, 64, 128]: + gating_output = paddle.rand([num_tokens, num_experts]) + e_score_correction_bias = paddle.rand([1, num_experts]) + expert_id_to_ep_rank_array = paddle.arange(num_experts, dtype="int32").reshape([num_experts, 1]) + expert_in_rank_num_list = paddle.ones([num_experts, 1], dtype="int32") + tokens_per_expert_stats_list = paddle.arange(num_experts, dtype="int32").reshape([num_experts, 1]) + + ref_topk_values, ref_topk_idx = self.native_group_topk( + gating_output=gating_output, + topk=top_k, + renormalize=renormalize, + num_expert_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + new_score, topk_values, topk_idx = get_moe_scores( + gating_output=gating_output, + n_group=n_group, + topk_group=topk_group, + top_k=top_k, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + renormalize=renormalize, + expert_id_to_ep_rank_array=expert_id_to_ep_rank_array, + expert_in_rank_num_list=expert_in_rank_num_list, + tokens_per_expert_stats_list=tokens_per_expert_stats_list, + ) + + equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() + equal_topk_ids = paddle.allclose( + topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0 + ).item() + print( + f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}" + ) + if not equal_topk_value: + print(f"ref_topk_values = {ref_topk_values}") + print(f"topk_values = {topk_values}") + if not equal_topk_ids: + print(f"ref_topk_idx = {ref_topk_idx}") + print(f"topk_idx = {topk_idx}") + assert equal_topk_value and equal_topk_ids + + +if __name__ == "__main__": + unittest.main() From 0b9f82b029af6e1ce19900eaf954627eb240994f Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Thu, 20 Nov 2025 17:41:50 +0800 Subject: [PATCH 2/4] noaux_eplb --- custom_ops/gpu_ops/noauxtc_kernel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index 0093ab9bdea..7ac3bb17491 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -1077,9 +1077,9 @@ INSTANTIATE_NOAUX_TC(float, int32_t); T * topk_values, \ IdxT * topk_indices, \ T * scores_with_bias, \ - int32_t* expert_id_to_ep_rank_array, \ - int32_t* expert_in_rank_num_list, \ - int32_t* tokens_per_expert_stats_list, \ + int32_t * expert_id_to_ep_rank_array, \ + int32_t * expert_in_rank_num_list, \ + int32_t * tokens_per_expert_stats_list, \ int64_t const num_tokens, \ int64_t const num_experts, \ int64_t const n_group, \ From 3c17044031627cfafba286b6e51f3ebd41342e53 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Fri, 21 Nov 2025 08:10:26 +0800 Subject: [PATCH 3/4] noaux_eplb --- fastdeploy/model_executor/layers/moe/ep.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index ef37863e5e9..a1dcda67f7e 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -430,8 +430,6 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_in_rank_num_list, tokens_per_expert_stats_list, ) = layer.redundant_table_manger.get_ep_rank_to_expert_id_list_by_layer(layer.layer_idx) - if layer.is_rearrange is False: - expert_id_to_ep_rank_array = paddle.arange(layer.num_experts).cast("int32") if layer.topk_method == "noaux_tc": from .moe import get_moe_scores From 106ebec0b989bc649f9f700e2e361ead195024b3 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 Date: Fri, 21 Nov 2025 12:04:16 +0800 Subject: [PATCH 4/4] noaux_eplb --- custom_ops/gpu_ops/cpp_extensions.cc | 2 +- custom_ops/gpu_ops/noaux_tc_redundant.cu | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index ece563b671c..5deaa6bc94f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1498,7 +1498,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); - m.def("noaux_tc_redunant", + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); diff --git a/custom_ops/gpu_ops/noaux_tc_redundant.cu b/custom_ops/gpu_ops/noaux_tc_redundant.cu index 1fd8c1c2892..1fcb09d269c 100644 --- a/custom_ops/gpu_ops/noaux_tc_redundant.cu +++ b/custom_ops/gpu_ops/noaux_tc_redundant.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include #include