|
16 | 16 | from vllm.distributed.parallel_state import get_tp_group, get_world_group |
17 | 17 | from vllm.platforms import current_platform |
18 | 18 | from vllm.v1.core.sched.output import SchedulerOutput |
19 | | -from vllm.v1.request import Request |
20 | 19 |
|
21 | 20 | from ucm.logger import init_logger |
22 | 21 | from ucm.shared.metrics import ucmmonitor |
|
29 | 28 | from vllm.attention.backends.abstract import AttentionMetadata |
30 | 29 | from vllm.forward_context import ForwardContext |
31 | 30 | from vllm.v1.core.kv_cache_manager import KVCacheBlocks |
| 31 | + from vllm.v1.request import Request |
32 | 32 |
|
33 | 33 | logger = init_logger(__name__) |
34 | 34 |
|
@@ -178,11 +178,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): |
178 | 178 | self.metrics_config, |
179 | 179 | ) |
180 | 180 | self.monitor = ucmmonitor.StatsMonitor.get_instance() |
181 | | - self.synchronize = ( |
182 | | - torch.cuda.synchronize |
183 | | - if current_platform.is_cuda_alike() |
184 | | - else torch.npu.synchronize |
185 | | - ) |
| 181 | + |
| 182 | + self.synchronize = ( |
| 183 | + torch.cuda.synchronize |
| 184 | + if current_platform.is_cuda_alike() |
| 185 | + else torch.npu.synchronize |
| 186 | + ) |
186 | 187 |
|
187 | 188 | # invlalid block ids due to load errors |
188 | 189 | self._invalid_block_ids: set[int] = set() |
@@ -558,7 +559,9 @@ def wait_for_save(self) -> None: |
558 | 559 | # TODO support PP |
559 | 560 | if (self.is_mla or self.is_dsa) and self.global_rank != 0: |
560 | 561 | return |
561 | | - if self.metrics_config: |
| 562 | + if self.metrics_config or current_platform.device_type == "npu": |
| 563 | + # When use vllm_ascend, we should add synchronize here, otherwise accuracy problem will raise |
| 564 | + # This has already been fixed in the latest main branch of vllm_ascend, so synchronize will no longer be needed in future versions. |
562 | 565 | self.synchronize() |
563 | 566 |
|
564 | 567 | metadata = self._get_connector_metadata() |
|
0 commit comments