Skip to content

TP > 1 with Ray Serve: Use Multiprocessing Executor (Not Ray Executor) #30016

@dmvevents

Description

@dmvevents

TP > 1 with Ray Serve: Use Multiprocessing Executor (Not Ray Executor)

Summary

When deploying vLLM with tensor_parallel_size > 1 on Ray Serve, use the multiprocessing executor (distributed_executor_backend="mp") instead of the Ray executor. This avoids placement group context issues with vLLM v1's subprocess architecture.

Problem Description

Attempting to use tensor_parallel_size > 1 with Ray Serve and the Ray executor (distributed_executor_backend="ray") results in worker initialization failures:

# This FAILS with TP > 1 on Ray Serve:
@serve.deployment(ray_actor_options={"num_gpus": 2})
class VLLMDeployment:
    def __init__(self):
        engine_args = AsyncEngineArgs(
            model="Qwen/Qwen2-7B-Instruct",
            tensor_parallel_size=2,
            distributed_executor_backend="ray",  # ❌ Fails
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

Error symptoms:

  • Workers cannot find or access Ray's placement group
  • "Placement group not found" or similar initialization errors
  • Timeout waiting for workers to initialize

Root Cause

vLLM v1 architecture:

  1. Ray Serve creates an actor with a placement group
  2. vLLM spawns EngineCore as a subprocess
  3. The subprocess loses Ray's placement group context (stored in thread-local storage)
  4. When EngineCore tries to spawn Ray worker actors for TP, they cannot access the placement group
  5. Worker initialization fails

This is an architectural interaction between:

  • Ray's placement group context (thread-local)
  • vLLM v1's subprocess-based EngineCore
  • Ray Serve's placement group creation

✅ Solution: Multiprocessing Executor

Use the multiprocessing executor instead:

@serve.deployment(
    name="vllm-tp2",
    num_replicas=1,
    ray_actor_options={"num_gpus": 2},
)
@serve.ingress(app)
class VLLMDeployment:
    def __init__(self):
        engine_args = AsyncEngineArgs(
            model="Qwen/Qwen2-7B-Instruct",
            tensor_parallel_size=2,
            distributed_executor_backend="mp",  # ✅ Use multiprocessing!
            trust_remote_code=True,
            enforce_eager=True,
            gpu_memory_utilization=0.4,
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

Why This Works

  1. Ray Serve allocates GPUs - ray_actor_options={"num_gpus": 2} gives the actor 2 GPUs
  2. vLLM inherits GPU visibility - Actor has CUDA_VISIBLE_DEVICES=0,1
  3. Multiprocessing spawns workers - vLLM creates worker processes using Python multiprocessing
  4. Workers inherit GPU environment - Each worker process gets access to the GPUs
  5. Ray assigns different GPUs - Each worker gets a different GPU via Ray's internal allocation
  6. NCCL enables communication - Workers coordinate via NCCL for tensor parallelism

No placement groups needed - Everything stays within the Ray Serve actor's resource allocation.

Complete Working Example

#!/usr/bin/env python3
import ray
from ray import serve
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm import AsyncEngineArgs, SamplingParams
from fastapi import FastAPI
from fastapi.responses import JSONResponse

ray.init(address="auto")
app = FastAPI()

@serve.deployment(
    name="vllm-tp2",
    num_replicas=1,
    ray_actor_options={"num_gpus": 2},
    max_ongoing_requests=5,
)
@serve.ingress(app)
class VLLMWithTP2:
    def __init__(self):
        engine_args = AsyncEngineArgs(
            model="Qwen/Qwen2-7B-Instruct",
            tensor_parallel_size=2,
            trust_remote_code=True,
            distributed_executor_backend="mp",  # KEY: multiprocessing
            enforce_eager=True,
            gpu_memory_utilization=0.4,
            max_model_len=2048,
        )
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    @app.post("/v1/chat/completions")
    async def chat(self, request: dict):
        messages = request.get("messages", [])
        max_tokens = min(request.get("max_tokens", 50), 100)

        prompt = "\n".join([
            f"{msg.get('role', 'user')}: {msg.get('content', '')}"
            for msg in messages
        ]) + "\nassistant:"

        sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens)
        request_id = f"req_{int(time.time()*1000)}"

        results_generator = self.engine.generate(prompt, sampling_params, request_id)
        final_output = None
        async for output in results_generator:
            final_output = output

        text = final_output.outputs[0].text if final_output else ""

        return JSONResponse({
            "id": request_id,
            "model": "Qwen/Qwen2-7B-Instruct",
            "tp_size": 2,
            "executor": "multiprocessing",
            "choices": [{
                "message": {"role": "assistant", "content": text.strip()},
                "finish_reason": "stop",
            }],
        })

    @app.get("/health")
    async def health(self):
        return {"status": "healthy", "tp_size": 2, "executor": "multiprocessing"}

# Deploy
serve.run(VLLMWithTP2.bind(), name="vllm-tp2", route_prefix="/")

Test Results

Deployment logs:

Worker_TP0 pid=4701 (GPU 0)
Worker_TP1 pid=4702 (GPU 1)
NCCL version 2.27.3+cuda12.9
comm rank 0 nRanks 2 localRanks 2
Channel 00-23/24 configured
Loading checkpoint shards: 100% | 4/4
Application ready at http://127.0.0.1:8000/

Inference test:

$ curl -X POST http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"messages":[{"role":"user","content":"What is 2+2?"}]}'

