From 0c58593f34570fd9895956b56c386f4112a9f015 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 26 Nov 2025 16:14:19 -0800 Subject: [PATCH 1/6] enable cuda graph for deepepHT Signed-off-by: yewentao256 --- vllm/config/compilation.py | 16 +++++++++++++++- vllm/config/vllm.py | 5 ++++- vllm/platforms/cuda.py | 21 --------------------- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 865d045676d1..434031940bcc 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -812,7 +812,9 @@ def post_init_cudagraph_sizes(self) -> None: # May get recomputed in the model runner if adjustment is needed for spec-decode self.compute_bs_to_padded_graph_size() - def set_splitting_ops_for_v1(self): + def set_splitting_ops_for_v1( + self, all2all_backend: str | None = None, data_parallel_size: int | None = None + ): # NOTE: this function needs to be called only when mode is # CompilationMode.VLLM_COMPILE assert self.mode == CompilationMode.VLLM_COMPILE, ( @@ -860,6 +862,18 @@ def set_splitting_ops_for_v1(self): self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + # split moe op for cudagraph + backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND + dp_size = data_parallel_size if data_parallel_size is not None else 1 + if backend == "deepep_high_throughput" and dp_size > 1 and self.splitting_ops: + moe_ops = [ + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + def set_splitting_ops_for_inductor_graph_partition(self): assert self.use_inductor_graph_partition if self.splitting_ops is None: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 9342564aa3d3..7f5badfd1ee4 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -649,7 +649,10 @@ def __post_init__(self): # Do this after all the updates to compilation_config.mode if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - self.compilation_config.set_splitting_ops_for_v1() + self.compilation_config.set_splitting_ops_for_v1( + all2all_backend=self.parallel_config.all2all_backend, + data_parallel_size=self.parallel_config.data_parallel_size, + ) if self.compilation_config.pass_config.enable_sequence_parallelism: # With pipeline parallelism or dynamo partitioning, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e8e14387bb7f..b659560d251b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -231,27 +231,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Forcing kv cache block size to 64 for FlashMLASparse backend." ) - # lazy import to avoid circular import - from vllm.config import CUDAGraphMode - - compilation_config = vllm_config.compilation_config - if ( - parallel_config.all2all_backend == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE - ): - # TODO: Piecewise Cuda graph might be enabled - # if torch compile cache key issue fixed - # See https://github.com/vllm-project/vllm/pull/25093 - logger.info( - "WideEP: Disabling CUDA Graphs since DeepEP high-throughput " - "kernels are optimized for prefill and are incompatible with " - "CUDA Graphs. " - "In order to use CUDA Graphs for decode-optimized workloads, " - "use --all2all-backend with another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter." - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE @classmethod def get_current_memory_usage( From 3faec2f97837995dcdcbe1548c31a1b4b32df527 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 1 Dec 2025 15:04:32 -0800 Subject: [PATCH 2/6] fix inductor Signed-off-by: yewentao256 --- vllm/config/compilation.py | 77 +++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index d304c48c69cf..78862b44c45d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -873,48 +873,55 @@ def set_splitting_ops_for_v1( if self.use_inductor_graph_partition: self.set_splitting_ops_for_inductor_graph_partition() - return - - if self.pass_config.enable_attn_fusion: + elif self.pass_config.enable_attn_fusion: # here use_inductor_graph_partition is False self.set_splitting_ops_for_attn_fusion() - return - - if self.splitting_ops is None: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture - # the full cudagraph outside the fx graph. This reduces some - # cpu overhead when the runtime batch_size is not cudagraph - # captured. see https://github.com/vllm-project/vllm/pull/20059 - # for details. Make a copy to avoid mutating the class-level - # list via reference. - self.splitting_ops = list(self._attention_ops) - elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty splitting_ops") - if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: - logger.warning_once( - "Piecewise compilation with empty splitting_ops do not" - "contains piecewise cudagraph. Setting cudagraph_" - "mode to NONE. Hint: If you are using attention backends " - "that support cudagraph, consider manually setting " - "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " - "full cudagraphs." - ) - self.cudagraph_mode = CUDAGraphMode.NONE - elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + else: + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) + elif len(self.splitting_ops) == 0: logger.warning_once( - "Piecewise compilation with empty splitting_ops do not " - "contains piecewise cudagraph. Setting cudagraph_mode " - "to FULL." + "Using piecewise compilation with empty splitting_ops" ) - self.cudagraph_mode = CUDAGraphMode.FULL - self.splitting_ops = [] + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not" + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention " + "backends that support cudagraph, consider manually " + "setting cudagraph_mode to FULL or FULL_DECODE_ONLY " + "to enable full cudagraphs." + ) + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do " + "not contains piecewise cudagraph. Setting " + "cudagraph_mode to FULL." + ) + self.cudagraph_mode = CUDAGraphMode.FULL + self.splitting_ops = [] - # split moe op for cudagraph + # split MoE ops for cudagraph backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND dp_size = data_parallel_size if data_parallel_size is not None else 1 - if backend == "deepep_high_throughput" and dp_size > 1 and self.splitting_ops: + if ( + backend == "deepep_high_throughput" + and dp_size > 1 + and self.splitting_ops + and ( + not self.pass_config.enable_attn_fusion + or self.use_inductor_graph_partition + ) + ): moe_ops = [ "vllm::moe_forward", "vllm::moe_forward_shared", From 2db7a0b05b8d3b49c01e280d49ee0b453fb8d8e9 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 1 Dec 2025 15:09:22 -0800 Subject: [PATCH 3/6] add unit test Signed-off-by: yewentao256 --- tests/compile/test_config.py | 61 +++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index a9e5ccee520e..f48bde252ef6 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -9,7 +9,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config.compilation import CompilationMode from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -233,6 +233,65 @@ def test_splitting_ops_dynamic(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE +def test_moe_splitting_ops_deepep_ht_piecewise(): + # Non-inductor, non-attn-fusion case: DeepEP HT with dp>1 + # should add MoE ops to splitting_ops on top of attention ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops is not None + assert "vllm::moe_forward" in splitting_ops + assert "vllm::moe_forward_shared" in splitting_ops + + +def test_moe_splitting_ops_deepep_ht_inductor_partition(): + # Inductor partition case: user-provided splitting_ops should be + # preserved and MoE ops should be appended for DeepEP HT with dp>1. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + splitting_ops=["vllm::unified_attention"], + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops is not None + assert "vllm::unified_attention" in splitting_ops + assert "vllm::moe_forward" in splitting_ops + assert "vllm::moe_forward_shared" in splitting_ops + + +def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor(): + # Pure attn-fusion case without inductor partition: even with + # DeepEP HT and dp>1, we should not re-enable piecewise compilation + # or add MoE ops into splitting_ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ), + ) + assert config.compilation_config.splitting_ops == [] + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL + + def test_should_split(): import torch From 95aba4e9a23079cbe8ba3791b7c5e2109d8f047f Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Mon, 1 Dec 2025 15:11:28 -0800 Subject: [PATCH 4/6] update Signed-off-by: yewentao256 --- vllm/config/vllm.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 728880da1bae..b276a4622503 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -797,11 +797,10 @@ def has_blocked_weights(): ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." # Do this after all the updates to compilation_config.mode - if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - self.compilation_config.set_splitting_ops_for_v1( - all2all_backend=self.parallel_config.all2all_backend, - data_parallel_size=self.parallel_config.data_parallel_size, - ) + self.compilation_config.set_splitting_ops_for_v1( + all2all_backend=self.parallel_config.all2all_backend, + data_parallel_size=self.parallel_config.data_parallel_size, + ) if self.compilation_config.pass_config.enable_sequence_parallelism: # With pipeline parallelism or dynamo partitioning, From 41ffce7282d6d1bc50ef0ab4b2dec5af8cdb85a9 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 3 Dec 2025 22:21:41 +0000 Subject: [PATCH 5/6] address comments Signed-off-by: yewentao256 --- tests/compile/test_config.py | 15 ++++++---- vllm/config/compilation.py | 55 ++++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index a128f392683c..3493abac73f2 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -264,14 +264,19 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, - splitting_ops=["vllm::unified_attention"], + splitting_ops=[ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ], ), ) splitting_ops = config.compilation_config.splitting_ops - assert splitting_ops is not None - assert "vllm::unified_attention" in splitting_ops - assert "vllm::moe_forward" in splitting_ops - assert "vllm::moe_forward_shared" in splitting_ops + assert splitting_ops == [ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor(): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index fbf1acf70d62..20c535306ad4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -956,13 +956,16 @@ def set_splitting_ops_for_v1( "mode is CompilationMode.VLLM_COMPILE" ) - if self.use_inductor_graph_partition: - self.set_splitting_ops_for_inductor_graph_partition() - elif self.pass_config.fuse_attn_quant: - # here use_inductor_graph_partition is False + added_default_splitting_ops = False + + if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: self.set_splitting_ops_for_attn_fusion() else: - if self.splitting_ops is None: + if self.use_inductor_graph_partition: + if self.splitting_ops is None: + added_default_splitting_ops = True + self.set_splitting_ops_for_inductor_graph_partition() + elif self.splitting_ops is None: # NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx # graph, we keep the piecewise fx graph structure but capture @@ -972,6 +975,7 @@ def set_splitting_ops_for_v1( # for details. Make a copy to avoid mutating the class-level # list via reference. self.splitting_ops = list(self._attention_ops) + added_default_splitting_ops = True elif len(self.splitting_ops) == 0: logger.warning_once( "Using piecewise compilation with empty splitting_ops" @@ -996,24 +1000,39 @@ def set_splitting_ops_for_v1( self.splitting_ops = [] # split MoE ops for cudagraph + moe_ops = [ + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND dp_size = data_parallel_size if data_parallel_size is not None else 1 - if ( + need_moe_splitting = ( backend == "deepep_high_throughput" and dp_size > 1 - and self.splitting_ops - and ( - not self.pass_config.enable_attn_fusion - or self.use_inductor_graph_partition + # pure attn-fusion without inductor partition deliberately disables + # piecewise graphs and MoE splitting. + and not ( + self.pass_config.fuse_attn_quant + and not self.use_inductor_graph_partition ) - ): - moe_ops = [ - "vllm::moe_forward", - "vllm::moe_forward_shared", - ] - for op in moe_ops: - if op not in self.splitting_ops: - self.splitting_ops.append(op) + ) + + if need_moe_splitting and self.splitting_ops is not None: + # if we just initialized default splitting_ops for this config, + # automatically append the MoE ops + if added_default_splitting_ops: + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + + # make sure MoE ops are split out + if not any(op in self.splitting_ops for op in moe_ops): + raise ValueError( + "DeepEP high throughput backend with data_parallel_size > 1 " + "requires splitting MoE ops from cudagraphs. Please ensure " + "'vllm::moe_forward' or 'vllm::moe_forward_shared' are " + "present in CompilationConfig.splitting_ops." + ) def set_splitting_ops_for_inductor_graph_partition(self): assert self.use_inductor_graph_partition From 656b16704c6d0f0b6a96099d607ce1c59f1c2d63 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 15:44:52 -0800 Subject: [PATCH 6/6] update Signed-off-by: yewentao256 --- vllm/config/compilation.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c3825d2d61da..b79200f0e477 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -988,11 +988,7 @@ def set_splitting_ops_for_v1( if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: self.set_splitting_ops_for_attn_fusion() else: - if self.use_inductor_graph_partition: - if self.splitting_ops is None: - added_default_splitting_ops = True - self.set_splitting_ops_for_inductor_graph_partition() - elif self.splitting_ops is None: + if self.splitting_ops is None: # NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx # graph, we keep the piecewise fx graph structure but capture @@ -1044,7 +1040,7 @@ def set_splitting_ops_for_v1( ) ) - if need_moe_splitting and self.splitting_ops is not None: + if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE: # if we just initialized default splitting_ops for this config, # automatically append the MoE ops if added_default_splitting_ops: @@ -1054,17 +1050,16 @@ def set_splitting_ops_for_v1( # make sure MoE ops are split out if not any(op in self.splitting_ops for op in moe_ops): - raise ValueError( + self.cudagraph_mode = CUDAGraphMode.NONE + logger.warning_once( "DeepEP high throughput backend with data_parallel_size > 1 " "requires splitting MoE ops from cudagraphs. Please ensure " "'vllm::moe_forward' or 'vllm::moe_forward_shared' are " "present in CompilationConfig.splitting_ops." ) - - def set_splitting_ops_for_inductor_graph_partition(self): - assert self.use_inductor_graph_partition - if self.splitting_ops is None: - self.splitting_ops = list(self._attention_ops) + elif self.cudagraph_mode.has_full_cudagraphs(): + # fall back to piecewise when MoE splitting is required. + self.cudagraph_mode = CUDAGraphMode.PIECEWISE def set_splitting_ops_for_attn_fusion(self): assert self.pass_config.fuse_attn_quant