From 2c012d363b4b507e8a2372b73ffe4ebdebc806b6 Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 14:35:53 -0700 Subject: [PATCH 01/14] Adding a test to reproduce the issue. RE:#109 --- tests/test_task.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_task.py b/tests/test_task.py index 397946c..46943d3 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -8,6 +8,7 @@ import pickle import re import shutil +import signal import sqlite3 import subprocess import tempfile @@ -140,6 +141,18 @@ def _log_from_another_process(logger_name, log_message): logger.info(log_message) +def _kill_current_process(): + """Kill the current process. + + Must be run within a taskgraph task process. + """ + if __name__ == '__main__': + raise AssertionError( + "This function is only supposed to be called in a subprocess") + + os.kill(os.getpid(), signal.SIGKILL) + + class TaskGraphTests(unittest.TestCase): """Tests for the taskgraph.""" @@ -1480,6 +1493,20 @@ def test_mtime_mismatch(self): with open(target_path) as target_file: self.assertEqual(target_file.read(), content) + def test_multiprocessing_deadlock(self): + """Verify that the graph is shut down in case of deadlock. + + This test will deadlock if the functionality it is testing for (graph + shutdown when a task process is killed) is not available. + + See https://github.com/natcap/taskgraph/issues/109 + """ + task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1) + task = task_graph.add_task(_kill_current_process) + + task_graph.join() + task_graph.close() + def Fail(n_tries, result_path): """Create a function that fails after ``n_tries``.""" From 8bd2ed2bd01931392f7d2d27b8ac7b8f18fb492f Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:06:10 -0700 Subject: [PATCH 02/14] Working on initializing a monitor thread. RE:#109 --- taskgraph/Task.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index 4ab8379..e0fb28b 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -14,6 +14,7 @@ import sqlite3 import threading import time + try: from importlib.metadata import PackageNotFoundError from importlib.metadata import version @@ -100,6 +101,19 @@ def _null_func(): return None +def _initialize_process_pool(logging_queue): + """A function to chain together multiple initialization functions. + + Args: + logging_queue (multiprocessing.Queue): The queue to use for passing + log records back to the main process. + + Returns: + ``None`` + """ + _initialize_logging_to_queue(logging_queue) + + def _initialize_logging_to_queue(logging_queue): """Add a synchronized queue to a new process. @@ -142,6 +156,21 @@ def _logging_queue_monitor(logging_queue): LOGGER.debug('_logging_queue_monitor shutting down') +def _process_pool_monitor(parent_pid, starting_pids_set): + LOGGER.debug("Starting the process pool PID monitor") + parent_process = psutil.Process(parent_pid) + while True: + child_processes = set( + proc.pid for proc in parent_process.children(recursive=False)) + print(child_processes) + if child_processes != starting_pids_set: + print("Change in PIDs!") + print(child_processes) + print(starting_pids_set) + time.sleep(1) + pass + + def _create_taskgraph_table_schema(taskgraph_database_path): """Create database exists and/or ensures it is compatible and recreate. @@ -390,14 +419,19 @@ def __init__( if n_workers > 0: self._logging_queue = multiprocessing.Queue() self._worker_pool = NonDaemonicPool( - n_workers, initializer=_initialize_logging_to_queue, + n_workers, initializer=_initialize_process_pool, initargs=(self._logging_queue,)) self._logging_monitor_thread = threading.Thread( target=_logging_queue_monitor, args=(self._logging_queue,)) + self._process_pool_monitor_thread = threading.Thread( + target=_process_pool_monitor, + args=(os.getpid(), set(proc.pid for proc in self._worker_pool._pool),)) self._logging_monitor_thread.daemon = True self._logging_monitor_thread.start() + self._process_pool_monitor_thread.daemon = True + self._process_pool_monitor_thread.start() if HAS_PSUTIL: parent = psutil.Process() parent.nice(PROCESS_LOW_PRIORITY) From 68bea713b2367dd18e934d74eaea8bc44b12859d Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:15:19 -0700 Subject: [PATCH 03/14] psutil is now required. Psutil is so much easier to use than the python stdlib equivalents relating to getting the PIDs of child processes. This should not be a problem to install on conda-forge or even from PyPI these days. RE:#109 --- taskgraph/Task.py | 44 ++++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index e0fb28b..4a77ba8 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -23,6 +23,7 @@ from importlib_metadata import PackageNotFoundError from importlib_metadata import version +import psutil import retrying try: @@ -35,19 +36,14 @@ _VALID_PATH_TYPES = (str, pathlib.PurePath) _TASKGRAPH_DATABASE_FILENAME = 'taskgraph_data.db' -try: - import psutil - HAS_PSUTIL = True - if psutil.WINDOWS: - # Windows' scheduler doesn't use POSIX niceness. - PROCESS_LOW_PRIORITY = psutil.BELOW_NORMAL_PRIORITY_CLASS - else: - # On POSIX, use system niceness. - # -20 is high priority, 0 is normal priority, 19 is low priority. - # 10 here is an arbitrary selection that's probably nice enough. - PROCESS_LOW_PRIORITY = 10 -except ImportError: - HAS_PSUTIL = False +if psutil.WINDOWS: + # Windows' scheduler doesn't use POSIX niceness. + PROCESS_LOW_PRIORITY = psutil.BELOW_NORMAL_PRIORITY_CLASS +else: + # On POSIX, use system niceness. + # -20 is high priority, 0 is normal priority, 19 is low priority. + # 10 here is an arbitrary selection that's probably nice enough. + PROCESS_LOW_PRIORITY = 10 LOGGER = logging.getLogger(__name__) _MAX_TIMEOUT = 5.0 # amount of time to wait for threads to terminate @@ -432,17 +428,17 @@ def __init__( self._logging_monitor_thread.start() self._process_pool_monitor_thread.daemon = True self._process_pool_monitor_thread.start() - if HAS_PSUTIL: - parent = psutil.Process() - parent.nice(PROCESS_LOW_PRIORITY) - for child in parent.children(): - try: - child.nice(PROCESS_LOW_PRIORITY) - except psutil.NoSuchProcess: - LOGGER.warning( - "NoSuchProcess exception encountered when trying " - "to nice %s. This might be a bug in `psutil` so " - "it should be okay to ignore.") + + parent = psutil.Process() + parent.nice(PROCESS_LOW_PRIORITY) + for child in parent.children(): + try: + child.nice(PROCESS_LOW_PRIORITY) + except psutil.NoSuchProcess: + LOGGER.warning( + "NoSuchProcess exception encountered when trying " + "to nice %s. This might be a bug in `psutil` so " + "it should be okay to ignore.") def __del__(self): """Ensure all threads have been joined for cleanup.""" From 257277b46e1c9214e4b3dfd9773e3885a1d430c5 Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:35:59 -0700 Subject: [PATCH 04/14] Moving monitor function to within the Graph object. This allows for easier access to the state of the Graph. RE:#109 --- taskgraph/Task.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index 4a77ba8..aa25325 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -152,21 +152,6 @@ def _logging_queue_monitor(logging_queue): LOGGER.debug('_logging_queue_monitor shutting down') -def _process_pool_monitor(parent_pid, starting_pids_set): - LOGGER.debug("Starting the process pool PID monitor") - parent_process = psutil.Process(parent_pid) - while True: - child_processes = set( - proc.pid for proc in parent_process.children(recursive=False)) - print(child_processes) - if child_processes != starting_pids_set: - print("Change in PIDs!") - print(child_processes) - print(starting_pids_set) - time.sleep(1) - pass - - def _create_taskgraph_table_schema(taskgraph_database_path): """Create database exists and/or ensures it is compatible and recreate. @@ -420,9 +405,11 @@ def __init__( self._logging_monitor_thread = threading.Thread( target=_logging_queue_monitor, args=(self._logging_queue,)) + + self._process_pool_monitor_wait_event = threading.Event() self._process_pool_monitor_thread = threading.Thread( - target=_process_pool_monitor, - args=(os.getpid(), set(proc.pid for proc in self._worker_pool._pool),)) + target=self._process_pool_monitor, + args=(self._process_pool_monitor_wait_event,)) self._logging_monitor_thread.daemon = True self._logging_monitor_thread.start() @@ -793,6 +780,26 @@ def _execution_monitor(self, monitor_wait_event): (time.time() - start_time)) % self._reporting_interval) LOGGER.debug("_execution monitor shutting down") + def _process_pool_monitor(self, pool_monitor_wait_event): + starting_pool_pids = set(proc.pid for proc in self._worker_pool._pool) + + while True: + if self._terminated: + break + + current_pids = set( + proc.pid for proc in self._worker_pool._pool) + + if current_pids != starting_pool_pids: + LOGGER.error( + "A change in process pool PIDs has been detected! " + "Shutting down the task graph. " + f"{starting_pool_pids} changed to {current_pids }") + self._terminate() + + # Wait 0.5s before looping. + pool_monitor_wait_event.wait(timeout=0.5) + def join(self, timeout=None): """Join all threads in the graph. From cfe554dbfc50584189bdc42bd685fcd6fbb911fb Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:44:54 -0700 Subject: [PATCH 05/14] Documentation. RE:#109 --- taskgraph/Task.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index aa25325..f8e5302 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -781,6 +781,21 @@ def _execution_monitor(self, monitor_wait_event): LOGGER.debug("_execution monitor shutting down") def _process_pool_monitor(self, pool_monitor_wait_event): + """Monitor the state of the multiprocessing pool's workers. + + Python's multiprocessing.Pool has a bunch of logic to make sure that + the pool always has the same number of workers, and it can even limit + the lifespan of the pool's worker processes. In our case, worker + processes have multiprocessing.Event objects on them, which means that + if a worker process dies for any reason, the whole TaskGraph object + will hang. This worker process monitors for any changes in the PIDs of + a multiprocessing.Pool object and terminates the graph if any are + found. + + Args: + pool_monitor_wait_event (threading.Event): used to sleep the + monitor thread for 0.5 seconds. + """ starting_pool_pids = set(proc.pid for proc in self._worker_pool._pool) while True: @@ -799,6 +814,7 @@ def _process_pool_monitor(self, pool_monitor_wait_event): # Wait 0.5s before looping. pool_monitor_wait_event.wait(timeout=0.5) + LOGGER.debug('_process_pool_monitor shutting down') def join(self, timeout=None): """Join all threads in the graph. From 0162bdccb9628efe7af1c08319e7d67f7eaa3333 Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:48:52 -0700 Subject: [PATCH 06/14] Testing that we log the graph's termination. RE:#109 --- tests/test_task.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_task.py b/tests/test_task.py index 46943d3..c63ed3f 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1502,10 +1502,15 @@ def test_multiprocessing_deadlock(self): See https://github.com/natcap/taskgraph/issues/109 """ task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1) - task = task_graph.add_task(_kill_current_process) + with self.assertLogs('taskgraph', level='ERROR') as cm: + _ = task_graph.add_task(_kill_current_process) + task_graph.join() + task_graph.close() - task_graph.join() - task_graph.close() + self.assertEqual(len(cm.output), 1) + self.assertTrue(cm.output[0].startswith('ERROR')) + self.assertIn('A change in process pool PIDs has been detected!', + cm.output[0]) def Fail(n_tries, result_path): From f06a187b547b70bc43e1e6cdce0aa0d0ec07e01e Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:54:25 -0700 Subject: [PATCH 07/14] Restoring psutil. It turns out that psutil was not needed for this solution. RE:#109 --- taskgraph/Task.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index f8e5302..daa65eb 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -23,7 +23,6 @@ from importlib_metadata import PackageNotFoundError from importlib_metadata import version -import psutil import retrying try: @@ -36,14 +35,19 @@ _VALID_PATH_TYPES = (str, pathlib.PurePath) _TASKGRAPH_DATABASE_FILENAME = 'taskgraph_data.db' -if psutil.WINDOWS: - # Windows' scheduler doesn't use POSIX niceness. - PROCESS_LOW_PRIORITY = psutil.BELOW_NORMAL_PRIORITY_CLASS -else: - # On POSIX, use system niceness. - # -20 is high priority, 0 is normal priority, 19 is low priority. - # 10 here is an arbitrary selection that's probably nice enough. - PROCESS_LOW_PRIORITY = 10 +try: + import psutil + HAS_PSUTIL = True + if psutil.WINDOWS: + # Windows' scheduler doesn't use POSIX niceness. + PROCESS_LOW_PRIORITY = psutil.BELOW_NORMAL_PRIORITY_CLASS + else: + # On POSIX, use system niceness. + # -20 is high priority, 0 is normal priority, 19 is low priority. + # 10 here is an arbitrary selection that's probably nice enough. + PROCESS_LOW_PRIORITY = 10 +except ImportError: + HAS_PSUTIL = False LOGGER = logging.getLogger(__name__) _MAX_TIMEOUT = 5.0 # amount of time to wait for threads to terminate @@ -416,16 +420,17 @@ def __init__( self._process_pool_monitor_thread.daemon = True self._process_pool_monitor_thread.start() - parent = psutil.Process() - parent.nice(PROCESS_LOW_PRIORITY) - for child in parent.children(): - try: - child.nice(PROCESS_LOW_PRIORITY) - except psutil.NoSuchProcess: - LOGGER.warning( - "NoSuchProcess exception encountered when trying " - "to nice %s. This might be a bug in `psutil` so " - "it should be okay to ignore.") + if HAS_PSUTIL: + parent = psutil.Process() + parent.nice(PROCESS_LOW_PRIORITY) + for child in parent.children(): + try: + child.nice(PROCESS_LOW_PRIORITY) + except psutil.NoSuchProcess: + LOGGER.warning( + "NoSuchProcess exception encountered when trying " + "to nice %s. This might be a bug in `psutil` so " + "it should be okay to ignore.") def __del__(self): """Ensure all threads have been joined for cleanup.""" From e7dfa3cf3774d85b975eb2c8ad3d96af3c8f7c2f Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 15:56:06 -0700 Subject: [PATCH 08/14] Restoring prior state of process initialization. It turned out that my modifications to this were not needed. RE:#109 --- taskgraph/Task.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index daa65eb..e3b202a 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -101,19 +101,6 @@ def _null_func(): return None -def _initialize_process_pool(logging_queue): - """A function to chain together multiple initialization functions. - - Args: - logging_queue (multiprocessing.Queue): The queue to use for passing - log records back to the main process. - - Returns: - ``None`` - """ - _initialize_logging_to_queue(logging_queue) - - def _initialize_logging_to_queue(logging_queue): """Add a synchronized queue to a new process. @@ -404,7 +391,7 @@ def __init__( if n_workers > 0: self._logging_queue = multiprocessing.Queue() self._worker_pool = NonDaemonicPool( - n_workers, initializer=_initialize_process_pool, + n_workers, initializer=_initialize_logging_to_queue, initargs=(self._logging_queue,)) self._logging_monitor_thread = threading.Thread( target=_logging_queue_monitor, From c6edd06d6fd564edba36c57eacb0b78979b6c3f8 Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 16:01:44 -0700 Subject: [PATCH 09/14] Killing with SIGTERM. RE:#109 --- tests/test_task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_task.py b/tests/test_task.py index c63ed3f..60ce276 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -150,7 +150,8 @@ def _kill_current_process(): raise AssertionError( "This function is only supposed to be called in a subprocess") - os.kill(os.getpid(), signal.SIGKILL) + # Signal.SIGTERM works on both *NIX and Windows. + os.kill(os.getpid(), signal.SIGTERM) class TaskGraphTests(unittest.TestCase): From e9ec3aa662eb2017425156b4f6a2dce374af8d3d Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 8 Jul 2025 16:14:37 -0700 Subject: [PATCH 10/14] Noting change in history. RE:#109 --- HISTORY.rst | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 54b09c4..07549c9 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,9 +4,13 @@ TaskGraph Release History ========================= -.. - Unreleased Changes - ------------------ +Unreleased Changes +------------------ +* When using ``n_workers >= 1``, the ``TaskGraph`` object will now monitor the + underlying ``multiprocessing.Pool`` object for any changes to the PIDs of its + processes. If a change is detected, the graph is shut down to avoid a + deadlock. https://github.com/natcap/taskgraph/issues/109 + 0.11.2 (2025-05-21) ------------------- From 18321d163bc47c212641796d54b172c9701b97e0 Mon Sep 17 00:00:00 2001 From: James Douglass Date: Tue, 15 Jul 2025 16:36:58 -0700 Subject: [PATCH 11/14] Reimplementing the Pool with concurrent.futures. RE:#109 --- taskgraph/Task.py | 76 +++++++++++++--------------------------------- tests/test_task.py | 8 ++--- 2 files changed, 23 insertions(+), 61 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index e3b202a..f14c251 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -1,5 +1,6 @@ """Task graph framework.""" import collections +import concurrent.futures import hashlib import inspect import logging @@ -96,6 +97,15 @@ def __init__(self, *args, **kwargs): super(NonDaemonicPool, self).__init__(*args, **kwargs) +class NoDaemonProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor): + """NonDaemonic Process Pool Executor""" + + def __init__(self, *args, **kwargs): + """Invoke super to set the context of Pool class explicitly.""" + kwargs['mp_context'] = NoDaemonContext() + super(NoDaemonProcessPoolExecutor, self).__init__(*args, **kwargs) + + def _null_func(): """Use when func=None on add_task.""" return None @@ -390,22 +400,16 @@ def __init__( # set up multiprocessing if n_workers > 0 if n_workers > 0: self._logging_queue = multiprocessing.Queue() - self._worker_pool = NonDaemonicPool( - n_workers, initializer=_initialize_logging_to_queue, - initargs=(self._logging_queue,)) + self._worker_pool = NoDaemonProcessPoolExecutor( + max_workers=n_workers, + initializer=_initialize_logging_to_queue, + initargs=(self._logging_queue,) + ) self._logging_monitor_thread = threading.Thread( target=_logging_queue_monitor, args=(self._logging_queue,)) - - self._process_pool_monitor_wait_event = threading.Event() - self._process_pool_monitor_thread = threading.Thread( - target=self._process_pool_monitor, - args=(self._process_pool_monitor_wait_event,)) - self._logging_monitor_thread.daemon = True self._logging_monitor_thread.start() - self._process_pool_monitor_thread.daemon = True - self._process_pool_monitor_thread.start() if HAS_PSUTIL: parent = psutil.Process() @@ -457,8 +461,7 @@ def _task_executor(self): # pool, because otherwise who knows if it's still # executing anything try: - self._worker_pool.close() - self._worker_pool.terminate() + self._worker_pool.shutdown() self._worker_pool = None self._terminate() except Exception: @@ -772,42 +775,6 @@ def _execution_monitor(self, monitor_wait_event): (time.time() - start_time)) % self._reporting_interval) LOGGER.debug("_execution monitor shutting down") - def _process_pool_monitor(self, pool_monitor_wait_event): - """Monitor the state of the multiprocessing pool's workers. - - Python's multiprocessing.Pool has a bunch of logic to make sure that - the pool always has the same number of workers, and it can even limit - the lifespan of the pool's worker processes. In our case, worker - processes have multiprocessing.Event objects on them, which means that - if a worker process dies for any reason, the whole TaskGraph object - will hang. This worker process monitors for any changes in the PIDs of - a multiprocessing.Pool object and terminates the graph if any are - found. - - Args: - pool_monitor_wait_event (threading.Event): used to sleep the - monitor thread for 0.5 seconds. - """ - starting_pool_pids = set(proc.pid for proc in self._worker_pool._pool) - - while True: - if self._terminated: - break - - current_pids = set( - proc.pid for proc in self._worker_pool._pool) - - if current_pids != starting_pool_pids: - LOGGER.error( - "A change in process pool PIDs has been detected! " - "Shutting down the task graph. " - f"{starting_pool_pids} changed to {current_pids }") - self._terminate() - - # Wait 0.5s before looping. - pool_monitor_wait_event.wait(timeout=0.5) - LOGGER.debug('_process_pool_monitor shutting down') - def join(self, timeout=None): """Join all threads in the graph. @@ -885,8 +852,7 @@ def _terminate(self): self._executor_ready_event.set() LOGGER.debug("shutting down workers") if self._worker_pool is not None: - self._worker_pool.close() - self._worker_pool.terminate() + self._worker_pool.shutdown() self._worker_pool = None # This will terminate the logging worker @@ -1137,12 +1103,12 @@ def _call(self): LOGGER.debug("not precalculated %s", self.task_name) if self._worker_pool is not None: - result = self._worker_pool.apply_async( - func=self._func, args=self._args, kwds=self._kwargs) + result = self._worker_pool.submit( + self._func, *self._args, **self._kwargs) # the following blocks and raises an exception if result # raised an exception - LOGGER.debug("apply_async for task %s", self.task_name) - payload = result.get() + LOGGER.debug("submit for task %s", self.task_name) + payload = result.result() else: LOGGER.debug("direct _func for task %s", self.task_name) payload = self._func(*self._args, **self._kwargs) diff --git a/tests/test_task.py b/tests/test_task.py index 60ce276..dc6b3f5 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,4 +1,5 @@ """Tests for taskgraph.""" +import concurrent.futures.process import hashlib import logging import logging.handlers @@ -1503,16 +1504,11 @@ def test_multiprocessing_deadlock(self): See https://github.com/natcap/taskgraph/issues/109 """ task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1) - with self.assertLogs('taskgraph', level='ERROR') as cm: + with self.assertRaises(concurrent.futures.process.BrokenProcessPool): _ = task_graph.add_task(_kill_current_process) task_graph.join() task_graph.close() - self.assertEqual(len(cm.output), 1) - self.assertTrue(cm.output[0].startswith('ERROR')) - self.assertIn('A change in process pool PIDs has been detected!', - cm.output[0]) - def Fail(n_tries, result_path): """Create a function that fails after ``n_tries``.""" From 128491f8b2065efe72677128a1876e171672cccf Mon Sep 17 00:00:00 2001 From: James Douglass Date: Thu, 17 Jul 2025 14:23:12 -0700 Subject: [PATCH 12/14] Adding an event to communicate the executor broke. RE:#109 --- taskgraph/Task.py | 82 +++++++++++++++++++++++++++++++++++++++++----- tests/test_task.py | 6 +++- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index f14c251..3b89b10 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -333,6 +333,8 @@ def __init__( # the event to halt other executors self._executor_ready_event = threading.Event() + self._executor_pool_broke_event = threading.Event() + # tasks that have all their dependencies satisfied go in this queue # and can be executed immediately self._task_ready_priority_queue = queue.PriorityQueue() @@ -411,6 +413,19 @@ def __init__( self._logging_monitor_thread.daemon = True self._logging_monitor_thread.start() + self._executor_pool_broke_monitor_thread = threading.Thread( + target=self._handle_broken_process_pool, + args=()) + self._executor_pool_broke_monitor_thread.daemon = True + self._executor_pool_broke_monitor_thread.start() + + #self._process_pool_monitor_wait_event = threading.Event() + #self._process_pool_monitor_thread = threading.Thread( + # target=self._process_pool_monitor, + # args=(self._process_pool_monitor_wait_event,)) + #self._process_pool_monitor_thread.daemon = True + #self._process_pool_monitor_thread.start() + if HAS_PSUTIL: parent = psutil.Process() parent.nice(PROCESS_LOW_PRIORITY) @@ -649,7 +664,7 @@ def add_task( ignore_path_list, hash_target_files, ignore_directories, transient_run, self._worker_pool, priority, hash_algorithm, store_result, - self._task_database_path) + self._task_database_path, self._executor_pool_broke_event) self._task_name_map[new_task.task_name] = new_task # it may be this task was already created in an earlier call, @@ -837,6 +852,47 @@ def close(self): self._executor_ready_event.set() LOGGER.debug("taskgraph closed") + def _handle_broken_process_pool(self): + # block until the event is set, which only happens if the pool broke. + self._executor_pool_broke_event.wait() + self._terminate() + + def _process_pool_monitor(self, pool_monitor_wait_event): + """Monitor the state of the multiprocessing pool's workers. + + Python's multiprocessing.Pool has a bunch of logic to make sure that + the pool always has the same number of workers, and it can even limit + the lifespan of the pool's worker processes. In our case, worker + processes have multiprocessing.Event objects on them, which means that + if a worker process dies for any reason, the whole TaskGraph object + will hang. This worker process monitors for any changes in the PIDs of + a multiprocessing.Pool object and terminates the graph if any are + found. + + Args: + pool_monitor_wait_event (threading.Event): used to sleep the + monitor thread for 0.5 seconds. + """ + + # TODO: check to see that the pool's executors are still running? + while True: + if self._terminated: + break + + current_pids = set( + proc.pid for proc in self._worker_pool._pool) + + if current_pids != starting_pool_pids: + LOGGER.error( + "A change in process pool PIDs has been detected! " + "Shutting down the task graph. " + f"{starting_pool_pids} changed to {current_pids }") + self._terminate() + + # Wait 0.5s before looping. + pool_monitor_wait_event.wait(timeout=0.5) + LOGGER.debug('_process_pool_monitor shutting down') + def _terminate(self): """Immediately terminate remaining task graph computation.""" LOGGER.debug( @@ -881,7 +937,7 @@ def __init__( self, task_name, func, args, kwargs, target_path_list, ignore_path_list, hash_target_files, ignore_directories, transient_run, worker_pool, priority, hash_algorithm, - store_result, task_database_path): + store_result, task_database_path, pool_broke_event): """Make a Task. Args: @@ -939,7 +995,9 @@ def __init__( for the target files created by the call and listed in ``target_path_list``, and the result of ``func`` is stored in ``result``. - + pool_broke_event (threading.Event): A threading ``Event`` object + that will be set by this ``Task`` when an underlying + executor fails. """ # it is a common error to accidentally pass a non string as to the # target path list, this terminates early if so @@ -978,6 +1036,8 @@ def __init__( # a _call and there are no more attempts at reexecution. self.task_done_executing_event = threading.Event() + self.executor_pool_broke_event = pool_broke_event + # These are used to store and later access the result of the call. self._result = None @@ -1103,12 +1163,16 @@ def _call(self): LOGGER.debug("not precalculated %s", self.task_name) if self._worker_pool is not None: - result = self._worker_pool.submit( - self._func, *self._args, **self._kwargs) - # the following blocks and raises an exception if result - # raised an exception - LOGGER.debug("submit for task %s", self.task_name) - payload = result.result() + try: + result = self._worker_pool.submit( + self._func, *self._args, **self._kwargs) + # the following blocks and raises an exception if result + # raised an exception + LOGGER.debug("submit for task %s", self.task_name) + payload = result.result() + except concurrent.futures.process.BrokenProcessPool: + self.executor_pool_broke_event.set() + LOGGER.exception('Process pool broke!') else: LOGGER.debug("direct _func for task %s", self.task_name) payload = self._func(*self._args, **self._kwargs) diff --git a/tests/test_task.py b/tests/test_task.py index dc6b3f5..203afbf 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1504,11 +1504,15 @@ def test_multiprocessing_deadlock(self): See https://github.com/natcap/taskgraph/issues/109 """ task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1) - with self.assertRaises(concurrent.futures.process.BrokenProcessPool): + with self.assertLogs('taskgraph', level='ERROR') as cm: _ = task_graph.add_task(_kill_current_process) task_graph.join() task_graph.close() + self.assertEqual(len(cm.output), 1) + self.assertTrue(cm.output[0].startswith('ERROR')) + self.assertIn('Process pool broke!', cm.output[0]) + def Fail(n_tries, result_path): """Create a function that fails after ``n_tries``.""" From c821e726ca8fd273c62e77eb941ca52f7be3885a Mon Sep 17 00:00:00 2001 From: James Douglass Date: Thu, 17 Jul 2025 14:36:01 -0700 Subject: [PATCH 13/14] Adding a debug statement. RE:#109 --- taskgraph/Task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index 3b89b10..c1852fd 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -855,6 +855,7 @@ def close(self): def _handle_broken_process_pool(self): # block until the event is set, which only happens if the pool broke. self._executor_pool_broke_event.wait() + LOGGER.debug("Broken process event triggered; terminating graph") self._terminate() def _process_pool_monitor(self, pool_monitor_wait_event): From 1b40259876cd338c0ac6794f87813c527212334c Mon Sep 17 00:00:00 2001 From: James Douglass Date: Thu, 17 Jul 2025 14:36:59 -0700 Subject: [PATCH 14/14] Raising the BrokenProcessError after logging. RE:#109 --- taskgraph/Task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/taskgraph/Task.py b/taskgraph/Task.py index c1852fd..ebdb85b 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -1174,6 +1174,7 @@ def _call(self): except concurrent.futures.process.BrokenProcessPool: self.executor_pool_broke_event.set() LOGGER.exception('Process pool broke!') + raise else: LOGGER.debug("direct _func for task %s", self.task_name) payload = self._func(*self._args, **self._kwargs)