Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD,
COMPRESSED_TENSORS_METHOD, enable_sp, is_310p,
prefill_context_parallel_enable,
update_aclgraph_sizes,
update_cudagraph_capture_sizes, vllm_version_is)
Expand All @@ -55,7 +56,9 @@ class NPUPlatform(Platform):
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1"

supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD]
supported_quantization: list[str] = [
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
]

def is_sleep_mode_available(self) -> bool:
return True
Expand All @@ -78,6 +81,8 @@ def pre_register_and_update(cls,
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)

from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401

Expand Down
Empty file.
300 changes: 300 additions & 0 deletions vllm_ascend/quantization/compressed_tensors/compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from typing import TYPE_CHECKING, Any, Optional, cast

import torch
from compressed_tensors.quantization import (QuantizationArgs,

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]

Check failure on line 4 in vllm_ascend/quantization/compressed_tensors/compressed_tensors.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Skipping analyzing "compressed_tensors.quantization": module is installed, but missing library stubs or py.typed marker [import-untyped]
QuantizationStrategy)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \
CompressedTensorsScheme
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)

from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD

from .schemes.compressed_tensors_w8a8 import CompressedTensorsW8A8
from .schemes.compressed_tensors_w8a8_dynamic import \
CompressedTensorsW8A8Dynamic

if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper

logger = init_logger(__name__)

QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]


def remove_quantization_method():
if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD)


remove_quantization_method()


@register_quantization_config(COMPRESSED_TENSORS_METHOD)
class AscendCompressedTensorsConfig(QuantizationConfig):

def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
config: Optional[dict[str, Any]] = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.quant_description = config

def get_name(self) -> str:
return "compressed-tensors"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")

@classmethod
def get_config_filenames(cls) -> list[str]:
return []

@classmethod
def from_config(cls, config: dict[str,
Any]) -> "AscendCompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)

return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
config=config,
)

@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))

# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.

config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))

target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get(
"format")
format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config
act_quant_format = (
is_activation_quantization_format(format)
if format is not None else
is_activation_quantization_format(quant_format))
input_activations = quant_config.get("input_activations")
if act_quant_format and input_activations is not None:
target_scheme_map[target]["input_activations"] = (
QuantizationArgs.model_validate(
quant_config.get("input_activations")))
return target_scheme_map

def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)

# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = AscendCompressedTensorsLinearMethod(self)
return quant_method
return None

def get_scheme(self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]:
"""
compressed-tensors supports non uniform in the following way:

targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.

Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for inference.
"""

# Find the "target" in the compressed-tensors config
# that our layer conforms to.
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None

# Will be empty for models with only sparsity
weight_quant = input_quant = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)

scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")

if weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
return None

else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
)
return scheme

def _get_scheme_from_parts(
self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> "CompressedTensorsScheme":
act_quant_format = is_activation_quantization_format(self.quant_format)
if act_quant_format and input_quant is not None:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8(strategy=weight_quant.strategy)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Dynamic()

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")

def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_static = not weight_quant.dynamic and not input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric

# Only symmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static

def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)


class AscendCompressedTensorsLinearMethod(LinearMethodBase):

def __init__(self, quantization_config: AscendCompressedTensorsConfig):
self.quantization_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader,
)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details

"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from .compressed_tensors_w8a8 import CompressedTensorsW8A8
from .compressed_tensors_w8a8_dynamic import CompressedTensorsW8A8Dynamic

__all__ = ["CompressedTensorsW8A8", "CompressedTensorsW8A8Dynamic"]
Loading
Loading