Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions tests/e2e/multicard/test_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,9 @@

Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
"""
import os
from unittest.mock import patch

from tests.e2e.conftest import VllmRunner

# NZ will cause precision error in Qwen3-Next
# When it is fixed, this set-up can be removed
_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4():
example_prompts = [
"Hello, my name is",
Expand All @@ -46,7 +38,6 @@ def test_models_distributed_Qwen3_NEXT_TP4():
del vllm_model


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
Expand All @@ -66,7 +57,6 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
del vllm_model


@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
example_prompts = [
"Hello, my name is",
Expand Down
1 change: 1 addition & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def test_q_proj_and_k_up_proj(self):
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)

@patch('vllm_ascend.utils._ENABLE_NZ', True)
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_format_cast):
layer = MagicMock(spec=LinearBase)
Expand Down
4 changes: 4 additions & 0 deletions tests/ut/models/test_qwen2_5_vl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import pytest
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -365,6 +367,7 @@ def test_pad_qkv_bias(self, mocker: MockerFixture):
res = attention.pad_qkv_bias(torch.rand((300)))
assert res.shape[0] == 384

@patch('vllm_ascend.utils._ENABLE_NZ', True)
def test_pad_qkv_weight(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
Expand All @@ -377,6 +380,7 @@ def test_pad_qkv_weight(self, mocker: MockerFixture):
res = attention.pad_qkv_weight(torch.rand((300, 300)))
assert res.shape == (384, 300)

@patch('vllm_ascend.utils._ENABLE_NZ', True)
def test_pad_proj_weight(self, mocker: MockerFixture):
attention = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
Expand Down
1 change: 1 addition & 0 deletions tests/ut/quantization/test_w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def build_layer(self,
requires_grad=False)
return layer

@patch('vllm_ascend.utils._ENABLE_NZ', False)
@patch('torch_npu.npu_format_cast')
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
Expand Down
18 changes: 12 additions & 6 deletions tests/ut/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,18 @@ def test_is_310p(self):
self.assertFalse(utils.is_310p())

def test_is_enable_nz(self):
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
1):
self.assertTrue(utils.is_enable_nz())
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
0):
self.assertFalse(utils.is_enable_nz())
# Case when _ENABLE_NZ is already set
utils._ENABLE_NZ = True
self.assertTrue(utils.is_enable_nz())

utils._ENABLE_NZ = False
self.assertFalse(utils.is_enable_nz())

# Case when _ENABLE_NZ is None and vllm_config is not provided
utils._ENABLE_NZ = None
with self.assertRaises(ValueError) as context:
utils.is_enable_nz()
self.assertIn("vllm_config must be provided", str(context.exception))

def test_sleep_mode_enabled(self):
utils._SLEEP_MODE_ENABLED = None
Expand Down
10 changes: 8 additions & 2 deletions tests/ut/worker/test_worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ def setUp(self):
self.model_config_mock = MagicMock(spec=ModelConfig)
self.model_config_mock.dtype = torch.float16
self.model_config_mock.trust_remote_code = False
self.model_config_mock.hf_config = None

self.hf_config_mock = MagicMock()
self.hf_config_mock.model_type = "test_model"
if hasattr(self.hf_config_mock, 'index_topk'):
delattr(self.hf_config_mock, 'index_topk')

self.model_config_mock.hf_config = self.hf_config_mock

self.parallel_config_mock = MagicMock(spec=ParallelConfig)

Expand Down Expand Up @@ -272,9 +278,9 @@ def test_sleep_mode_disabled_raises_error(self, mock_sleep_mode_enabled):

self.assertIn("Sleep mode is not enabled", str(cm.exception))

@patch('vllm_ascend.utils._ENABLE_NZ', False)
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_wake_up_mode_enabled(self, mock_allocator_class,
mock_sleep_mode_enabled):
"""Test wake_up method when sleep mode is enabled"""
Expand Down
11 changes: 9 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
_IS_MOE_MODEL = None
_ENABLE_SP = None
_HAS_LAYER_IDX = None
_ENABLE_NZ = None


def is_310p():
Expand All @@ -69,8 +70,14 @@ def is_310p():
return _IS_310P


def is_enable_nz():
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
def is_enable_nz(vllm_config: Optional[VllmConfig] = None) -> bool:
global _ENABLE_NZ
if _ENABLE_NZ is None:
if not vllm_config:
raise ValueError(
"vllm_config must be provided when _ENABLE_NZ is None")
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
return _ENABLE_NZ


def sleep_mode_enabled():
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
# register patch for vllm
from vllm_ascend.utils import adapt_patch
adapt_patch()
is_enable_nz(vllm_config)
# Register ops when worker init.
from vllm_ascend import ops
ops.register_dummy_fusion_op()
Expand Down
Loading