diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index cf3382318dd..e19ab1e115c 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -20,17 +20,9 @@ Run `pytest tests/e2e/multicard/test_qwen3_next.py`. """ -import os -from unittest.mock import patch - from tests.e2e.conftest import VllmRunner -# NZ will cause precision error in Qwen3-Next -# When it is fixed, this set-up can be removed -_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ" - -@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) def test_models_distributed_Qwen3_NEXT_TP4(): example_prompts = [ "Hello, my name is", @@ -46,7 +38,6 @@ def test_models_distributed_Qwen3_NEXT_TP4(): del vllm_model -@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): example_prompts = [ "Hello, my name is", @@ -66,7 +57,6 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY(): del vllm_model -@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"}) def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY(): example_prompts = [ "Hello, my name is", diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 8d15bcaab16..102b6febe9d 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -416,6 +416,7 @@ def test_q_proj_and_k_up_proj(self): self.assertEqual(q_pe.shape[1], self.impl.num_heads) self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) + @patch('vllm_ascend.utils._ENABLE_NZ', True) @patch('torch_npu.npu_format_cast') def test_process_weights_after_loading(self, mock_format_cast): layer = MagicMock(spec=LinearBase) diff --git a/tests/ut/models/test_qwen2_5_vl.py b/tests/ut/models/test_qwen2_5_vl.py index 7111aaed6c8..b4f06803706 100644 --- a/tests/ut/models/test_qwen2_5_vl.py +++ b/tests/ut/models/test_qwen2_5_vl.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest import torch import torch.nn.functional as F @@ -365,6 +367,7 @@ def test_pad_qkv_bias(self, mocker: MockerFixture): res = attention.pad_qkv_bias(torch.rand((300))) assert res.shape[0] == 384 + @patch('vllm_ascend.utils._ENABLE_NZ', True) def test_pad_qkv_weight(self, mocker: MockerFixture): attention = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") @@ -377,6 +380,7 @@ def test_pad_qkv_weight(self, mocker: MockerFixture): res = attention.pad_qkv_weight(torch.rand((300, 300))) assert res.shape == (384, 300) + @patch('vllm_ascend.utils._ENABLE_NZ', True) def test_pad_proj_weight(self, mocker: MockerFixture): attention = self.init_vision_transformer(mocker) mocker.patch("torch.nn.Module.__setattr__") diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 2116b0c1688..42c3c933625 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -260,6 +260,7 @@ def build_layer(self, requires_grad=False) return layer + @patch('vllm_ascend.utils._ENABLE_NZ', False) @patch('torch_npu.npu_format_cast') @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 147e8378ddc..8d34547bfb7 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -46,12 +46,18 @@ def test_is_310p(self): self.assertFalse(utils.is_310p()) def test_is_enable_nz(self): - with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", - 1): - self.assertTrue(utils.is_enable_nz()) - with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", - 0): - self.assertFalse(utils.is_enable_nz()) + # Case when _ENABLE_NZ is already set + utils._ENABLE_NZ = True + self.assertTrue(utils.is_enable_nz()) + + utils._ENABLE_NZ = False + self.assertFalse(utils.is_enable_nz()) + + # Case when _ENABLE_NZ is None and vllm_config is not provided + utils._ENABLE_NZ = None + with self.assertRaises(ValueError) as context: + utils.is_enable_nz() + self.assertIn("vllm_config must be provided", str(context.exception)) def test_sleep_mode_enabled(self): utils._SLEEP_MODE_ENABLED = None diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 2fbad2f8102..c3c5d462955 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -20,7 +20,13 @@ def setUp(self): self.model_config_mock = MagicMock(spec=ModelConfig) self.model_config_mock.dtype = torch.float16 self.model_config_mock.trust_remote_code = False - self.model_config_mock.hf_config = None + + self.hf_config_mock = MagicMock() + self.hf_config_mock.model_type = "test_model" + if hasattr(self.hf_config_mock, 'index_topk'): + delattr(self.hf_config_mock, 'index_topk') + + self.model_config_mock.hf_config = self.hf_config_mock self.parallel_config_mock = MagicMock(spec=ParallelConfig) @@ -272,9 +278,9 @@ def test_sleep_mode_disabled_raises_error(self, mock_sleep_mode_enabled): self.assertIn("Sleep mode is not enabled", str(cm.exception)) + @patch('vllm_ascend.utils._ENABLE_NZ', False) @patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled") @patch("vllm_ascend.worker.worker_v1.CaMemAllocator") - @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}) def test_wake_up_mode_enabled(self, mock_allocator_class, mock_sleep_mode_enabled): """Test wake_up method when sleep mode is enabled""" diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ec3c6f03bf9..8dd7c6bcf1c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -59,6 +59,7 @@ _IS_MOE_MODEL = None _ENABLE_SP = None _HAS_LAYER_IDX = None +_ENABLE_NZ = None def is_310p(): @@ -69,8 +70,14 @@ def is_310p(): return _IS_310P -def is_enable_nz(): - return envs_ascend.VLLM_ASCEND_ENABLE_NZ +def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool: + global _ENABLE_NZ + if _ENABLE_NZ is None: + if not vllm_config: + raise ValueError( + "vllm_config must be provided when _ENABLE_NZ is None") + _ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next" + return _ENABLE_NZ def sleep_mode_enabled(): diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index a90883cdcb8..145f38a14dc 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -87,6 +87,7 @@ def __init__( # register patch for vllm from vllm_ascend.utils import adapt_patch adapt_patch() + is_enable_nz(vllm_config) # Register ops when worker init. from vllm_ascend import ops ops.register_dummy_fusion_op()