Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/ut/ops/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -41,7 +42,9 @@ def context(self, mocker: MockerFixture):
# Test case for the most common and basic scenario
@pytest.mark.parametrize(
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
def test_forward_oot_basic(self, residual):
@patch("torch.ops.vllm.maybe_chunk_residual")
def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual):
mock_maybe_chunk_residual.side_effect = lambda x, residual: residual
layer = RMSNorm(hidden_size=8, eps=1e-05)
x = torch.randn(4, 8, dtype=torch.float16)
if residual is not None:
Expand Down Expand Up @@ -105,6 +108,8 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
mock_forward_context.num_hidden_layers = num_hidden_layers
mock_forward_context.fusion_linear = "gate_up_dense"
mock_forward_context.weight_prefetch_method = None
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
lambda x, residual: residual)

# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def __init__(self, vllm_config):
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
if self.enable_shared_expert_dp:
from vllm_ascend.utils import enable_sp
assert enable_sp(
vllm_config), "shared_expert_dp requires enable_sp=True."
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.recompute_scheduler_enable = additional_config.get(
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,8 @@ def forward(
forward_context = get_forward_context()
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
hidden_states, kv_cache, attn_metadata)
else:
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def forward_oot(
import torch_npu

if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear,
Expand Down
28 changes: 28 additions & 0 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
Expand All @@ -15,6 +16,27 @@
from vllm_ascend.utils import npu_stream_switch, prefetch_stream


def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return residual

if x.size(0) != residual.size(0):
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, ("Currently, this situation only occurs "
"when sp is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]

return residual


def _maybe_all_gather_and_maybe_unpad_impl(
x: torch.Tensor,
label: bool,
Expand Down Expand Up @@ -260,6 +282,12 @@ def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
return output


direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fake_impl for the maybe_chunk_residual custom operator is incorrect. It returns x, but the actual function returns a modified residual. While x might have the correct shape for shape inference, this is semantically wrong and breaks the convention used by other fake implementations in this file, which return placeholder tensors (e.g., from torch.empty). This can lead to subtle bugs or incorrect behavior when using torch.compile's meta backend, as it misrepresents the data flow.

A more correct approach would be to return a placeholder tensor that has the correct shape and dtype, similar to other fake implementations in this file.

Suggested change
fake_impl=lambda x, residual: x,
fake_impl=lambda x, residual: torch.empty_like(x, dtype=residual.dtype),

mutates_args=[],
dispatch_key="PrivateUse1")

direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
Expand Down
4 changes: 1 addition & 3 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
Expand Down Expand Up @@ -337,8 +337,6 @@ def get_attn_backend_cls(
ascend_config = get_ascend_config()

if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and not use_sparse:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sparse:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"

Expand Down
25 changes: 20 additions & 5 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def dummy_run(self,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
positions = positions.squeeze(-1)
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
previous_hidden_states)
self.model(input_ids=input_ids,
positions=positions,
hidden_states=previous_hidden_states)
Expand Down Expand Up @@ -470,11 +475,21 @@ def _propose(
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens]
)
input_ids = self.input_ids[:num_input_tokens]
positions = self.positions[:num_input_tokens]
hidden_states = self.hidden_states[:num_input_tokens]

# positions [N] -> [N, 1] for padding
positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(
positions)
positions = positions.squeeze(-1)

hidden_states = self.model(input_ids=input_ids,
positions=positions,
hidden_states=hidden_states)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), True)

num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
Expand Down
Loading