Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions vllm_ascend/distributed/mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

GET_META_MSG = b"get_meta_msg"
DONE_SENDING_MSG = b"done_sending_msg"
HEARTBEAT_MSG = b"hb_ping"
HEARTBEAT_ACK = b"hb_pong"


class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
Expand Down Expand Up @@ -323,6 +325,12 @@
if msg[0] == GET_META_MSG:
logger.info("Got GET META INFO for request %s", msg[0])
sock.send_multipart((identity, b"", encoded_data))
elif msg[0] == HEARTBEAT_MSG:
# Heartbeat: reply immediately
try:
sock.send_multipart((identity, b"", HEARTBEAT_ACK))
except Exception:
pass
Comment on lines +330 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Silently swallowing all exceptions with except Exception: pass is dangerous. If sending the heartbeat acknowledgment fails, the HeartbeatMonitor on the other side will not receive a reply and may incorrectly assume this peer is down. This could lead to connection issues and failures that are very hard to debug. The exception should be logged to provide visibility into network or socket problems.

Suggested change
try:
sock.send_multipart((identity, b"", HEARTBEAT_ACK))
except Exception:
pass
try:
sock.send_multipart((identity, b"", HEARTBEAT_ACK))
except Exception as e:
logger.warning(f"Failed to send heartbeat ACK to {identity!r}: {e}")

elif msg[0] == DONE_SENDING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s",
msg[1])
Expand All @@ -337,6 +345,137 @@
logger.error("Failed to decode message: %s", e)


class HeartbeatMonitor(threading.Thread):
"""
Maintain heartbeat to multiple remote (host, port) REQ endpoints.
Non-blocking API: add_target(), is_alive().
"""

def __init__(self, poll_interval: float = 360, timeout: float = 1.0):
super().__init__(daemon=True, name="HeartbeatMonitor")
self.poll_interval = poll_interval
self.timeout = timeout
self._targets_lock = threading.Lock()
self._targets: dict[tuple[str, int], zmq.Socket] = {}

Check failure on line 359 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "zmq.Socket" is not defined [name-defined]

Check failure on line 359 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "zmq.Socket" is not defined [name-defined]

Check failure on line 359 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Name "zmq.Socket" is not defined [name-defined]
self._alive: dict[tuple[str, int], float] = {}
self._status: dict[tuple[str, int], bool] = {}
self._counters: dict[tuple[str, int], dict[str, int]] = {}
self._stop = threading.Event()
self._poller = zmq.Poller() # type: ignore
self._ctx = zmq.Context() # type: ignore

def add_target(self, host: str, port: int):
key = (host, port)
with self._targets_lock:
if key in self._targets:
return
path = make_zmq_path("tcp", host, port)
sock = make_zmq_socket(
ctx=self._ctx,
path=path,
socket_type=zmq.REQ, # type: ignore
bind=False)
# avoid blocking forever
sock.setsockopt(zmq.SNDTIMEO,

Check failure on line 379 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module has no attribute "SNDTIMEO" [attr-defined]

Check failure on line 379 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module has no attribute "SNDTIMEO" [attr-defined]

Check failure on line 379 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module has no attribute "SNDTIMEO" [attr-defined]
int(self.timeout * 1000)) # type: ignore
sock.setsockopt(zmq.RCVTIMEO,

Check failure on line 381 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module has no attribute "RCVTIMEO" [attr-defined]

Check failure on line 381 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module has no attribute "RCVTIMEO" [attr-defined]
int(self.timeout * 1000)) # type: ignore
# optional: TCP keepalive (helps at OS layer, not required)
# sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
self._targets[key] = sock
self._poller.register(sock, zmq.POLLIN) # type: ignore
# pessimistic init
self._alive[key] = 0.0
self._status[key] = False
self._counters[key] = {"sent": 0, "recv": 0, "timeout": 0}
logger.info(f"[HB] add_target host={host} port={port}")
logger.debug(
f"[HB] counters init for {host}:{port} -> {self._counters[key]}"
)

def is_alive(self, host: str, port: int) -> bool:

Check failure on line 396 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Signature of "is_alive" incompatible with supertype "Thread" [override]

Check failure on line 396 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Signature of "is_alive" incompatible with supertype "Thread" [override]
key = (host, port)
with self._targets_lock:
last_ok = self._alive.get(key, 0.0)
# consider alive if seen within ~3 intervals
return (time.time() - last_ok) < (3 * self.poll_interval +
self.timeout)

def remove_target(self, host: str, port: int):
"""Remove a target from heartbeat monitor and close its socket."""
key = (host, port)
with self._targets_lock:
sock = self._targets.pop(key, None)
if sock is not None:
try:
self._poller.unregister(sock) # type: ignore
except Exception:
pass
try:
sock.close(linger=0) # type: ignore
except Exception:
pass
# clean book-keeping
self._alive.pop(key, None)
self._status.pop(key, None)
self._counters.pop(key, None)
logger.info(f"[HB] removed target {host}:{port}")

