diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 9da7eb5b1ad..9b1736c42b4 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch import pytest import torch @@ -41,7 +42,9 @@ def context(self, mocker: MockerFixture): # Test case for the most common and basic scenario @pytest.mark.parametrize( "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) - def test_forward_oot_basic(self, residual): + @patch("torch.ops.vllm.maybe_chunk_residual") + def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual): + mock_maybe_chunk_residual.side_effect = lambda x, residual: residual layer = RMSNorm(hidden_size=8, eps=1e-05) x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: @@ -105,6 +108,8 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): mock_forward_context.num_hidden_layers = num_hidden_layers mock_forward_context.fusion_linear = "gate_up_dense" mock_forward_context.weight_prefetch_method = None + mocker.patch("torch.ops.vllm.maybe_chunk_residual", + lambda x, residual: residual) # Ensure fusion and layer_idx increment are handled correctly x = torch.randn(4, 8, dtype=torch.float16) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index b0973b15c29..e5092f05a04 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -69,6 +69,10 @@ def __init__(self, vllm_config): self.enable_shared_expert_dp = additional_config.get( "enable_shared_expert_dp", False ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + if self.enable_shared_expert_dp: + from vllm_ascend.utils import enable_sp + assert enable_sp( + vllm_config), "shared_expert_dp requires enable_sp=True." self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) self.recompute_scheduler_enable = additional_config.get( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 177d91bc8a5..3b21697206a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1248,6 +1248,8 @@ def forward( forward_context = get_forward_context() if (self.enable_mlapo and (attn_metadata is None or not forward_context.with_prefill)): + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( hidden_states, kv_cache, attn_metadata) else: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 6b89f4a5c71..50e62976999 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -109,6 +109,7 @@ def forward_oot( import torch_npu if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( self, x, residual, self.next_need_quant_fusion_linear, diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 69e220ea6e8..ba0f0e92dea 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -2,6 +2,7 @@ import torch.nn.functional as F import torch_npu from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -15,6 +16,27 @@ from vllm_ascend.utils import npu_stream_switch, prefetch_stream +def _maybe_chunk_residual_impl(x: torch.Tensor, + residual: torch.Tensor) -> torch.Tensor: + try: + forward_context = get_forward_context() + except AssertionError: + return residual + + if x.size(0) != residual.size(0): + sp_enabled = forward_context.sp_enabled + assert sp_enabled is True, ("Currently, this situation only occurs " + "when sp is enabled") + pad_size = forward_context.pad_size + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] + + return residual + + def _maybe_all_gather_and_maybe_unpad_impl( x: torch.Tensor, label: bool, @@ -260,6 +282,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, return output +direct_register_custom_op(op_name="maybe_chunk_residual", + op_func=_maybe_chunk_residual_impl, + fake_impl=lambda x, residual: x, + mutates_args=[], + dispatch_key="PrivateUse1") + direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=_maybe_all_gather_and_maybe_unpad_fake, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 449c3b07535..e3777b8b235 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -284,7 +284,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv" - if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp: + if ascend_config.torchair_graph_config.enabled: parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" @@ -337,8 +337,6 @@ def get_attn_backend_cls( ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and not use_sparse: - return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" if use_mla and use_sparse: return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d8b25e8cd03..8f7a4120a33 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -185,6 +185,11 @@ def dummy_run(self, kv_caches=self.runner.kv_caches[-1:], spec_step_idx=0) else: + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce(positions) + positions = positions.squeeze(-1) + previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + previous_hidden_states) self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) @@ -470,11 +475,21 @@ def _propose( spec_step_idx=0, **model_kwargs) else: - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens] - ) + input_ids = self.input_ids[:num_input_tokens] + positions = self.positions[:num_input_tokens] + hidden_states = self.hidden_states[:num_input_tokens] + + # positions [N] -> [N, 1] for padding + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce( + positions) + positions = positions.squeeze(-1) + + hidden_states = self.model(input_ids=input_ids, + positions=positions, + hidden_states=hidden_states) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) num_indices = last_token_indices.shape[0] if lmhead_tp_enable():