diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index 1c5c0a92608..86303fb61a6 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -50,6 +50,8 @@ GET_META_MSG = b"get_meta_msg" DONE_SENDING_MSG = b"done_sending_msg" +HEARTBEAT_MSG = b"hb_ping" +HEARTBEAT_ACK = b"hb_pong" class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): @@ -323,6 +325,12 @@ def run(self): if msg[0] == GET_META_MSG: logger.info("Got GET META INFO for request %s", msg[0]) sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == HEARTBEAT_MSG: + # Heartbeat: reply immediately + try: + sock.send_multipart((identity, b"", HEARTBEAT_ACK)) + except Exception: + pass elif msg[0] == DONE_SENDING_MSG: logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) @@ -337,6 +345,137 @@ def run(self): logger.error("Failed to decode message: %s", e) +class HeartbeatMonitor(threading.Thread): + """ + Maintain heartbeat to multiple remote (host, port) REQ endpoints. + Non-blocking API: add_target(), is_alive(). + """ + + def __init__(self, poll_interval: float = 360, timeout: float = 1.0): + super().__init__(daemon=True, name="HeartbeatMonitor") + self.poll_interval = poll_interval + self.timeout = timeout + self._targets_lock = threading.Lock() + self._targets: dict[tuple[str, int], zmq.Socket] = {} + self._alive: dict[tuple[str, int], float] = {} + self._status: dict[tuple[str, int], bool] = {} + self._counters: dict[tuple[str, int], dict[str, int]] = {} + self._stop = threading.Event() + self._poller = zmq.Poller() # type: ignore + self._ctx = zmq.Context() # type: ignore + + def add_target(self, host: str, port: int): + key = (host, port) + with self._targets_lock: + if key in self._targets: + return + path = make_zmq_path("tcp", host, port) + sock = make_zmq_socket( + ctx=self._ctx, + path=path, + socket_type=zmq.REQ, # type: ignore + bind=False) + # avoid blocking forever + sock.setsockopt(zmq.SNDTIMEO, + int(self.timeout * 1000)) # type: ignore + sock.setsockopt(zmq.RCVTIMEO, + int(self.timeout * 1000)) # type: ignore + # optional: TCP keepalive (helps at OS layer, not required) + # sock.setsockopt(zmq.TCP_KEEPALIVE, 1) + self._targets[key] = sock + self._poller.register(sock, zmq.POLLIN) # type: ignore + # pessimistic init + self._alive[key] = 0.0 + self._status[key] = False + self._counters[key] = {"sent": 0, "recv": 0, "timeout": 0} + logger.info(f"[HB] add_target host={host} port={port}") + logger.debug( + f"[HB] counters init for {host}:{port} -> {self._counters[key]}" + ) + + def is_alive(self, host: str, port: int) -> bool: + key = (host, port) + with self._targets_lock: + last_ok = self._alive.get(key, 0.0) + # consider alive if seen within ~3 intervals + return (time.time() - last_ok) < (3 * self.poll_interval + + self.timeout) + + def remove_target(self, host: str, port: int): + """Remove a target from heartbeat monitor and close its socket.""" + key = (host, port) + with self._targets_lock: + sock = self._targets.pop(key, None) + if sock is not None: + try: + self._poller.unregister(sock) # type: ignore + except Exception: + pass + try: + sock.close(linger=0) # type: ignore + except Exception: + pass + # clean book-keeping + self._alive.pop(key, None) + self._status.pop(key, None) + self._counters.pop(key, None) + logger.info(f"[HB] removed target {host}:{port}") + + def run(self): + encoder = msgspec.msgpack.Encoder() + while True: + with self._targets_lock: + items = list(self._targets.items()) + # send pings + for key, sock in items: + try: + data = encoder.encode((HEARTBEAT_MSG, b"")) + sock.send(data, flags=zmq.DONTWAIT) # type: ignore + try: + self._counters[key]["sent"] += 1 + except KeyError: + self._counters[key] = { + "sent": 1, + "recv": 0, + "timeout": 0 + } + host, port = key + logger.debug( + f"[HB->] ping {host}:{port} sent={self._counters[key]['sent']}" + ) + except Exception: + # send error: keep _alive as is; next round will try again + pass + try: + events = dict(self._poller.poll(int(self.timeout * 1000))) + except Exception: + events = {} + now = time.time() + for key, sock in items: + if sock in events: + try: + reply = sock.recv(flags=zmq.DONTWAIT) # type: ignore + if reply == HEARTBEAT_ACK: + with self._targets_lock: + self._alive[key] = now + self._counters[key]["recv"] = self._counters.get( + key, {}).get("recv", 0) + 1 + except Exception: + pass + time.sleep(self.poll_interval) + # cleanup + with self._targets_lock: + for sock in self._targets.values(): + try: + sock.close(linger=0) # type: ignore + except Exception: + pass + try: + self._ctx.term() # type: ignore + except Exception: + pass + + class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): def __init__(self): @@ -705,6 +844,37 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): deque) self.remote_poller = zmq.Poller() # type: ignore self.timeout = 1.0 # seconds + self._peer_meta_lock = threading.Lock() + # --- Heartbeat monitor --- + self.heartbeat = HeartbeatMonitor(poll_interval=1.0, + timeout=self.timeout) + self.heartbeat.start() + # Track current peer (host, port) per remote_engine_id to detect switches + self._peer_lock = threading.Lock() + self.current_peer: dict[str, tuple[str, int]] = {} + # Background watcher: print alive status every 5s + self._watch_stop = threading.Event() + self._watch_interval = 60.0 # seconds + self._watch_thread = threading.Thread(target=self._peer_watch_loop, + name="PeerWatcher", + daemon=True) + self._watch_thread.start() + logger.info( + f"[WATCH] Peer watcher started, interval={self._watch_interval}s") + + def _register_peer(self, eng: Optional[str], host: Optional[str], + port: Optional[int]) -> None: + if eng is None or host is None or port is None: + return + with self._peer_lock: + self.current_peer[eng] = (host, port) + try: + self.heartbeat.add_target(host, port) # type: ignore[arg-type] + except Exception as _e: + logger.warning( + f"[WATCH] add_target failed for {host}:{port}: {_e}") + else: + logger.info(f"[WATCH] registered peer engine={eng} {host}:{port}") def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -989,6 +1159,48 @@ def _get_remote_socket( self.remote_poller.register(sock, zmq.POLLIN) # type: ignore return sock + def _peer_watch_loop(self): + while not self._watch_stop.is_set(): + try: + with self._peer_lock: + snapshot = list(self.current_peer.items()) + if not snapshot: + logger.info("[WATCH] no peers registered yet") + for eng, (host, port) in snapshot: + try: + alive = self.heartbeat.is_alive(host, port) + except Exception as _e: + logger.warning( + f"[WATCH] is_alive exception for {host}:{port}: {_e}" + ) + alive = False + logger.info( + f"[WATCH] engine={eng} peer {host}:{port} alive={alive}" + ) + if not alive: + try: + with self._peer_meta_lock: + self.remote_te_port.get(eng, + {}).pop(port, None) + self.remote_kv_caches_base_addr.get( + eng, {}).pop(port, None) + logger.warning( + f"[WATCH] purged meta for down peer: engine={eng} {host}:{port}" + ) + except Exception as _e: + logger.warning( + f"[WATCH] purge meta error for {host}:{port}: {_e}" + ) + try: + self.heartbeat.remove_target(host, port) + except Exception as _e: + logger.warning( + f"[WATCH] heartbeat.remove_target failed for {host}:{port}: {_e}" + ) + except Exception as _e: + logger.warning(f"[WATCH] loop exception: {_e}") + time.sleep(self._watch_interval) + def update_decoder_info(self, req_id, req_meta): req_meta_update = copy.deepcopy(req_meta) if self.pd_tp_ratio > 1: @@ -1018,6 +1230,9 @@ def update_decoder_info(self, req_id, req_meta): logger.info( f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" ) + self._register_peer(req_meta_update.remote_engine_id, + req_meta_update.remote_host, + req_meta_update.remote_port) req_meta_update.remote_te_rpc_port = self.remote_te_port[ req_meta_update.remote_engine_id][req_meta_update.remote_port] req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[