diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c55234bc3d9..9851e51ce33 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -103,6 +104,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLADecodeMetadata(TestBase): @@ -428,6 +431,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 3dd1d2f7f6a..1f108b3eb06 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -86,7 +86,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLATorchairPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -107,6 +108,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLATorchairDecodeMetadata(TestBase): @@ -661,6 +664,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 177d91bc8a5..4044126d312 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -80,6 +80,7 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -371,6 +372,7 @@ def build( seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -766,7 +768,8 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -774,8 +777,11 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -791,7 +797,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 32543a84d80..3ffcdfbaf1c 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -72,6 +72,7 @@ class TorchairChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -462,6 +463,7 @@ def build( seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -777,7 +779,8 @@ def _compute_prefill_context( q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -785,8 +788,11 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ + i] + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -802,7 +808,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,