def run(self):
encoder = msgspec.msgpack.Encoder()
while True:
with self._targets_lock:
items = list(self._targets.items())
# send pings
for key, sock in items:
try:
data = encoder.encode((HEARTBEAT_MSG, b""))
sock.send(data, flags=zmq.DONTWAIT) # type: ignore
try:
self._counters[key]["sent"] += 1
except KeyError:
self._counters[key] = {
"sent": 1,
"recv": 0,
"timeout": 0
}
Comment on lines +434 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This try...except KeyError block is suspicious. The _counters dictionary is populated in add_target under the same lock as _targets. Since the run loop iterates over a snapshot of _targets.items(), every key from the snapshot should already exist in _counters. A KeyError here would indicate a serious logic bug or a race condition. Instead of defensively handling it, it would be better to assert that the key exists or remove this handler if the error is indeed impossible. This kind of defensive coding can mask underlying problems.

                    self._counters[key]["sent"] += 1

host, port = key
logger.debug(
f"[HB->] ping {host}:{port} sent={self._counters[key]['sent']}"
)
except Exception:
# send error: keep _alive as is; next round will try again
pass
Comment on lines +446 to +448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Silently swallowing all exceptions with except Exception: pass is dangerous. It can hide critical issues in the heartbeat mechanism, such as configuration errors or network problems, making debugging extremely difficult. If sock.send fails consistently, it will never be reported, and the peer might be considered alive by other parts of the system while it's not being pinged. You should at least log the exception to provide visibility into potential problems.

                except Exception as e:
                    # send error: keep _alive as is; next round will try again
                    host, port = key
                    logger.warning(f"[HB->] failed to send ping to {host}:{port}: {e}")

try:
events = dict(self._poller.poll(int(self.timeout * 1000)))
except Exception:
events = {}
now = time.time()
for key, sock in items:
if sock in events:
try:
reply = sock.recv(flags=zmq.DONTWAIT) # type: ignore
if reply == HEARTBEAT_ACK:
with self._targets_lock:
self._alive[key] = now
self._counters[key]["recv"] = self._counters.get(
key, {}).get("recv", 0) + 1
except Exception:
pass
time.sleep(self.poll_interval)
# cleanup
with self._targets_lock:
for sock in self._targets.values():
try:
sock.close(linger=0) # type: ignore
except Exception:
pass
try:
self._ctx.term() # type: ignore
except Exception:
pass


class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):

def __init__(self):
Expand Down Expand Up @@ -705,6 +844,37 @@
deque)
self.remote_poller = zmq.Poller() # type: ignore
self.timeout = 1.0 # seconds
self._peer_meta_lock = threading.Lock()
# --- Heartbeat monitor ---
self.heartbeat = HeartbeatMonitor(poll_interval=1.0,
timeout=self.timeout)
self.heartbeat.start()
# Track current peer (host, port) per remote_engine_id to detect switches
self._peer_lock = threading.Lock()
self.current_peer: dict[str, tuple[str, int]] = {}
# Background watcher: print alive status every 5s
self._watch_stop = threading.Event()
self._watch_interval = 60.0 # seconds
self._watch_thread = threading.Thread(target=self._peer_watch_loop,
name="PeerWatcher",
daemon=True)
self._watch_thread.start()
logger.info(
f"[WATCH] Peer watcher started, interval={self._watch_interval}s")

def _register_peer(self, eng: Optional[str], host: Optional[str],
port: Optional[int]) -> None:
if eng is None or host is None or port is None:
return
with self._peer_lock:
self.current_peer[eng] = (host, port)
try:
self.heartbeat.add_target(host, port) # type: ignore[arg-type]
except Exception as _e:
logger.warning(
f"[WATCH] add_target failed for {host}:{port}: {_e}")
else:
logger.info(f"[WATCH] registered peer engine={eng} {host}:{port}")

def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
Expand Down Expand Up @@ -989,6 +1159,48 @@
self.remote_poller.register(sock, zmq.POLLIN) # type: ignore
return sock

def _peer_watch_loop(self):
while not self._watch_stop.is_set():
try:
with self._peer_lock:
snapshot = list(self.current_peer.items())
if not snapshot:
logger.info("[WATCH] no peers registered yet")
for eng, (host, port) in snapshot:
try:
alive = self.heartbeat.is_alive(host, port)
except Exception as _e:
logger.warning(
f"[WATCH] is_alive exception for {host}:{port}: {_e}"
)
alive = False
logger.info(
f"[WATCH] engine={eng} peer {host}:{port} alive={alive}"
)
if not alive:
try:
with self._peer_meta_lock:
self.remote_te_port.get(eng,
{}).pop(port, None)
self.remote_kv_caches_base_addr.get(
eng, {}).pop(port, None)
logger.warning(
f"[WATCH] purged meta for down peer: engine={eng} {host}:{port}"
)
except Exception as _e:
logger.warning(
f"[WATCH] purge meta error for {host}:{port}: {_e}"
)
try:
self.heartbeat.remove_target(host, port)
except Exception as _e:
logger.warning(
f"[WATCH] heartbeat.remove_target failed for {host}:{port}: {_e}"
)
except Exception as _e:
logger.warning(f"[WATCH] loop exception: {_e}")
time.sleep(self._watch_interval)

def update_decoder_info(self, req_id, req_meta):
req_meta_update = copy.deepcopy(req_meta)
if self.pd_tp_ratio > 1:
Expand Down Expand Up @@ -1018,6 +1230,9 @@
logger.info(
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
)
self._register_peer(req_meta_update.remote_engine_id,
req_meta_update.remote_host,
req_meta_update.remote_port)
req_meta_update.remote_te_rpc_port = self.remote_te_port[
req_meta_update.remote_engine_id][req_meta_update.remote_port]
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
Expand Down
Loading