diff --git a/HISTORY.rst b/HISTORY.rst index 54b09c4..07549c9 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,9 +4,13 @@ TaskGraph Release History ========================= -.. - Unreleased Changes - ------------------ +Unreleased Changes +------------------ +* When using ``n_workers >= 1``, the ``TaskGraph`` object will now monitor the + underlying ``multiprocessing.Pool`` object for any changes to the PIDs of its + processes. If a change is detected, the graph is shut down to avoid a + deadlock. https://github.com/natcap/taskgraph/issues/109 + 0.11.2 (2025-05-21) ------------------- diff --git a/taskgraph/Task.py b/taskgraph/Task.py index 4ab8379..ebdb85b 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -1,5 +1,6 @@ """Task graph framework.""" import collections +import concurrent.futures import hashlib import inspect import logging @@ -14,6 +15,7 @@ import sqlite3 import threading import time + try: from importlib.metadata import PackageNotFoundError from importlib.metadata import version @@ -95,6 +97,15 @@ def __init__(self, *args, **kwargs): super(NonDaemonicPool, self).__init__(*args, **kwargs) +class NoDaemonProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor): + """NonDaemonic Process Pool Executor""" + + def __init__(self, *args, **kwargs): + """Invoke super to set the context of Pool class explicitly.""" + kwargs['mp_context'] = NoDaemonContext() + super(NoDaemonProcessPoolExecutor, self).__init__(*args, **kwargs) + + def _null_func(): """Use when func=None on add_task.""" return None @@ -322,6 +333,8 @@ def __init__( # the event to halt other executors self._executor_ready_event = threading.Event() + self._executor_pool_broke_event = threading.Event() + # tasks that have all their dependencies satisfied go in this queue # and can be executed immediately self._task_ready_priority_queue = queue.PriorityQueue() @@ -389,15 +402,30 @@ def __init__( # set up multiprocessing if n_workers > 0 if n_workers > 0: self._logging_queue = multiprocessing.Queue() - self._worker_pool = NonDaemonicPool( - n_workers, initializer=_initialize_logging_to_queue, - initargs=(self._logging_queue,)) + self._worker_pool = NoDaemonProcessPoolExecutor( + max_workers=n_workers, + initializer=_initialize_logging_to_queue, + initargs=(self._logging_queue,) + ) self._logging_monitor_thread = threading.Thread( target=_logging_queue_monitor, args=(self._logging_queue,)) - self._logging_monitor_thread.daemon = True self._logging_monitor_thread.start() + + self._executor_pool_broke_monitor_thread = threading.Thread( + target=self._handle_broken_process_pool, + args=()) + self._executor_pool_broke_monitor_thread.daemon = True + self._executor_pool_broke_monitor_thread.start() + + #self._process_pool_monitor_wait_event = threading.Event() + #self._process_pool_monitor_thread = threading.Thread( + # target=self._process_pool_monitor, + # args=(self._process_pool_monitor_wait_event,)) + #self._process_pool_monitor_thread.daemon = True + #self._process_pool_monitor_thread.start() + if HAS_PSUTIL: parent = psutil.Process() parent.nice(PROCESS_LOW_PRIORITY) @@ -448,8 +476,7 @@ def _task_executor(self): # pool, because otherwise who knows if it's still # executing anything try: - self._worker_pool.close() - self._worker_pool.terminate() + self._worker_pool.shutdown() self._worker_pool = None self._terminate() except Exception: @@ -637,7 +664,7 @@ def add_task( ignore_path_list, hash_target_files, ignore_directories, transient_run, self._worker_pool, priority, hash_algorithm, store_result, - self._task_database_path) + self._task_database_path, self._executor_pool_broke_event) self._task_name_map[new_task.task_name] = new_task # it may be this task was already created in an earlier call, @@ -825,6 +852,48 @@ def close(self): self._executor_ready_event.set() LOGGER.debug("taskgraph closed") + def _handle_broken_process_pool(self): + # block until the event is set, which only happens if the pool broke. + self._executor_pool_broke_event.wait() + LOGGER.debug("Broken process event triggered; terminating graph") + self._terminate() + + def _process_pool_monitor(self, pool_monitor_wait_event): + """Monitor the state of the multiprocessing pool's workers. + + Python's multiprocessing.Pool has a bunch of logic to make sure that + the pool always has the same number of workers, and it can even limit + the lifespan of the pool's worker processes. In our case, worker + processes have multiprocessing.Event objects on them, which means that + if a worker process dies for any reason, the whole TaskGraph object + will hang. This worker process monitors for any changes in the PIDs of + a multiprocessing.Pool object and terminates the graph if any are + found. + + Args: + pool_monitor_wait_event (threading.Event): used to sleep the + monitor thread for 0.5 seconds. + """ + + # TODO: check to see that the pool's executors are still running? + while True: + if self._terminated: + break + + current_pids = set( + proc.pid for proc in self._worker_pool._pool) + + if current_pids != starting_pool_pids: + LOGGER.error( + "A change in process pool PIDs has been detected! " + "Shutting down the task graph. " + f"{starting_pool_pids} changed to {current_pids }") + self._terminate() + + # Wait 0.5s before looping. + pool_monitor_wait_event.wait(timeout=0.5) + LOGGER.debug('_process_pool_monitor shutting down') + def _terminate(self): """Immediately terminate remaining task graph computation.""" LOGGER.debug( @@ -840,8 +909,7 @@ def _terminate(self): self._executor_ready_event.set() LOGGER.debug("shutting down workers") if self._worker_pool is not None: - self._worker_pool.close() - self._worker_pool.terminate() + self._worker_pool.shutdown() self._worker_pool = None # This will terminate the logging worker @@ -870,7 +938,7 @@ def __init__( self, task_name, func, args, kwargs, target_path_list, ignore_path_list, hash_target_files, ignore_directories, transient_run, worker_pool, priority, hash_algorithm, - store_result, task_database_path): + store_result, task_database_path, pool_broke_event): """Make a Task. Args: @@ -928,7 +996,9 @@ def __init__( for the target files created by the call and listed in ``target_path_list``, and the result of ``func`` is stored in ``result``. - + pool_broke_event (threading.Event): A threading ``Event`` object + that will be set by this ``Task`` when an underlying + executor fails. """ # it is a common error to accidentally pass a non string as to the # target path list, this terminates early if so @@ -967,6 +1037,8 @@ def __init__( # a _call and there are no more attempts at reexecution. self.task_done_executing_event = threading.Event() + self.executor_pool_broke_event = pool_broke_event + # These are used to store and later access the result of the call. self._result = None @@ -1092,12 +1164,17 @@ def _call(self): LOGGER.debug("not precalculated %s", self.task_name) if self._worker_pool is not None: - result = self._worker_pool.apply_async( - func=self._func, args=self._args, kwds=self._kwargs) - # the following blocks and raises an exception if result - # raised an exception - LOGGER.debug("apply_async for task %s", self.task_name) - payload = result.get() + try: + result = self._worker_pool.submit( + self._func, *self._args, **self._kwargs) + # the following blocks and raises an exception if result + # raised an exception + LOGGER.debug("submit for task %s", self.task_name) + payload = result.result() + except concurrent.futures.process.BrokenProcessPool: + self.executor_pool_broke_event.set() + LOGGER.exception('Process pool broke!') + raise else: LOGGER.debug("direct _func for task %s", self.task_name) payload = self._func(*self._args, **self._kwargs) diff --git a/tests/test_task.py b/tests/test_task.py index 397946c..203afbf 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,4 +1,5 @@ """Tests for taskgraph.""" +import concurrent.futures.process import hashlib import logging import logging.handlers @@ -8,6 +9,7 @@ import pickle import re import shutil +import signal import sqlite3 import subprocess import tempfile @@ -140,6 +142,19 @@ def _log_from_another_process(logger_name, log_message): logger.info(log_message) +def _kill_current_process(): + """Kill the current process. + + Must be run within a taskgraph task process. + """ + if __name__ == '__main__': + raise AssertionError( + "This function is only supposed to be called in a subprocess") + + # Signal.SIGTERM works on both *NIX and Windows. + os.kill(os.getpid(), signal.SIGTERM) + + class TaskGraphTests(unittest.TestCase): """Tests for the taskgraph.""" @@ -1480,6 +1495,24 @@ def test_mtime_mismatch(self): with open(target_path) as target_file: self.assertEqual(target_file.read(), content) + def test_multiprocessing_deadlock(self): + """Verify that the graph is shut down in case of deadlock. + + This test will deadlock if the functionality it is testing for (graph + shutdown when a task process is killed) is not available. + + See https://github.com/natcap/taskgraph/issues/109 + """ + task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1) + with self.assertLogs('taskgraph', level='ERROR') as cm: + _ = task_graph.add_task(_kill_current_process) + task_graph.join() + task_graph.close() + + self.assertEqual(len(cm.output), 1) + self.assertTrue(cm.output[0].startswith('ERROR')) + self.assertIn('Process pool broke!', cm.output[0]) + def Fail(n_tries, result_path): """Create a function that fails after ``n_tries``."""