diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 9e912c6d810d..3493abac73f2 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -10,7 +10,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.logger import _print_warning_once @@ -235,6 +235,70 @@ def test_splitting_ops_dynamic(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE +def test_moe_splitting_ops_deepep_ht_piecewise(): + # Non-inductor, non-attn-fusion case: DeepEP HT with dp>1 + # should add MoE ops to splitting_ops on top of attention ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops is not None + assert "vllm::moe_forward" in splitting_ops + assert "vllm::moe_forward_shared" in splitting_ops + + +def test_moe_splitting_ops_deepep_ht_inductor_partition(): + # Inductor partition case: user-provided splitting_ops should be + # preserved and MoE ops should be appended for DeepEP HT with dp>1. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + splitting_ops=[ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ], + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops == [ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] + + +def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor(): + # Pure attn-fusion case without inductor partition: even with + # DeepEP HT and dp>1, we should not re-enable piecewise compilation + # or add MoE ops into splitting_ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ), + ) + assert config.compilation_config.splitting_ops == [] + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL + + def test_should_split(): import torch diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 963b091939e0..20c535306ad4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -939,7 +939,9 @@ def post_init_cudagraph_sizes(self) -> None: # May get recomputed in the model runner if adjustment is needed for spec-decode self.compute_bs_to_padded_graph_size() - def set_splitting_ops_for_v1(self): + def set_splitting_ops_for_v1( + self, all2all_backend: str | None = None, data_parallel_size: int | None = None + ): # To compatible with OOT hardware plugin platform (for example vllm-ascend) # which currently only supports sequence parallelism in eager mode. if self.mode != CompilationMode.VLLM_COMPILE: @@ -954,45 +956,83 @@ def set_splitting_ops_for_v1(self): "mode is CompilationMode.VLLM_COMPILE" ) - if self.use_inductor_graph_partition: - self.set_splitting_ops_for_inductor_graph_partition() - return + added_default_splitting_ops = False - if self.pass_config.fuse_attn_quant: - # here use_inductor_graph_partition is False + if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: self.set_splitting_ops_for_attn_fusion() - return - - if self.splitting_ops is None: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture - # the full cudagraph outside the fx graph. This reduces some - # cpu overhead when the runtime batch_size is not cudagraph - # captured. see https://github.com/vllm-project/vllm/pull/20059 - # for details. Make a copy to avoid mutating the class-level - # list via reference. - self.splitting_ops = list(self._attention_ops) - elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty splitting_ops") - if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + else: + if self.use_inductor_graph_partition: + if self.splitting_ops is None: + added_default_splitting_ops = True + self.set_splitting_ops_for_inductor_graph_partition() + elif self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) + added_default_splitting_ops = True + elif len(self.splitting_ops) == 0: logger.warning_once( - "Piecewise compilation with empty splitting_ops do not" - "contains piecewise cudagraph. Setting cudagraph_" - "mode to NONE. Hint: If you are using attention backends " - "that support cudagraph, consider manually setting " - "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " - "full cudagraphs." + "Using piecewise compilation with empty splitting_ops" ) - self.cudagraph_mode = CUDAGraphMode.NONE - elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: - logger.warning_once( - "Piecewise compilation with empty splitting_ops do not " - "contains piecewise cudagraph. Setting cudagraph_mode " - "to FULL." + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not" + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention " + "backends that support cudagraph, consider manually " + "setting cudagraph_mode to FULL or FULL_DECODE_ONLY " + "to enable full cudagraphs." + ) + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do " + "not contains piecewise cudagraph. Setting " + "cudagraph_mode to FULL." + ) + self.cudagraph_mode = CUDAGraphMode.FULL + self.splitting_ops = [] + + # split MoE ops for cudagraph + moe_ops = [ + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] + backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND + dp_size = data_parallel_size if data_parallel_size is not None else 1 + need_moe_splitting = ( + backend == "deepep_high_throughput" + and dp_size > 1 + # pure attn-fusion without inductor partition deliberately disables + # piecewise graphs and MoE splitting. + and not ( + self.pass_config.fuse_attn_quant + and not self.use_inductor_graph_partition + ) + ) + + if need_moe_splitting and self.splitting_ops is not None: + # if we just initialized default splitting_ops for this config, + # automatically append the MoE ops + if added_default_splitting_ops: + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + + # make sure MoE ops are split out + if not any(op in self.splitting_ops for op in moe_ops): + raise ValueError( + "DeepEP high throughput backend with data_parallel_size > 1 " + "requires splitting MoE ops from cudagraphs. Please ensure " + "'vllm::moe_forward' or 'vllm::moe_forward_shared' are " + "present in CompilationConfig.splitting_ops." ) - self.cudagraph_mode = CUDAGraphMode.FULL - self.splitting_ops = [] def set_splitting_ops_for_inductor_graph_partition(self): assert self.use_inductor_graph_partition diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 735b0afbaaeb..68a0cec64790 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -809,7 +809,10 @@ def has_blocked_weights(): ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." # Do this after all the updates to compilation_config.mode - self.compilation_config.set_splitting_ops_for_v1() + self.compilation_config.set_splitting_ops_for_v1( + all2all_backend=self.parallel_config.all2all_backend, + data_parallel_size=self.parallel_config.data_parallel_size, + ) if self.compilation_config.pass_config.enable_sp: # With pipeline parallelism or dynamo partitioning, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4bf9401b6b05..f2ed3e60270b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -229,27 +229,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Forcing kv cache block size to 64 for FlashMLASparse backend." ) - # lazy import to avoid circular import - from vllm.config import CUDAGraphMode - - compilation_config = vllm_config.compilation_config - if ( - parallel_config.all2all_backend == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE - ): - # TODO: Piecewise Cuda graph might be enabled - # if torch compile cache key issue fixed - # See https://github.com/vllm-project/vllm/pull/25093 - logger.info( - "WideEP: Disabling CUDA Graphs since DeepEP high-throughput " - "kernels are optimized for prefill and are incompatible with " - "CUDA Graphs. " - "In order to use CUDA Graphs for decode-optimized workloads, " - "use --all2all-backend with another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter." - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE @classmethod def get_current_memory_usage(