Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
TaskGraph Release History
=========================

..
Unreleased Changes
------------------
Unreleased Changes
------------------
* When using ``n_workers >= 1``, the ``TaskGraph`` object will now monitor the
underlying ``multiprocessing.Pool`` object for any changes to the PIDs of its
processes. If a change is detected, the graph is shut down to avoid a
deadlock. https://github.com/natcap/taskgraph/issues/109


0.11.2 (2025-05-21)
-------------------
Expand Down
111 changes: 94 additions & 17 deletions taskgraph/Task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Task graph framework."""
import collections
import concurrent.futures
import hashlib
import inspect
import logging
Expand All @@ -14,6 +15,7 @@
import sqlite3
import threading
import time

try:
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version
Expand Down Expand Up @@ -95,6 +97,15 @@ def __init__(self, *args, **kwargs):
super(NonDaemonicPool, self).__init__(*args, **kwargs)


class NoDaemonProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
"""NonDaemonic Process Pool Executor"""

def __init__(self, *args, **kwargs):
"""Invoke super to set the context of Pool class explicitly."""
kwargs['mp_context'] = NoDaemonContext()
super(NoDaemonProcessPoolExecutor, self).__init__(*args, **kwargs)


def _null_func():
"""Use when func=None on add_task."""
return None
Expand Down Expand Up @@ -322,6 +333,8 @@ def __init__(
# the event to halt other executors
self._executor_ready_event = threading.Event()

self._executor_pool_broke_event = threading.Event()

# tasks that have all their dependencies satisfied go in this queue
# and can be executed immediately
self._task_ready_priority_queue = queue.PriorityQueue()
Expand Down Expand Up @@ -389,15 +402,30 @@ def __init__(
# set up multiprocessing if n_workers > 0
if n_workers > 0:
self._logging_queue = multiprocessing.Queue()
self._worker_pool = NonDaemonicPool(
n_workers, initializer=_initialize_logging_to_queue,
initargs=(self._logging_queue,))
self._worker_pool = NoDaemonProcessPoolExecutor(
max_workers=n_workers,
initializer=_initialize_logging_to_queue,
initargs=(self._logging_queue,)
)
self._logging_monitor_thread = threading.Thread(
target=_logging_queue_monitor,
args=(self._logging_queue,))

self._logging_monitor_thread.daemon = True
self._logging_monitor_thread.start()

self._executor_pool_broke_monitor_thread = threading.Thread(
target=self._handle_broken_process_pool,
args=())
self._executor_pool_broke_monitor_thread.daemon = True
self._executor_pool_broke_monitor_thread.start()

#self._process_pool_monitor_wait_event = threading.Event()
#self._process_pool_monitor_thread = threading.Thread(
# target=self._process_pool_monitor,
# args=(self._process_pool_monitor_wait_event,))
#self._process_pool_monitor_thread.daemon = True
#self._process_pool_monitor_thread.start()

if HAS_PSUTIL:
parent = psutil.Process()
parent.nice(PROCESS_LOW_PRIORITY)
Expand Down Expand Up @@ -448,8 +476,7 @@ def _task_executor(self):
# pool, because otherwise who knows if it's still
# executing anything
try:
self._worker_pool.close()
self._worker_pool.terminate()
self._worker_pool.shutdown()
self._worker_pool = None
self._terminate()
except Exception:
Expand Down Expand Up @@ -637,7 +664,7 @@ def add_task(
ignore_path_list, hash_target_files, ignore_directories,
transient_run, self._worker_pool,
priority, hash_algorithm, store_result,
self._task_database_path)
self._task_database_path, self._executor_pool_broke_event)

self._task_name_map[new_task.task_name] = new_task
# it may be this task was already created in an earlier call,
Expand Down Expand Up @@ -825,6 +852,48 @@ def close(self):
self._executor_ready_event.set()
LOGGER.debug("taskgraph closed")

def _handle_broken_process_pool(self):
# block until the event is set, which only happens if the pool broke.
self._executor_pool_broke_event.wait()
LOGGER.debug("Broken process event triggered; terminating graph")
self._terminate()

def _process_pool_monitor(self, pool_monitor_wait_event):
"""Monitor the state of the multiprocessing pool's workers.

Python's multiprocessing.Pool has a bunch of logic to make sure that
the pool always has the same number of workers, and it can even limit
the lifespan of the pool's worker processes. In our case, worker
processes have multiprocessing.Event objects on them, which means that
if a worker process dies for any reason, the whole TaskGraph object
will hang. This worker process monitors for any changes in the PIDs of
a multiprocessing.Pool object and terminates the graph if any are
found.

