diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index e742852e91..fd53f88efc 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -157,6 +157,17 @@ def is_layer_skipped_ascend( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " "to have the same precision.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = [ + self.quant_description[layer] + for layer in self.quant_description if prefix in layer + ] + is_skipped = any( + quantization == "FLOAT" + for quantization in experts_quant_description + ) else: is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 6d914c0dad..749f5b257d 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -52,6 +52,16 @@ def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, f"Not all shards of {prefix} are quantized with same quant type." f"Shard {proj_name} uses {shard_quant_type}, but another shard" f"use {quant_type}. Please check quantization config.") + elif "experts" in prefix: + # For the experts' prefix (e.g., "model.layers.3.mlp.experts") + # Assume all experts within the same MLP use the same quantization method + experts_quant_description = set( + quant_description[layer] + for layer in quant_description if prefix in layer + ) + if not len(experts_quant_description) == 1: + raise RuntimeError(f"{prefix} has different quantization type: {experts_quant_description}.") + quant_type = experts_quant_description.pop() else: quant_type = quant_description[prefix + '.weight'] return quant_type