Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import importlib

import vllm.envs as envs
from vllm.attention.backends.registry import _Backend

importlib.reload(envs)

# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)

# Get the backend class path
from vllm.platforms.rocm import RocmPlatform
Expand Down Expand Up @@ -253,7 +252,6 @@ def test_mla_backend_selection(
import importlib

import vllm.envs as envs
from vllm.attention.backends.registry import _Backend

importlib.reload(envs)

Expand All @@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)

from vllm.platforms.rocm import RocmPlatform

Expand Down Expand Up @@ -301,7 +299,6 @@ def test_mla_backend_selection(

def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
from vllm.attention.backends.registry import _Backend
from vllm.platforms.rocm import RocmPlatform

# Mock on_gfx9 to return False
Expand All @@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
),
):
RocmPlatform.get_attn_backend_cls(
selected_backend=_Backend.ROCM_AITER_FA,
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128,
dtype=torch.float16,
kv_cache_dtype="auto",
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/kv_connector/unit/test_backwards_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
Expand All @@ -24,7 +25,6 @@
from .utils import create_scheduler, create_vllm_config

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
Expand Down Expand Up @@ -68,7 +68,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
Expand Down Expand Up @@ -119,7 +119,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
Expand Down
11 changes: 4 additions & 7 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

import torch

from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey

if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType

Expand Down Expand Up @@ -178,8 +177,6 @@ def supports_attn_type(cls, attn_type: str) -> bool:
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention.backends.abstract import AttentionType

return attn_type == AttentionType.DECODER

@classmethod
Expand Down Expand Up @@ -360,7 +357,7 @@ def forward(
) -> torch.Tensor:
raise NotImplementedError

def fused_output_quant_supported(self, quant_key: QuantKey):
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
Expand Down Expand Up @@ -412,7 +409,7 @@ def __init__(
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
Expand All @@ -22,8 +23,6 @@
KVCacheSpec,
)

from ..layer import Attention


@functools.lru_cache
def create_chunked_local_attention_backend(
Expand Down
3 changes: 1 addition & 2 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.configuration_utils import ALLOWED_LAYER_TYPES

import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
Expand Down Expand Up @@ -53,15 +54,13 @@

import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor
else:
PretrainedConfig = Any

AttentionBackendEnum = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
Expand Down
11 changes: 2 additions & 9 deletions vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
from typing import Any, Literal, TypeAlias

from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass

from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash

if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
else:
AttentionBackendEnum = Any


@dataclass
class BaseDummyOptions:
Expand Down Expand Up @@ -170,9 +166,6 @@ def _validate_limit_per_prompt(
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum

if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@

import torch

from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
Expand Down Expand Up @@ -239,7 +239,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
return

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
"""
Initialize with a single KV cache tensor used by all layers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
Expand All @@ -45,7 +46,6 @@
from vllm.utils.math_utils import cdiv

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
Expand Down Expand Up @@ -117,7 +117,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
Expand All @@ -17,7 +18,6 @@
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
Expand Down Expand Up @@ -91,7 +91,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
Expand All @@ -50,7 +51,6 @@
from vllm.version import __version__ as VLLM_VERSION

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand Down Expand Up @@ -915,7 +915,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
Expand All @@ -26,7 +27,6 @@
from vllm.v1.utils import ConstantList

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
Expand Down Expand Up @@ -490,7 +490,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
Expand All @@ -27,7 +28,6 @@
from vllm.v1.outputs import KVConnectorOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
Expand Down Expand Up @@ -216,7 +216,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
for c in self._connectors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import zmq

from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
Expand Down Expand Up @@ -51,7 +51,6 @@
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
Expand Down Expand Up @@ -308,7 +307,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import regex as re
import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
Expand All @@ -22,7 +23,6 @@
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
Expand Down Expand Up @@ -243,7 +243,7 @@ def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
Expand Down
Loading