Args:
pool_monitor_wait_event (threading.Event): used to sleep the
monitor thread for 0.5 seconds.
"""

# TODO: check to see that the pool's executors are still running?
while True:
if self._terminated:
break

current_pids = set(
proc.pid for proc in self._worker_pool._pool)

if current_pids != starting_pool_pids:
LOGGER.error(
"A change in process pool PIDs has been detected! "
"Shutting down the task graph. "
f"{starting_pool_pids} changed to {current_pids }")
self._terminate()

# Wait 0.5s before looping.
pool_monitor_wait_event.wait(timeout=0.5)
LOGGER.debug('_process_pool_monitor shutting down')

def _terminate(self):
"""Immediately terminate remaining task graph computation."""
LOGGER.debug(
Expand All @@ -840,8 +909,7 @@ def _terminate(self):
self._executor_ready_event.set()
LOGGER.debug("shutting down workers")
if self._worker_pool is not None:
self._worker_pool.close()
self._worker_pool.terminate()
self._worker_pool.shutdown()
self._worker_pool = None

# This will terminate the logging worker
Expand Down Expand Up @@ -870,7 +938,7 @@ def __init__(
self, task_name, func, args, kwargs, target_path_list,
ignore_path_list, hash_target_files, ignore_directories,
transient_run, worker_pool, priority, hash_algorithm,
store_result, task_database_path):
store_result, task_database_path, pool_broke_event):
"""Make a Task.

Args:
Expand Down Expand Up @@ -928,7 +996,9 @@ def __init__(
for the target files created by the call and listed in
``target_path_list``, and the result of ``func`` is stored in
``result``.

pool_broke_event (threading.Event): A threading ``Event`` object
that will be set by this ``Task`` when an underlying
executor fails.
"""
# it is a common error to accidentally pass a non string as to the
# target path list, this terminates early if so
Expand Down Expand Up @@ -967,6 +1037,8 @@ def __init__(
# a _call and there are no more attempts at reexecution.
self.task_done_executing_event = threading.Event()

self.executor_pool_broke_event = pool_broke_event

# These are used to store and later access the result of the call.
self._result = None

Expand Down Expand Up @@ -1092,12 +1164,17 @@ def _call(self):
LOGGER.debug("not precalculated %s", self.task_name)

if self._worker_pool is not None:
result = self._worker_pool.apply_async(
func=self._func, args=self._args, kwds=self._kwargs)
# the following blocks and raises an exception if result
# raised an exception
LOGGER.debug("apply_async for task %s", self.task_name)
payload = result.get()
try:
result = self._worker_pool.submit(
self._func, *self._args, **self._kwargs)
# the following blocks and raises an exception if result
# raised an exception
LOGGER.debug("submit for task %s", self.task_name)
payload = result.result()
except concurrent.futures.process.BrokenProcessPool:
self.executor_pool_broke_event.set()
LOGGER.exception('Process pool broke!')
raise
else:
LOGGER.debug("direct _func for task %s", self.task_name)
payload = self._func(*self._args, **self._kwargs)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for taskgraph."""
import concurrent.futures.process
import hashlib
import logging
import logging.handlers
Expand All @@ -8,6 +9,7 @@
import pickle
import re
import shutil
import signal
import sqlite3
import subprocess
import tempfile
Expand Down Expand Up @@ -140,6 +142,19 @@ def _log_from_another_process(logger_name, log_message):
logger.info(log_message)


def _kill_current_process():
"""Kill the current process.

Must be run within a taskgraph task process.
"""
if __name__ == '__main__':
raise AssertionError(
"This function is only supposed to be called in a subprocess")

# Signal.SIGTERM works on both *NIX and Windows.
os.kill(os.getpid(), signal.SIGTERM)


class TaskGraphTests(unittest.TestCase):
"""Tests for the taskgraph."""

Expand Down Expand Up @@ -1480,6 +1495,24 @@ def test_mtime_mismatch(self):
with open(target_path) as target_file:
self.assertEqual(target_file.read(), content)

def test_multiprocessing_deadlock(self):
"""Verify that the graph is shut down in case of deadlock.

This test will deadlock if the functionality it is testing for (graph
shutdown when a task process is killed) is not available.

See https://github.com/natcap/taskgraph/issues/109
"""
task_graph = taskgraph.TaskGraph(self.workspace_dir, n_workers=1)
with self.assertLogs('taskgraph', level='ERROR') as cm:
_ = task_graph.add_task(_kill_current_process)
task_graph.join()
task_graph.close()

self.assertEqual(len(cm.output), 1)
self.assertTrue(cm.output[0].startswith('ERROR'))
self.assertIn('Process pool broke!', cm.output[0])


def Fail(n_tries, result_path):
"""Create a function that fails after ``n_tries``."""
Expand Down
Loading