diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d7550bf11b..ecc9343166 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -31,7 +31,8 @@ init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, + COMPRESSED_TENSORS_METHOD, enable_sp, is_310p, prefill_context_parallel_enable, update_aclgraph_sizes, update_cudagraph_capture_sizes, vllm_version_is) @@ -55,7 +56,9 @@ class NPUPlatform(Platform): device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" - supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD] + supported_quantization: list[str] = [ + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + ] def is_sleep_mode_available(self) -> bool: return True @@ -78,6 +81,8 @@ def pre_register_and_update(cls, if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) + from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ + AscendCompressedTensorsConfig # noqa: F401 from vllm_ascend.quantization.quant_config import \ AscendQuantConfig # noqa: F401 diff --git a/vllm_ascend/quantization/compressed_tensors/__init__.py b/vllm_ascend/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000..557375be4b --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,300 @@ +from typing import TYPE_CHECKING, Any, Optional, cast + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import ( + QUANTIZATION_METHODS, register_quantization_config) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ + CompressedTensorsScheme +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, is_activation_quantization_format, + should_ignore_layer) + +from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD + +from .schemes.compressed_tensors_w8a8 import CompressedTensorsW8A8 +from .schemes.compressed_tensors_w8a8_dynamic import \ + CompressedTensorsW8A8Dynamic + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] + + +def remove_quantization_method(): + if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS: + QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD) + + +remove_quantization_method() + + +@register_quantization_config(COMPRESSED_TENSORS_METHOD) +class AscendCompressedTensorsConfig(QuantizationConfig): + + def __init__( + self, + target_scheme_map: dict[str, Any], + ignore: list[str], + quant_format: str, + config: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.quant_description = config + + def get_name(self) -> str: + return "compressed-tensors" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, + Any]) -> "AscendCompressedTensorsConfig": + ignore: list[str] = cast(list[str], config.get("ignore", [])) + quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config( + config=config) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + config=config, + ) + + @classmethod + def _quantization_scheme_map_from_config( + cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A dictionary mapping target layer names to their corresponding + quantization_args for weights and input activations + """ + target_scheme_map: dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target][ + "weights"] = QuantizationArgs.model_validate( + quant_config.get("weights")) + + target_scheme_map[target]["input_activations"] = None + target_scheme_map[target]["format"] = quant_config.get( + "format") + format = target_scheme_map[target].get("format") + # If no per-config format defined, use global format in config + act_quant_format = ( + is_activation_quantization_format(format) + if format is not None else + is_activation_quantization_format(quant_format)) + input_activations = quant_config.get("input_activations") + if act_quant_format and input_activations is not None: + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.model_validate( + quant_config.get("input_activations"))) + return target_scheme_map + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = AscendCompressedTensorsLinearMethod(self) + return quant_method + return None + + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for inference. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping, + ) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + if weight_quant is None: + logger.warning_once("Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod") + return None + + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( + weight_quant=weight_quant, + input_quant=input_quant, + ) + return scheme + + def _get_scheme_from_parts( + self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> "CompressedTensorsScheme": + act_quant_format = is_activation_quantization_format(self.quant_format) + if act_quant_format and input_quant is not None: + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8(strategy=weight_quant.strategy) + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Dynamic() + + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) + is_static = not weight_quant.dynamic and not input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + # Only symmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and is_symmetric and is_static + + def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + + +class AscendCompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: AscendCompressedTensorsConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/__init__.py b/vllm_ascend/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000..7f334daf71 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .compressed_tensors_w8a8 import CompressedTensorsW8A8 +from .compressed_tensors_w8a8_dynamic import CompressedTensorsW8A8Dynamic + +__all__ = ["CompressedTensorsW8A8", "CompressedTensorsW8A8Dynamic"] \ No newline at end of file diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py new file mode 100644 index 0000000000..490becf127 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py @@ -0,0 +1,153 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List, Optional + +import torch +import torch_npu +from compressed_tensors.quantization import QuantizationStrategy +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ + CompressedTensorsScheme +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz + +logger = init_logger(__name__) + + +def quant_per_tensor(in_tensor: torch.Tensor, + input_scale: torch.Tensor, + input_offset: torch.Tensor, + function=False): + return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, + torch.qint8, -1, function) + + +class CompressedTensorsW8A8(CompressedTensorsScheme): + + def __init__(self, strategy: str) -> None: + self.strategy = strategy + # aclnn quant matmul requires to transpose matrix B, set to true by default. + self.transpose_weight = not is_310p() + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + self.output_partition_sizes = output_partition_sizes + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=params_dtype), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=params_dtype), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + input_scale = BasevLLMParameter(data=torch.empty(1, + dtype=params_dtype), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + if x.dtype != torch.int8: + x = quant_per_tensor( + x, + layer.aclnn_input_scale_reciprocal, + None, + ) + + if is_310p(): + # On 300I Duo platform, we need transpose again if + # using nz. This transpose can be skipped in torchair. + output = torch_npu.npu_quant_matmul( + x, + layer.weight.data.transpose(1, 0), + layer.deq_scale, + bias=bias, + output_dtype=layer.params_dtype, + ) + else: + output = torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=bias, + output_dtype=layer.params_dtype, + ) + return output + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), + requires_grad=False) + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + if self.strategy == QuantizationStrategy.TENSOR: + deq_scale = layer.input_scale.data * torch.repeat_interleave( + layer.weight_scale.data, + torch.tensor(self.output_partition_sizes, + dtype=torch.int, + device=layer.weight_scale.data.device)) + else: + deq_scale = layer.input_scale.data * layer.weight_scale.data + layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamic.py b/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamic.py new file mode 100644 index 0000000000..f4a1bc0c21 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamic.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List, Optional + +import torch +import torch_npu +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ + CompressedTensorsScheme +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) + +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Dynamic(CompressedTensorsScheme): + + def __init__(self) -> None: + # aclnn quant matmul requires to transpose matrix B, set to true by default. + self.transpose_weight = True + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), + dtype=params_dtype), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + if not isinstance(x, tuple): + output_dtype = x.dtype + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + else: + output_dtype = layer.weight_scale.dtype + quantized_x, dynamic_scale = x + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + pertoken_scale=dynamic_scale, + bias=bias, + output_dtype=output_dtype, + ) + return output + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.transpose_weight: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # cast quantized weight tensors in NZ format for higher inference speed + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, + ACL_FORMAT_FRACTAL_NZ) + layer.weight_scale.data = layer.weight_scale.data.flatten() diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 5960d2f857..28ed7ed2bb 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -92,7 +92,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - if torch.npu.is_available(): + quant_method = hf_quant_cfg.get("quant_method", None) + if quant_method is None and torch.npu.is_available(): return ASCEND_QUANTIZATION_METHOD return None diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e1afd24a08..24cd16ccd3 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -41,6 +41,7 @@ VllmConfig = None ASCEND_QUANTIZATION_METHOD = "ascend" +COMPRESSED_TENSORS_METHOD = "compressed-tensors" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] REGISTERED_ASCEND_OPS = {} diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e8729925fa..0932ec238f 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -152,6 +152,8 @@ def __init__( # FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170 from vllm.model_executor.layers.linear import \ WEIGHT_LOADER_V2_SUPPORTED + WEIGHT_LOADER_V2_SUPPORTED.append( + "AscendCompressedTensorsLinearMethod") if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")