Skip to content

[Bug]: vLLM cold start on MOE models not optimal #29992

@zou3519

Description

@zou3519

Your current environment

main

🐛 Describe the bug

Previously, for dense models, with FX graph splitting, vLLM produces 3 unique graphs. (the model is split at the attention operator). The graph split ends up producing ~50 graphs, and we only needed to compile 3 unique graphs out of the 50.

Looking at a tlparse for llama4 maverick, which is an MOE model:

  • every other layer has a MOE (instead of an nn.Linear feedforward)
  • this means there should be at most <= 6 unique graphs

However, all of the graph splits that has an moe operator (e.g. torch.ops.vllm.moe_forward) are actually unique, so there are at least 25 unique graphs that need to be compiled.

The only difference between the graphs is the name of the MOE layer.

We should hide the name of the MOE layer from existing in the graph (maybe via context) to avoid this and bring the number of unique graphs back to <=6 for MOE models.

Image Image

cc @ProExpertProg

Also, this potentially has implications for switching to inductor graph partition. Depending on what model we were actually benchmarking (I hope we were benchmarking a dense model?) the compile time speedup/slowdown number might change after this.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    To triage

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions