diff --git a/libs/opsqueue_python/tests/conftest.py b/libs/opsqueue_python/tests/conftest.py index b4b10bf..625f722 100644 --- a/libs/opsqueue_python/tests/conftest.py +++ b/libs/opsqueue_python/tests/conftest.py @@ -6,6 +6,8 @@ import subprocess import uuid import os +from libs.opsqueue_python.tests.util import wait_for_server +import psutil import pytest from dataclasses import dataclass from pathlib import Path @@ -52,13 +54,10 @@ def opsqueue() -> Generator[OpsqueueProcess, None, None]: @contextmanager def opsqueue_service( - *, port: int | None = None + *, port: int = 0, ) -> Generator[OpsqueueProcess, None, None]: global test_opsqueue_port_offset - if port is None: - port = random_free_port() - temp_dbname = f"/tmp/opsqueue_tests-{uuid.uuid4()}.db" command = [ @@ -72,7 +71,8 @@ def opsqueue_service( if env.get("RUST_LOG") is None: env["RUST_LOG"] = "off" - with subprocess.Popen(command, cwd=PROJECT_ROOT, env=env) as process: + with psutil.Popen(command, cwd=PROJECT_ROOT, env=env) as process: + _host, port = wait_for_server(process) try: wrapper = OpsqueueProcess(port=port, process=process) yield wrapper diff --git a/libs/opsqueue_python/tests/util.py b/libs/opsqueue_python/tests/util.py new file mode 100644 index 0000000..1b2d76f --- /dev/null +++ b/libs/opsqueue_python/tests/util.py @@ -0,0 +1,57 @@ +import logging + +import psutil +from opnieuw import retry +from psutil._common import pconn + +LOGGER = logging.getLogger(__name__) + + +@retry( + retry_on_exceptions=ValueError, + max_calls_total=5, + retry_window_after_first_call_in_seconds=5, +) +def wait_for_server(proc: psutil.Popen) -> tuple[str, int]: + """ + Wait for a process to be listening on a single port. + If the process is listening on no ports, a ValueError is thrown and this is retried. + If multiple ports are listening, a RuntimeError is thrown. + """ + if not proc.is_running(): + raise ValueError(f"Process {proc} is not running") + + try: + # Try to get the connections of the main process first, if that fails try the children. + # Processes wrapped with `timeout` do not have connections themselves. + connections: list[pconn] = ( + proc.net_connections() + or [ + child_conn + for child in proc.children(recursive=False) + for child_conn in child.net_connections() + ] + or [ + child_conn + for child in proc.children(recursive=True) + for child_conn in child.net_connections() + ] + ) + except psutil.AccessDenied as e: + match proc.status(): + case psutil.STATUS_ZOMBIE | psutil.STATUS_DEAD | psutil.STATUS_STOPPED: + raise RuntimeError(f"Process {proc} has exited unexpectedly") from e + case _: + raise RuntimeError( + f"Could not get `net_connections` for process {proc}, access denied " + ) from e + + ports = [x for x in connections if x.status == psutil.CONN_LISTEN] + listen_count = len(ports) + + if listen_count == 0: + raise ValueError(f"Process {proc} is not listening on any ports") + if listen_count == 1: + return ports[0].laddr + + raise RuntimeError(f"Process {proc} is listening on multiple ports")