Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,9 @@ def post_init_cudagraph_sizes(self) -> None:
# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()

def set_splitting_ops_for_v1(self):
def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
):
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
assert self.mode == CompilationMode.VLLM_COMPILE, (
Expand Down Expand Up @@ -860,6 +862,18 @@ def set_splitting_ops_for_v1(self):
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []

# split moe op for cudagraph
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not apply to the cases above (inductor partition or attn fusion)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point, fixed!
Also added the unit test, please take a look again @ProExpertProg

backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
if backend == "deepep_high_throughput" and dp_size > 1 and self.splitting_ops:
moe_ops = [
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)

def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
if self.splitting_ops is None:
Expand Down
5 changes: 4 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,10 @@ def __post_init__(self):

# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size,
)

if self.compilation_config.pass_config.enable_sequence_parallelism:
# With pipeline parallelism or dynamo partitioning,
Expand Down
21 changes: 0 additions & 21 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,27 +231,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode

compilation_config = vllm_config.compilation_config
if (
parallel_config.all2all_backend == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
# TODO: Piecewise Cuda graph might be enabled
# if torch compile cache key issue fixed
# See https://github.com/vllm-project/vllm/pull/25093
logger.info(
"WideEP: Disabling CUDA Graphs since DeepEP high-throughput "
"kernels are optimized for prefill and are incompatible with "
"CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

@classmethod
def get_current_memory_usage(
Expand Down