Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import _print_warning_once
Expand Down Expand Up @@ -235,6 +235,70 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE


def test_moe_splitting_ops_deepep_ht_piecewise():
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
# should add MoE ops to splitting_ops on top of attention ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops is not None
assert "vllm::moe_forward" in splitting_ops
assert "vllm::moe_forward_shared" in splitting_ops


def test_moe_splitting_ops_deepep_ht_inductor_partition():
# Inductor partition case: user-provided splitting_ops should be
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=[
"vllm::unified_attention",
"vllm::moe_forward",
"vllm::moe_forward_shared",
],
),
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops == [
"vllm::unified_attention",
"vllm::moe_forward",
"vllm::moe_forward_shared",
]


def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
# Pure attn-fusion case without inductor partition: even with
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
# or add MoE ops into splitting_ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
),
)
assert config.compilation_config.splitting_ops == []
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL


def test_should_split():
import torch

Expand Down
110 changes: 75 additions & 35 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,9 @@ def post_init_cudagraph_sizes(self) -> None:
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()

def set_splitting_ops_for_v1(self):
def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
if self.mode != CompilationMode.VLLM_COMPILE:
Expand All @@ -954,45 +956,83 @@ def set_splitting_ops_for_v1(self):
"mode is CompilationMode.VLLM_COMPILE"
)

if self.use_inductor_graph_partition:
self.set_splitting_ops_for_inductor_graph_partition()
return
added_default_splitting_ops = False

if self.pass_config.fuse_attn_quant:
# here use_inductor_graph_partition is False
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
self.set_splitting_ops_for_attn_fusion()
return

if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty splitting_ops")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
else:
if self.use_inductor_graph_partition:
if self.splitting_ops is None:
added_default_splitting_ops = True
self.set_splitting_ops_for_inductor_graph_partition()
elif self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True
elif len(self.splitting_ops) == 0:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention backends "
"that support cudagraph, consider manually setting "
"cudagraph_mode to FULL or FULL_DECODE_ONLY to enable "
"full cudagraphs."
"Using piecewise compilation with empty splitting_ops"
)
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not "
"contains piecewise cudagraph. Setting cudagraph_mode "
"to FULL."
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do not"
"contains piecewise cudagraph. Setting cudagraph_"
"mode to NONE. Hint: If you are using attention "
"backends that support cudagraph, consider manually "
"setting cudagraph_mode to FULL or FULL_DECODE_ONLY "
"to enable full cudagraphs."
)
self.cudagraph_mode = CUDAGraphMode.NONE
elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.warning_once(
"Piecewise compilation with empty splitting_ops do "
"not contains piecewise cudagraph. Setting "
"cudagraph_mode to FULL."
)
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []

# split MoE ops for cudagraph
moe_ops = [
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
need_moe_splitting = (
backend == "deepep_high_throughput"
and dp_size > 1
# pure attn-fusion without inductor partition deliberately disables
# piecewise graphs and MoE splitting.
and not (
self.pass_config.fuse_attn_quant
and not self.use_inductor_graph_partition
)
)

if need_moe_splitting and self.splitting_ops is not None:
# if we just initialized default splitting_ops for this config,
# automatically append the MoE ops
if added_default_splitting_ops:
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)

# make sure MoE ops are split out
if not any(op in self.splitting_ops for op in moe_ops):
raise ValueError(
"DeepEP high throughput backend with data_parallel_size > 1 "
"requires splitting MoE ops from cudagraphs. Please ensure "
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
"present in CompilationConfig.splitting_ops."
)
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []

def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
Expand Down
5 changes: 4 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,10 @@ def has_blocked_weights():
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1()
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size,
)

if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning,
Expand Down
21 changes: 0 additions & 21 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,27 +229,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode

compilation_config = vllm_config.compilation_config
if (
parallel_config.all2all_backend == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# TODO: Piecewise Cuda graph might be enabled
# if torch compile cache key issue fixed
# See https://github.com/vllm-project/vllm/pull/25093
logger.info(
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
"kernels are optimized for prefill and are incompatible with "
"CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

@classmethod
def get_current_memory_usage(
Expand Down