Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)

metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
Expand All @@ -103,6 +104,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)


class TestAscendMLADecodeMetadata(TestBase):
Expand Down Expand Up @@ -428,6 +431,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]

prefill_meta = MagicMock()
Expand Down
6 changes: 5 additions & 1 deletion tests/ut/torchair/test_torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)

metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
Expand All @@ -107,6 +108,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)


class TestAscendMLATorchairDecodeMetadata(TestBase):
Expand Down Expand Up @@ -661,6 +664,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]

prefill_meta = MagicMock()
Expand Down
14 changes: 10 additions & 4 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ChunkedContextMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor

attn_mask: torch.Tensor
query_lens: torch.Tensor
Expand Down Expand Up @@ -371,6 +372,7 @@ def build(
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
Expand Down Expand Up @@ -766,16 +768,20 @@ def _compute_prefill_context(

iters = len(prefill_metadata.chunked_context.seq_tot)

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
Expand All @@ -791,7 +797,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(q_nope.device),
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
Expand Down
14 changes: 10 additions & 4 deletions vllm_ascend/torchair/torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TorchairChunkedContextMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor

attn_mask: torch.Tensor
query_lens: torch.Tensor
Expand Down Expand Up @@ -462,6 +463,7 @@ def build(
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
Expand Down Expand Up @@ -777,16 +779,20 @@ def _compute_prefill_context(
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
Expand All @@ -802,7 +808,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(query.device),
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
Expand Down