-
Notifications
You must be signed in to change notification settings - Fork 617
[P/D]Add a heartbeat mechanism to PD separation #4071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,8 @@ | |
|
|
||
| GET_META_MSG = b"get_meta_msg" | ||
| DONE_SENDING_MSG = b"done_sending_msg" | ||
| HEARTBEAT_MSG = b"hb_ping" | ||
| HEARTBEAT_ACK = b"hb_pong" | ||
|
|
||
|
|
||
| class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): | ||
|
|
@@ -323,6 +325,12 @@ | |
| if msg[0] == GET_META_MSG: | ||
| logger.info("Got GET META INFO for request %s", msg[0]) | ||
| sock.send_multipart((identity, b"", encoded_data)) | ||
| elif msg[0] == HEARTBEAT_MSG: | ||
| # Heartbeat: reply immediately | ||
| try: | ||
| sock.send_multipart((identity, b"", HEARTBEAT_ACK)) | ||
| except Exception: | ||
| pass | ||
| elif msg[0] == DONE_SENDING_MSG: | ||
| logger.debug("Got DONE_RECVING_MSG for request %s", | ||
| msg[1]) | ||
|
|
@@ -337,6 +345,137 @@ | |
| logger.error("Failed to decode message: %s", e) | ||
|
|
||
|
|
||
| class HeartbeatMonitor(threading.Thread): | ||
| """ | ||
| Maintain heartbeat to multiple remote (host, port) REQ endpoints. | ||
| Non-blocking API: add_target(), is_alive(). | ||
| """ | ||
|
|
||
| def __init__(self, poll_interval: float = 360, timeout: float = 1.0): | ||
| super().__init__(daemon=True, name="HeartbeatMonitor") | ||
| self.poll_interval = poll_interval | ||
| self.timeout = timeout | ||
| self._targets_lock = threading.Lock() | ||
| self._targets: dict[tuple[str, int], zmq.Socket] = {} | ||
|
Check failure on line 359 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| self._alive: dict[tuple[str, int], float] = {} | ||
| self._status: dict[tuple[str, int], bool] = {} | ||
| self._counters: dict[tuple[str, int], dict[str, int]] = {} | ||
| self._stop = threading.Event() | ||
| self._poller = zmq.Poller() # type: ignore | ||
| self._ctx = zmq.Context() # type: ignore | ||
|
|
||
| def add_target(self, host: str, port: int): | ||
| key = (host, port) | ||
| with self._targets_lock: | ||
| if key in self._targets: | ||
| return | ||
| path = make_zmq_path("tcp", host, port) | ||
| sock = make_zmq_socket( | ||
| ctx=self._ctx, | ||
| path=path, | ||
| socket_type=zmq.REQ, # type: ignore | ||
| bind=False) | ||
| # avoid blocking forever | ||
| sock.setsockopt(zmq.SNDTIMEO, | ||
|
Check failure on line 379 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| int(self.timeout * 1000)) # type: ignore | ||
| sock.setsockopt(zmq.RCVTIMEO, | ||
|
Check failure on line 381 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| int(self.timeout * 1000)) # type: ignore | ||
| # optional: TCP keepalive (helps at OS layer, not required) | ||
| # sock.setsockopt(zmq.TCP_KEEPALIVE, 1) | ||
| self._targets[key] = sock | ||
| self._poller.register(sock, zmq.POLLIN) # type: ignore | ||
| # pessimistic init | ||
| self._alive[key] = 0.0 | ||
| self._status[key] = False | ||
| self._counters[key] = {"sent": 0, "recv": 0, "timeout": 0} | ||
| logger.info(f"[HB] add_target host={host} port={port}") | ||
| logger.debug( | ||
| f"[HB] counters init for {host}:{port} -> {self._counters[key]}" | ||
| ) | ||
|
|
||
| def is_alive(self, host: str, port: int) -> bool: | ||
|
Check failure on line 396 in vllm_ascend/distributed/mooncake_layerwise_connector.py
|
||
| key = (host, port) | ||
| with self._targets_lock: | ||
| last_ok = self._alive.get(key, 0.0) | ||
| # consider alive if seen within ~3 intervals | ||
| return (time.time() - last_ok) < (3 * self.poll_interval + | ||
| self.timeout) | ||
|
|
||
| def remove_target(self, host: str, port: int): | ||
| """Remove a target from heartbeat monitor and close its socket.""" | ||
| key = (host, port) | ||
| with self._targets_lock: | ||
| sock = self._targets.pop(key, None) | ||
| if sock is not None: | ||
| try: | ||
| self._poller.unregister(sock) # type: ignore | ||
| except Exception: | ||
| pass | ||
| try: | ||
| sock.close(linger=0) # type: ignore | ||
| except Exception: | ||
| pass | ||
| # clean book-keeping | ||
| self._alive.pop(key, None) | ||
| self._status.pop(key, None) | ||
| self._counters.pop(key, None) | ||
| logger.info(f"[HB] removed target {host}:{port}") | ||
|
|
||
| def run(self): | ||
| encoder = msgspec.msgpack.Encoder() | ||
| while True: | ||
| with self._targets_lock: | ||
| items = list(self._targets.items()) | ||
| # send pings | ||
| for key, sock in items: | ||
| try: | ||
| data = encoder.encode((HEARTBEAT_MSG, b"")) | ||
| sock.send(data, flags=zmq.DONTWAIT) # type: ignore | ||
| try: | ||
| self._counters[key]["sent"] += 1 | ||
| except KeyError: | ||
| self._counters[key] = { | ||
| "sent": 1, | ||
| "recv": 0, | ||
| "timeout": 0 | ||
| } | ||
|
Comment on lines
+434
to
+441
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This self._counters[key]["sent"] += 1 |
||
| host, port = key | ||
| logger.debug( | ||
| f"[HB->] ping {host}:{port} sent={self._counters[key]['sent']}" | ||
| ) | ||
| except Exception: | ||
| # send error: keep _alive as is; next round will try again | ||
| pass | ||
|
Comment on lines
+446
to
+448
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silently swallowing all exceptions with except Exception as e:
# send error: keep _alive as is; next round will try again
host, port = key
logger.warning(f"[HB->] failed to send ping to {host}:{port}: {e}") |
||
| try: | ||
| events = dict(self._poller.poll(int(self.timeout * 1000))) | ||
| except Exception: | ||
| events = {} | ||
| now = time.time() | ||
| for key, sock in items: | ||
| if sock in events: | ||
| try: | ||
| reply = sock.recv(flags=zmq.DONTWAIT) # type: ignore | ||
| if reply == HEARTBEAT_ACK: | ||
| with self._targets_lock: | ||
| self._alive[key] = now | ||
| self._counters[key]["recv"] = self._counters.get( | ||
| key, {}).get("recv", 0) + 1 | ||
| except Exception: | ||
| pass | ||
| time.sleep(self.poll_interval) | ||
| # cleanup | ||
| with self._targets_lock: | ||
| for sock in self._targets.values(): | ||
| try: | ||
| sock.close(linger=0) # type: ignore | ||
| except Exception: | ||
| pass | ||
| try: | ||
| self._ctx.term() # type: ignore | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
| class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata): | ||
|
|
||
| def __init__(self): | ||
|
|
@@ -705,6 +844,37 @@ | |
| deque) | ||
| self.remote_poller = zmq.Poller() # type: ignore | ||
| self.timeout = 1.0 # seconds | ||
| self._peer_meta_lock = threading.Lock() | ||
| # --- Heartbeat monitor --- | ||
| self.heartbeat = HeartbeatMonitor(poll_interval=1.0, | ||
| timeout=self.timeout) | ||
| self.heartbeat.start() | ||
| # Track current peer (host, port) per remote_engine_id to detect switches | ||
| self._peer_lock = threading.Lock() | ||
| self.current_peer: dict[str, tuple[str, int]] = {} | ||
| # Background watcher: print alive status every 5s | ||
| self._watch_stop = threading.Event() | ||
| self._watch_interval = 60.0 # seconds | ||
| self._watch_thread = threading.Thread(target=self._peer_watch_loop, | ||
| name="PeerWatcher", | ||
| daemon=True) | ||
| self._watch_thread.start() | ||
| logger.info( | ||
| f"[WATCH] Peer watcher started, interval={self._watch_interval}s") | ||
|
|
||
| def _register_peer(self, eng: Optional[str], host: Optional[str], | ||
| port: Optional[int]) -> None: | ||
| if eng is None or host is None or port is None: | ||
| return | ||
| with self._peer_lock: | ||
| self.current_peer[eng] = (host, port) | ||
| try: | ||
| self.heartbeat.add_target(host, port) # type: ignore[arg-type] | ||
| except Exception as _e: | ||
| logger.warning( | ||
| f"[WATCH] add_target failed for {host}:{port}: {_e}") | ||
| else: | ||
| logger.info(f"[WATCH] registered peer engine={eng} {host}:{port}") | ||
|
|
||
| def _get_prefill_decode_size(self, vllm_config: VllmConfig): | ||
| # get prefill tp and dp size from extra config | ||
|
|
@@ -989,6 +1159,48 @@ | |
| self.remote_poller.register(sock, zmq.POLLIN) # type: ignore | ||
| return sock | ||
|
|
||
| def _peer_watch_loop(self): | ||
| while not self._watch_stop.is_set(): | ||
| try: | ||
| with self._peer_lock: | ||
| snapshot = list(self.current_peer.items()) | ||
| if not snapshot: | ||
| logger.info("[WATCH] no peers registered yet") | ||
| for eng, (host, port) in snapshot: | ||
| try: | ||
| alive = self.heartbeat.is_alive(host, port) | ||
| except Exception as _e: | ||
| logger.warning( | ||
| f"[WATCH] is_alive exception for {host}:{port}: {_e}" | ||
| ) | ||
| alive = False | ||
| logger.info( | ||
| f"[WATCH] engine={eng} peer {host}:{port} alive={alive}" | ||
| ) | ||
| if not alive: | ||
| try: | ||
| with self._peer_meta_lock: | ||
| self.remote_te_port.get(eng, | ||
| {}).pop(port, None) | ||
| self.remote_kv_caches_base_addr.get( | ||
| eng, {}).pop(port, None) | ||
| logger.warning( | ||
| f"[WATCH] purged meta for down peer: engine={eng} {host}:{port}" | ||
| ) | ||
| except Exception as _e: | ||
| logger.warning( | ||
| f"[WATCH] purge meta error for {host}:{port}: {_e}" | ||
| ) | ||
| try: | ||
| self.heartbeat.remove_target(host, port) | ||
| except Exception as _e: | ||
| logger.warning( | ||
| f"[WATCH] heartbeat.remove_target failed for {host}:{port}: {_e}" | ||
| ) | ||
| except Exception as _e: | ||
| logger.warning(f"[WATCH] loop exception: {_e}") | ||
| time.sleep(self._watch_interval) | ||
|
|
||
| def update_decoder_info(self, req_id, req_meta): | ||
| req_meta_update = copy.deepcopy(req_meta) | ||
| if self.pd_tp_ratio > 1: | ||
|
|
@@ -1018,6 +1230,9 @@ | |
| logger.info( | ||
| f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}" | ||
| ) | ||
| self._register_peer(req_meta_update.remote_engine_id, | ||
| req_meta_update.remote_host, | ||
| req_meta_update.remote_port) | ||
| req_meta_update.remote_te_rpc_port = self.remote_te_port[ | ||
| req_meta_update.remote_engine_id][req_meta_update.remote_port] | ||
| req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silently swallowing all exceptions with
except Exception: passis dangerous. If sending the heartbeat acknowledgment fails, theHeartbeatMonitoron the other side will not receive a reply and may incorrectly assume this peer is down. This could lead to connection issues and failures that are very hard to debug. The exception should be logged to provide visibility into network or socket problems.