Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def build(
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
Expand Down Expand Up @@ -766,16 +766,16 @@ def _compute_prefill_context(

iters = len(prefill_metadata.chunked_context.seq_tot)

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len_base, seq_len_chunk])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
Expand All @@ -791,7 +791,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(q_nope.device),
seq_len_chunk,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
Expand Down
10 changes: 5 additions & 5 deletions vllm_ascend/torchair/torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def build(
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens=chunk_seq_lens.cpu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
Expand Down Expand Up @@ -777,16 +777,16 @@ def _compute_prefill_context(
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
seq_len_base = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len_base, seq_len_chunk])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tensor seq_len is constructed on the CPU and subsequently passed to the torch_npu.atb.npu_ring_mla kernel at line 824. This might lead to implicit synchronization and performance degradation, which is similar to the issue this PR aims to fix for npu_paged_cache_load.

In vllm_ascend/attention/mla_v1.py, a similar pattern was addressed by ensuring the tensor is on the NPU device before being used in the kernel. To maintain consistency and prevent potential performance bottlenecks, seq_len should be moved to the NPU.

I suggest moving the tensor to the correct device during its creation.

Suggested change
seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len_base, seq_len_chunk])
seq_len_chunk = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len_base, seq_len_chunk]).to(q_nope.device, non_blocking=True)

kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
Expand All @@ -802,7 +802,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(query.device),
seq_len_chunk,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
Expand Down
Loading