From d6ba512b68166373b51e74d7272470183600a5c4 Mon Sep 17 00:00:00 2001 From: underfituu Date: Thu, 6 Nov 2025 11:17:27 +0800 Subject: [PATCH 1/4] dev_bugfix_prefix_cache_pref Signed-off-by: underfituu --- vllm_ascend/attention/mla_v1.py | 10 +++++----- vllm_ascend/torchair/torchair_mla.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 177d91bc8a5..74bd349b9d5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -370,7 +370,7 @@ def build( starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -766,7 +766,7 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -774,8 +774,8 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -791,7 +791,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + seq_len_chunk, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 32543a84d80..a1a6b035e77 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -461,7 +461,7 @@ def build( starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens=chunk_seq_lens.cpu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -777,7 +777,7 @@ def _compute_prefill_context( q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -785,8 +785,8 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len = torch.stack([seq_len1, seq_len2]) + seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -802,7 +802,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(query.device), + seq_len_chunk, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, From ba2d09dd6ca710004477ef481533e15fc3e353f3 Mon Sep 17 00:00:00 2001 From: underfituu Date: Thu, 6 Nov 2025 12:43:23 +0800 Subject: [PATCH 2/4] fix_stack_bug Signed-off-by: underfituu --- vllm_ascend/attention/mla_v1.py | 7 +++++-- vllm_ascend/torchair/torchair_mla.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 74bd349b9d5..e22bb75e82b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -80,6 +80,7 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -370,7 +371,8 @@ def build( starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens.npu(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -775,6 +777,7 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[i] seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, @@ -791,7 +794,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len_chunk, + seq_len_chunk_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index a1a6b035e77..69245c56446 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -72,6 +72,7 @@ class TorchairChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -461,7 +462,8 @@ def build( starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens.cpu(), + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -786,6 +788,7 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[i] seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, @@ -802,7 +805,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len_chunk, + seq_len_chunk_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, From 007c0c27228808c69adaaab99c188d3164006eb9 Mon Sep 17 00:00:00 2001 From: underfituu Date: Thu, 6 Nov 2025 14:13:42 +0800 Subject: [PATCH 3/4] fix_lint Signed-off-by: underfituu --- vllm_ascend/attention/mla_v1.py | 8 +++++--- vllm_ascend/torchair/torchair_mla.py | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e22bb75e82b..26ec81a064b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -768,7 +768,8 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + seq_len_base = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -776,8 +777,9 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[i] + seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lenss[i] + seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 69245c56446..3c6968dc27e 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -779,7 +779,8 @@ def _compute_prefill_context( q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + seq_len_base = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -788,7 +789,8 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[i] + seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] seq_len = torch.stack([seq_len_base, seq_len_chunk]) kv_c_normed = torch.empty(toks, num_heads, From 57de64061130bd9afb3754ec8b30429142ff2f65 Mon Sep 17 00:00:00 2001 From: underfituu Date: Thu, 6 Nov 2025 15:22:02 +0800 Subject: [PATCH 4/4] fix_ut Signed-off-by: underfituu --- tests/ut/attention/test_mla_v1.py | 6 +++++- tests/ut/torchair/test_torchair_mla.py | 6 +++++- vllm_ascend/attention/mla_v1.py | 13 +++++++------ vllm_ascend/torchair/torchair_mla.py | 13 +++++++------ 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c55234bc3d9..9851e51ce33 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -103,6 +104,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLADecodeMetadata(TestBase): @@ -428,6 +431,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 3dd1d2f7f6a..1f108b3eb06 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -86,7 +86,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens) metadata = AscendMLATorchairPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), @@ -107,6 +108,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) self.assertIs(metadata.chunked_context.workspace, workspace) self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) + self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, + chunk_seq_lens) class TestAscendMLATorchairDecodeMetadata(TestBase): @@ -661,6 +664,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load): chunk_ctx = MagicMock() chunk_ctx.seq_tot = [8] chunk_ctx.chunk_seq_lens = [torch.tensor([8])] + chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] chunk_ctx.starts = [torch.tensor([0])] prefill_meta = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 26ec81a064b..4044126d312 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -768,8 +768,8 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) - seq_len_base = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -777,10 +777,11 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lenss[i] - seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ i] - seq_len = torch.stack([seq_len_base, seq_len_chunk]) + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -796,7 +797,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len_chunk_npu, + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 3c6968dc27e..3ffcdfbaf1c 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -779,8 +779,8 @@ def _compute_prefill_context( q_pe = query[..., self.qk_nope_head_dim:] q_nope = query[..., :self.qk_nope_head_dim] - seq_len_base = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32) + current_seq_len = torch.tensor(prefill_metadata.query_lens, + dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) @@ -788,10 +788,11 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i] - seq_len_chunk_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ i] - seq_len = torch.stack([seq_len_base, seq_len_chunk]) + context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] + seq_len = torch.stack([current_seq_len, context_seq_len]) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -807,7 +808,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len_chunk_npu, + context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,