{
  "id": "req_1764813981407",
  "model": "Qwen/Qwen2-7B-Instruct",
  "tp_size": 2,
  "executor": "multiprocessing",
  "choices": [{
    "message": {
      "role": "assistant",
      "content": "2 + 2 equals 4."
    }
  }]
}

Status: Working perfectly with TP=2

Architecture Comparison

Ray Executor (Fails with Ray Serve)

Ray Serve Actor
  └─> Placement Group (2 GPUs)
       └─> EngineCore Subprocess ⚠️ (loses PG context)
            └─> Ray Worker Actor 0 ❌ (can't find PG)
            └─> Ray Worker Actor 1 ❌ (can't find PG)

Multiprocessing Executor (Works)

Ray Serve Actor (num_gpus=2)
  └─> CUDA_VISIBLE_DEVICES=0,1
       └─> EngineCore Subprocess (inherits CUDA vars)
            └─> Worker Process 0 (GPU 0) ✅
            └─> Worker Process 1 (GPU 1) ✅
                 └─> NCCL Communication ✅

When to Use Each Executor

Use Multiprocessing Executor ("mp") When:

  • ✅ Deploying with Ray Serve
  • ✅ Single-node TP (2-8 GPUs on one node)
  • ✅ Want simple, reliable deployment
  • ✅ Don't need multi-node TP

Use Ray Executor ("ray") When:

  • Multi-node TP is required (TP across nodes)
  • NOT deploying with Ray Serve (use standalone vLLM)
  • Need Ray's advanced scheduling features

Limitations

  • Single node only - Multiprocessing executor works within one node
  • Multi-node TP - Would require Ray executor (but not compatible with Ray Serve)
  • Recommended TP values - 2-8 GPUs on single node

For multi-node scenarios, use standalone vLLM deployment instead of Ray Serve.

Environment

  • vLLM: v0.11.0
  • Ray: 2.40+
  • Ray Serve: Enabled
  • GPUs: NVIDIA H100 (AWS p5.48xlarge)
  • NCCL: 2.27.3+cuda12.9
  • CUDA: 12.6

Additional Resources

Related Issues

This documents a working solution for TP > 1 with Ray Serve. No code changes needed - just use the multiprocessing executor.

If multi-node TP with Ray Serve is needed in the future, that would require architectural changes to preserve placement group context across subprocess boundaries.


TL;DR: Use distributed_executor_backend="mp" for TP > 1 with Ray Serve. It works perfectly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions