From c0a7a0a8c55836f0d97b319e8d543c84df737681 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 26 Nov 2025 15:03:04 -0500 Subject: [PATCH 1/3] update attention imports Signed-off-by: Matthew Bonanni --- .../test_rocm_attention_backends_selection.py | 9 +++------ .../kv_connector/unit/test_backwards_compatibility.py | 6 +++--- vllm/attention/backends/abstract.py | 2 -- vllm/attention/backends/registry.py | 10 ++++------ vllm/attention/layers/chunked_local_attention.py | 3 +-- vllm/config/model.py | 3 +-- vllm/config/multimodal.py | 11 ++--------- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 4 ++-- .../kv_connector/v1/decode_bench_connector.py | 4 ++-- .../kv_transfer/kv_connector/v1/lmcache_connector.py | 4 ++-- .../v1/lmcache_integration/vllm_v1_adapter.py | 4 ++-- .../kv_connector/v1/lmcache_mp_connector.py | 4 ++-- .../kv_transfer/kv_connector/v1/multi_connector.py | 4 ++-- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ++--- .../kv_connector/v1/p2p/p2p_nccl_connector.py | 4 ++-- .../kv_connector/v1/shared_storage_connector.py | 4 ++-- vllm/forward_context.py | 8 +++----- vllm/model_executor/layers/attention_layer_base.py | 7 ++----- vllm/model_executor/layers/mamba/abstract.py | 7 ++----- .../compressed_tensors/compressed_tensors.py | 3 +-- vllm/model_executor/layers/quantization/fp8.py | 4 +--- vllm/model_executor/layers/quantization/modelopt.py | 3 +-- vllm/model_executor/layers/quantization/mxfp4.py | 3 +-- vllm/model_executor/layers/quantization/petit.py | 3 +-- vllm/model_executor/layers/quantization/ptpc_fp8.py | 3 +-- .../model_executor/layers/quantization/quark/quark.py | 3 +-- vllm/platforms/cpu.py | 5 +---- vllm/platforms/cuda.py | 10 ++-------- vllm/platforms/interface.py | 5 +---- vllm/platforms/rocm.py | 6 +----- vllm/platforms/tpu.py | 5 +---- vllm/platforms/xpu.py | 7 +------ vllm/v1/attention/backends/cpu_attn.py | 2 -- vllm/v1/attention/backends/flash_attn.py | 2 -- vllm/v1/attention/backends/flex_attention.py | 2 -- vllm/v1/attention/backends/utils.py | 7 +++++-- vllm/v1/kv_offload/spec.py | 4 ++-- vllm/v1/spec_decode/eagle.py | 3 +-- vllm/v1/worker/utils.py | 7 ++----- 39 files changed, 63 insertions(+), 127 deletions(-) diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index 80158d4b7278..77790be6f892 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -139,14 +139,13 @@ def test_standard_attention_backend_selection( import importlib import vllm.envs as envs - from vllm.attention.backends.registry import _Backend importlib.reload(envs) # Convert string backend to enum if provided backend_enum = None if selected_backend: - backend_enum = getattr(_Backend, selected_backend) + backend_enum = getattr(AttentionBackendEnum, selected_backend) # Get the backend class path from vllm.platforms.rocm import RocmPlatform @@ -253,7 +252,6 @@ def test_mla_backend_selection( import importlib import vllm.envs as envs - from vllm.attention.backends.registry import _Backend importlib.reload(envs) @@ -269,7 +267,7 @@ def test_mla_backend_selection( # Convert string backend to enum if provided backend_enum = None if selected_backend: - backend_enum = getattr(_Backend, selected_backend) + backend_enum = getattr(AttentionBackendEnum, selected_backend) from vllm.platforms.rocm import RocmPlatform @@ -301,7 +299,6 @@ def test_mla_backend_selection( def test_aiter_fa_requires_gfx9(mock_vllm_config): """Test that ROCM_AITER_FA requires gfx9 architecture.""" - from vllm.attention.backends.registry import _Backend from vllm.platforms.rocm import RocmPlatform # Mock on_gfx9 to return False @@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): ), ): RocmPlatform.get_attn_backend_cls( - selected_backend=_Backend.ROCM_AITER_FA, + selected_backend=AttentionBackendEnum.ROCM_AITER_FA, head_size=128, dtype=torch.float16, kv_cache_dtype="auto", diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py index f51001a6ec12..7cd23805c599 100644 --- a/tests/v1/kv_connector/unit/test_backwards_compatibility.py +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -14,6 +14,7 @@ import pytest +from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, @@ -24,7 +25,6 @@ from .utils import create_scheduler, create_vllm_config if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -68,7 +68,7 @@ def save_kv_layer( self, layer_name: str, kv_layer, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs, ) -> None: pass @@ -119,7 +119,7 @@ def save_kv_layer( self, layer_name: str, kv_layer, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs, ) -> None: pass diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a321167b8090..b1518b2ca2ef 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -178,8 +178,6 @@ def supports_attn_type(cls, attn_type: str) -> bool: By default, only supports decoder attention. Backends should override this to support other attention types. """ - from vllm.attention.backends.abstract import AttentionType - return attn_type == AttentionType.DECODER @classmethod diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 125e4e382774..98671f66aee7 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -4,14 +4,12 @@ from collections.abc import Callable from enum import Enum, EnumMeta -from typing import TYPE_CHECKING, cast +from typing import cast +from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - logger = init_logger(__name__) @@ -98,7 +96,7 @@ def get_path(self, include_classname: bool = True) -> str: path = path.rsplit(".", 1)[0] return path - def get_class(self) -> "type[AttentionBackend]": + def get_class(self) -> type[AttentionBackend]: """Get the backend class (respects overrides). Returns: @@ -160,7 +158,7 @@ def get_path(self, include_classname: bool = True) -> str: path = path.rsplit(".", 1)[0] return path - def get_class(self) -> "type[AttentionBackend]": + def get_class(self) -> type[AttentionBackend]: """Get the backend class (respects overrides). Returns: diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 48fcc6fa736b..0ced0028ded9 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -5,6 +5,7 @@ import torch from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.config.vllm import VllmConfig @@ -22,8 +23,6 @@ KVCacheSpec, ) -from ..layer import Attention - @functools.lru_cache def create_chunked_local_attention_backend( diff --git a/vllm/config/model.py b/vllm/config/model.py index 25972f097f53..0fe51760289e 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -14,6 +14,7 @@ from transformers.configuration_utils import ALLOWED_LAYER_TYPES import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType @@ -53,7 +54,6 @@ import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -61,7 +61,6 @@ else: PretrainedConfig = Any - AttentionBackendEnum = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 590bc4dcd076..8a2936de96d6 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -2,19 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import Any, Literal, TypeAlias from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config.utils import config from vllm.utils.hashing import safe_hash -if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum -else: - AttentionBackendEnum = Any - @dataclass class BaseDummyOptions: @@ -170,9 +166,6 @@ def _validate_limit_per_prompt( def _validate_mm_encoder_attn_backend( cls, value: str | AttentionBackendEnum | None ) -> AttentionBackendEnum | None: - # We need to import the real type here (deferred to avoid circular import). - from vllm.attention.backends.registry import AttentionBackendEnum - if isinstance(value, str) and value.upper() == "XFORMERS": raise ValueError( "Attention backend 'XFORMERS' has been removed (See PR #29262 for " diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 74f09278b7bb..cac45425bb7a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -42,12 +42,12 @@ import torch +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -239,7 +239,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): return def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] ): """ Initialize with a single KV cache tensor used by all layers. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py index 9cd7d93c92fa..e9b2bd392b0e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -36,6 +36,7 @@ import torch +from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, @@ -45,7 +46,6 @@ from vllm.utils.math_utils import cdiv if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -117,7 +117,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: # This connector doesn't save KV cache (benchmarking only) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 0c24a53fb754..30da424ddcca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -7,6 +7,7 @@ LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, ) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, @@ -17,7 +18,6 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig @@ -91,7 +91,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 94572b02fa87..15ac5b049fce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -29,6 +29,7 @@ from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.plugin.plugin_launcher import PluginLauncher +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, @@ -50,7 +51,6 @@ from vllm.version import __version__ as VLLM_VERSION if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -915,7 +915,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs, ) -> None: """Start saving the a layer of KV cache from vLLM's paged buffer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index d1d3e475cc88..a4bddf5e0316 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -10,6 +10,7 @@ from lmcache.integration.vllm.utils import mla_enabled from lmcache.utils import init_logger as lmcache_init_logger +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, @@ -26,7 +27,6 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -490,7 +490,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index c9d08e9b78ed..f47e8ca7e6c5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -7,6 +7,7 @@ import torch +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -27,7 +28,6 @@ from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -216,7 +216,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs, ) -> None: for c in self._connectors: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff51840b84b1..62e039b35a48 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,7 +20,7 @@ import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -51,7 +51,6 @@ from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -308,7 +307,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs, ) -> None: """NixlConnector does not save explicitly.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index a124a0d519db..8f3a62d7bcdb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -7,6 +7,7 @@ import regex as re import torch +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, @@ -22,7 +23,6 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig @@ -243,7 +243,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 4611b4d1ff7b..ed641cfc43dd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -7,6 +7,7 @@ import safetensors import torch +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, @@ -19,7 +20,6 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig @@ -211,7 +211,7 @@ def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", + attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 635419bc7cad..173d366267e8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,19 +5,17 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import Any, NamedTuple import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata - logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 @@ -195,7 +193,7 @@ class ForwardContext: for each microbatch. Set dynamically for each forward pass """ - attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]] + attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index ffbef470b186..a60cf787135c 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -3,14 +3,11 @@ """Base class for attention-like layers.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import KVCacheSpec -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - class AttentionLayerBase(ABC): """ @@ -22,7 +19,7 @@ class AttentionLayerBase(ABC): """ @abstractmethod - def get_attn_backend(self) -> type["AttentionBackend"]: + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this layer.""" pass diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index aa919d6fdc35..74f4383e9c23 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -2,18 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Iterable -from typing import TYPE_CHECKING import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_mamba_attn_backend from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - class MambaBase(AttentionLayerBase): """ @@ -66,6 +63,6 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: ), ) - def get_attn_backend(self) -> type["AttentionBackend"]: + def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this Mamba layer.""" return get_mamba_attn_backend(self.mamba_type) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 2800f90ce0b6..ddeb781c6216 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -18,6 +18,7 @@ from compressed_tensors.transform import TransformConfig import vllm.envs as envs +from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -131,8 +132,6 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - if isinstance(layer, LinearBase): # collect schemes quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e033032903e8..7dfc8a9c36c3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,6 +14,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.layer import Attention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -277,7 +278,6 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config": def get_xpu_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention from vllm.model_executor.layers.quantization.ipex_quant import ( XPUFp8LinearMethod, XPUFp8MoEMethod, @@ -307,8 +307,6 @@ def get_xpu_quant_method( def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - if current_platform.is_xpu(): return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2cf7089e0ff9..80f8e3a03e7c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -12,6 +12,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -149,8 +150,6 @@ def is_layer_excluded(self, prefix: str) -> bool: def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - # handle kv-cache first so we can focus only on weight quantization thereafter if isinstance(layer, Attention): return self.KVCacheMethodCls(self) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d975131f7cff..bc241ac692e2 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter from vllm import envs +from vllm.attention.layer import Attention from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -184,8 +185,6 @@ def get_config_filenames(cls) -> list[str]: def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - if isinstance(layer, LinearBase): if self.ignored_layers and is_layer_skipped( prefix=prefix, diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py index 402cebc38c21..5ccc73166361 100644 --- a/vllm/model_executor/layers/quantization/petit.py +++ b/vllm/model_executor/layers/quantization/petit.py @@ -8,6 +8,7 @@ import torch from torch.nn.parameter import Parameter +from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( LinearBase, @@ -159,8 +160,6 @@ def is_layer_excluded(self, prefix: str, exclude_modules: list[str]) -> bool: def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - exclude = self.require_exclude_modules() if isinstance(layer, LinearBase): diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 26ba8e5b16bc..ed8a2c7fa084 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods @@ -65,8 +66,6 @@ def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index f59e5e2a0af7..3640e5c45278 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -6,6 +6,7 @@ import torch +from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -102,8 +103,6 @@ def apply_vllm_mapper( # noqa: B027 def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - # Check if the layer is skipped for quantization. exclude_layers = cast(list[str], self.quant_config.get("exclude")) if should_ignore_layer( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index ed655912d396..5f9561366e0d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -14,6 +14,7 @@ import torch from vllm import envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from .interface import CpuArchEnum, Platform, PlatformEnum @@ -21,10 +22,8 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: - AttentionBackendEnum = None VllmConfig = None @@ -135,8 +134,6 @@ def get_attn_backend_cls( use_sparse: bool, attn_type: str | None = None, ) -> str: - from vllm.attention.backends.registry import AttentionBackendEnum - if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e8e14387bb7f..d5c3a177d9c2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -15,6 +15,8 @@ # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml from vllm.utils.torch_utils import cuda_device_count_stateless @@ -22,11 +24,9 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import CacheDType else: - AttentionBackendEnum = None VllmConfig = None CacheDType = None @@ -48,8 +48,6 @@ def _get_backend_priorities( device_capability: DeviceCapability, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" - from vllm.attention.backends.registry import AttentionBackendEnum - if use_mla: if device_capability.major == 10: return [ @@ -265,8 +263,6 @@ def get_current_memory_usage( def get_vit_attn_backend( cls, head_size: int, dtype: torch.dtype ) -> "AttentionBackendEnum": - from vllm.attention.backends.registry import AttentionBackendEnum - # Try FlashAttention first try: backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() @@ -335,8 +331,6 @@ def get_attn_backend_cls( use_sparse: bool, attn_type: str | None = None, ) -> str: - from vllm.attention.backends.abstract import AttentionType - if attn_type is None: attn_type = AttentionType.DECODER diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1e6b53021f88..27c6fac09f49 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -12,12 +12,12 @@ import numpy as np import torch +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType @@ -226,9 +226,6 @@ def import_kernels(cls) -> None: def get_vit_attn_backend( cls, head_size: int, dtype: torch.dtype ) -> "AttentionBackendEnum": - # Import AttentionBackendEnum here to avoid circular import. - from vllm.attention.backends.registry import AttentionBackendEnum - return AttentionBackendEnum.TORCH_SDPA @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0483f6c06ada..ccf3446a3a6e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -8,16 +8,14 @@ import torch import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig -else: - AttentionBackendEnum = None logger = init_logger(__name__) @@ -196,7 +194,6 @@ def get_vit_attn_backend( from importlib.util import find_spec from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import AttentionBackendEnum if rocm_aiter_ops.is_mha_enabled(): # Note: AITER FA is only supported for Qwen-VL models. @@ -222,7 +219,6 @@ def get_attn_backend_cls( attn_type: str | None = None, ) -> str: from vllm._aiter_ops import rocm_aiter_ops - from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: if kv_cache_dtype.startswith("fp8"): diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 944344a22957..2a2b80000be3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -7,6 +7,7 @@ import torch from tpu_info import device +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger @@ -15,7 +16,6 @@ if TYPE_CHECKING: from typing import TypeAlias - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -26,7 +26,6 @@ BlockSize = None VllmConfig = None PoolingParams = None - AttentionBackendEnum = None ParamsType = None logger = init_logger(__name__) @@ -67,8 +66,6 @@ def get_attn_backend_cls( use_sparse, attn_type: str | None = None, ) -> str: - from vllm.attention.backends.registry import AttentionBackendEnum - if use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") if selected_backend != AttentionBackendEnum.PALLAS: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 18a3186b142f..768714fb1672 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -8,16 +8,15 @@ import torch import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig else: VllmConfig = None - AttentionBackendEnum = None logger = init_logger(__name__) @@ -60,8 +59,6 @@ def get_attn_backend_cls( "only NHD layout is supported by XPU attention kernels." ) - from vllm.attention.backends.registry import AttentionBackendEnum - if use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") if selected_backend == AttentionBackendEnum.TRITON_ATTN: @@ -116,8 +113,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def get_vit_attn_backend( cls, head_size: int, dtype: torch.dtype ) -> "AttentionBackendEnum": - from vllm.attention.backends.registry import AttentionBackendEnum - return AttentionBackendEnum.FLASH_ATTN @classmethod diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d0b1f8c1b807..fed7dcdf293b 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -51,8 +51,6 @@ def get_name() -> str: @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """CPU attention supports decoder and encoder-only attention.""" - from vllm.attention.backends.abstract import AttentionType - return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 0fc57cfb1f9d..07d560cf3fa5 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -84,8 +84,6 @@ def get_name() -> str: @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """FlashAttention supports all attention types.""" - from vllm.attention.backends.abstract import AttentionType - return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 3869f1f4164c..8de0a0a11471 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -87,8 +87,6 @@ def get_name() -> str: @classmethod def supports_attn_type(cls, attn_type: str) -> bool: """FlexAttention supports both decoder and encoder-only attention.""" - from vllm.attention.backends.abstract import AttentionType - return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) @staticmethod diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 18e91fd4fd6a..da931b9679a3 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -24,12 +24,15 @@ from vllm.utils.math_utils import cdiv if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionImpl from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout, ) diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index 3afce5589075..2cdd5ba5ffe5 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -6,12 +6,12 @@ import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.worker.worker import OffloadingHandler if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig logger = init_logger(__name__) @@ -51,7 +51,7 @@ def get_manager(self) -> OffloadingManager: def get_handlers( self, kv_caches: dict[str, torch.Tensor], - attn_backends: dict[str, type["AttentionBackend"]], + attn_backends: dict[str, type[AttentionBackend]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: """ Get offloading handlers along with their respective src and dst types. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7b9037c03d4f..7600df48150a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( CompilationMode, CUDAGraphMode, @@ -157,8 +158,6 @@ def __init__( ) # Determine allowed attention backends once during initialization. - from vllm.attention.backends.registry import AttentionBackendEnum - self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 92e4ce3abdba..bd88cb1b253f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING import torch from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layer import Attention from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index @@ -17,9 +17,6 @@ from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec -if TYPE_CHECKING: - from vllm.attention.layer import Attention - class MultiModalBudget: """Helper class to calculate budget information for multi-modal models.""" @@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups( def bind_kv_cache( kv_caches: dict[str, torch.Tensor], - forward_context: dict[str, "Attention"], + forward_context: dict[str, Attention], runner_kv_caches: list[torch.Tensor], num_attn_module: int = 1, ) -> None: From 6620e9ec6c48871f6cc43ab17b594abddb24056a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 26 Nov 2025 15:11:27 -0500 Subject: [PATCH 2/3] fix circular import Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 98671f66aee7..125e4e382774 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -4,12 +4,14 @@ from collections.abc import Callable from enum import Enum, EnumMeta -from typing import cast +from typing import TYPE_CHECKING, cast -from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + logger = init_logger(__name__) @@ -96,7 +98,7 @@ def get_path(self, include_classname: bool = True) -> str: path = path.rsplit(".", 1)[0] return path - def get_class(self) -> type[AttentionBackend]: + def get_class(self) -> "type[AttentionBackend]": """Get the backend class (respects overrides). Returns: @@ -158,7 +160,7 @@ def get_path(self, include_classname: bool = True) -> str: path = path.rsplit(".", 1)[0] return path - def get_class(self) -> type[AttentionBackend]: + def get_class(self) -> "type[AttentionBackend]": """Get the backend class (respects overrides). Returns: From b3fd814951658ce8ec5f16df47e36ac521ba681b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 26 Nov 2025 16:10:01 -0500 Subject: [PATCH 3/3] fix circular import Signed-off-by: Matthew Bonanni --- vllm/attention/backends/abstract.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b1518b2ca2ef..c290670eeacb 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -6,11 +6,10 @@ import torch -from vllm.model_executor.layers.linear import ColumnParallelLinear -from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey - if TYPE_CHECKING: from vllm.config.cache import CacheDType + from vllm.model_executor.layers.linear import ColumnParallelLinear + from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import KVCacheLayoutType @@ -358,7 +357,7 @@ def forward( ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, quant_key: QuantKey): + def fused_output_quant_supported(self, quant_key: "QuantKey"): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization @@ -410,7 +409,7 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - kv_b_proj: ColumnParallelLinear, + kv_b_proj: "ColumnParallelLinear", indexer: object | None = None, ) -> None: raise NotImplementedError