diff --git a/Doc/library/multiprocessing.rst b/Doc/library/multiprocessing.rst index 714207cb0aefcd..16d5c75a4353f4 100644 --- a/Doc/library/multiprocessing.rst +++ b/Doc/library/multiprocessing.rst @@ -1211,22 +1211,32 @@ Miscellaneous .. versionchanged:: 3.11 Accepts a :term:`path-like object`. -.. function:: set_forkserver_preload(module_names) +.. function:: set_forkserver_preload(module_names, *, on_error='ignore') Set a list of module names for the forkserver main process to attempt to import so that their already imported state is inherited by forked - processes. Any :exc:`ImportError` when doing so is silently ignored. - This can be used as a performance enhancement to avoid repeated work - in every process. + processes. This can be used as a performance enhancement to avoid repeated + work in every process. For this to work, it must be called before the forkserver process has been launched (before creating a :class:`Pool` or starting a :class:`Process`). + The *on_error* parameter controls how :exc:`ImportError` exceptions during + module preloading are handled: ``"ignore"`` (default) silently ignores + failures, ``"warn"`` causes the forkserver subprocess to emit an + :exc:`ImportWarning` to stderr, and ``"fail"`` causes the forkserver + subprocess to exit with the exception traceback on stderr, making + subsequent process creation fail with :exc:`EOFError` or + :exc:`ConnectionError`. + Only meaningful when using the ``'forkserver'`` start method. See :ref:`multiprocessing-start-methods`. .. versionadded:: 3.4 + .. versionchanged:: next + Added the *on_error* parameter. + .. function:: set_start_method(method, force=False) Set the method which should be used to start child processes. diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index 051d567d457928..a73261cde856bb 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -177,12 +177,15 @@ def set_executable(self, executable): from .spawn import set_executable set_executable(executable) - def set_forkserver_preload(self, module_names): + def set_forkserver_preload(self, module_names, *, on_error='ignore'): '''Set list of module names to try to load in forkserver process. - This is really just a hint. + + The on_error parameter controls how import failures are handled: + "ignore" (default) silently ignores failures, "warn" emits warnings, + and "fail" raises exceptions breaking the forkserver context. ''' from .forkserver import set_forkserver_preload - set_forkserver_preload(module_names) + set_forkserver_preload(module_names, on_error=on_error) def get_context(self, method=None): if method is None: diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 8a4e8d835b0c91..7d9033415c0863 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -42,6 +42,7 @@ def __init__(self): self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] + self._preload_on_error = 'ignore' def _stop(self): # Method used by unit tests to stop the server @@ -64,11 +65,22 @@ def _stop_unlocked(self): self._forkserver_address = None self._forkserver_authkey = None - def set_forkserver_preload(self, modules_names): - '''Set list of module names to try to load in forkserver process.''' + def set_forkserver_preload(self, modules_names, *, on_error='ignore'): + '''Set list of module names to try to load in forkserver process. + + The on_error parameter controls how import failures are handled: + "ignore" (default) silently ignores failures, "warn" emits warnings, + and "fail" raises exceptions breaking the forkserver context. + ''' if not all(type(mod) is str for mod in modules_names): raise TypeError('module_names must be a list of strings') + if on_error not in ('ignore', 'warn', 'fail'): + raise ValueError( + f"on_error must be 'ignore', 'warn', or 'fail', " + f"not {on_error!r}" + ) self._preload_modules = modules_names + self._preload_on_error = on_error def get_inherited_fds(self): '''Return list of fds inherited from parent process. @@ -107,6 +119,14 @@ def connect_to_new_process(self, fds): wrapped_client, self._forkserver_authkey) connection.deliver_challenge( wrapped_client, self._forkserver_authkey) + except (EOFError, ConnectionError, BrokenPipeError) as exc: + if (self._preload_modules and + self._preload_on_error == 'fail'): + exc.add_note( + "Forkserver process may have crashed during module " + "preloading. Check stderr." + ) + raise finally: wrapped_client._detach() del wrapped_client @@ -152,6 +172,8 @@ def ensure_running(self): main_kws['sys_path'] = data['sys_path'] if 'init_main_from_path' in data: main_kws['main_path'] = data['init_main_from_path'] + if self._preload_on_error != 'ignore': + main_kws['on_error'] = self._preload_on_error with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') @@ -196,8 +218,68 @@ def ensure_running(self): # # +def _handle_preload(preload, main_path=None, sys_path=None, on_error='ignore'): + """Handle module preloading with configurable error handling. + + Args: + preload: List of module names to preload. + main_path: Path to __main__ module if '__main__' is in preload. + sys_path: sys.path to use for imports (None means use current). + on_error: How to handle import errors ("ignore", "warn", or "fail"). + """ + if not preload: + return + + if sys_path is not None: + sys.path[:] = sys_path + + if '__main__' in preload and main_path is not None: + process.current_process()._inheriting = True + try: + spawn.import_main_path(main_path) + except Exception as e: + # Catch broad Exception because import_main_path() uses + # runpy.run_path() which executes the script and can raise + # any exception, not just ImportError + match on_error: + case 'fail': + raise + case 'warn': + import warnings + warnings.warn( + f"Failed to preload __main__ from {main_path!r}: {e}", + ImportWarning, + stacklevel=2 + ) + case 'ignore': + pass + finally: + del process.current_process()._inheriting + + for modname in preload: + try: + __import__(modname) + except ImportError as e: + match on_error: + case 'fail': + raise + case 'warn': + import warnings + warnings.warn( + f"Failed to preload module {modname!r}: {e}", + ImportWarning, + stacklevel=2 + ) + case 'ignore': + pass + + # gh-135335: flush stdout/stderr in case any of the preloaded modules + # wrote to them, otherwise children might inherit buffered data + util._flush_std_streams() + + def main(listener_fd, alive_r, preload, main_path=None, sys_path=None, - *, authkey_r=None): + *, authkey_r=None, on_error='ignore'): """Run forkserver.""" if authkey_r is not None: try: @@ -208,24 +290,7 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None, else: authkey = b'' - if preload: - if sys_path is not None: - sys.path[:] = sys_path - if '__main__' in preload and main_path is not None: - process.current_process()._inheriting = True - try: - spawn.import_main_path(main_path) - finally: - del process.current_process()._inheriting - for modname in preload: - try: - __import__(modname) - except ImportError: - pass - - # gh-135335: flush stdout/stderr in case any of the preloaded modules - # wrote to them, otherwise children might inherit buffered data - util._flush_std_streams() + _handle_preload(preload, main_path, sys_path, on_error) util._close_stdin() diff --git a/Lib/test/test_multiprocessing_forkserver/__init__.py b/Lib/test/test_multiprocessing_forkserver/__init__.py index d91715a344dfa7..7b1b884ab297b5 100644 --- a/Lib/test/test_multiprocessing_forkserver/__init__.py +++ b/Lib/test/test_multiprocessing_forkserver/__init__.py @@ -9,5 +9,8 @@ if sys.platform == "win32": raise unittest.SkipTest("forkserver is not available on Windows") +if not support.has_fork_support: + raise unittest.SkipTest("requires working os.fork()") + def load_tests(*args): return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_multiprocessing_forkserver/test_preload.py b/Lib/test/test_multiprocessing_forkserver/test_preload.py new file mode 100644 index 00000000000000..7bc9ec18a3d471 --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/test_preload.py @@ -0,0 +1,179 @@ +"""Tests for forkserver preload functionality.""" + +import multiprocessing +import sys +import tempfile +import unittest +from multiprocessing import forkserver + + +class TestForkserverPreload(unittest.TestCase): + """Tests for forkserver preload functionality.""" + + def setUp(self): + self._saved_warnoptions = sys.warnoptions.copy() + # Remove warning options that would convert ImportWarning to errors: + # - 'error' converts all warnings to errors + # - 'error::ImportWarning' specifically converts ImportWarning + # Keep other specific options like 'error::BytesWarning' that + # subprocess's _args_from_interpreter_flags() expects to remove + sys.warnoptions[:] = [ + opt for opt in sys.warnoptions + if opt not in ('error', 'error::ImportWarning') + ] + self.ctx = multiprocessing.get_context('forkserver') + forkserver._forkserver._stop() + + def tearDown(self): + sys.warnoptions[:] = self._saved_warnoptions + forkserver._forkserver._stop() + + @staticmethod + def _send_value(conn, value): + """Send value through connection. Static method to be picklable as Process target.""" + conn.send(value) + + def test_preload_on_error_ignore_default(self): + """Test that invalid modules are silently ignored by default.""" + self.ctx.set_forkserver_preload(['nonexistent_module_xyz']) + + r, w = self.ctx.Pipe(duplex=False) + p = self.ctx.Process(target=self._send_value, args=(w, 42)) + p.start() + w.close() + result = r.recv() + r.close() + p.join() + + self.assertEqual(result, 42) + self.assertEqual(p.exitcode, 0) + + def test_preload_on_error_ignore_explicit(self): + """Test that invalid modules are silently ignored with on_error='ignore'.""" + self.ctx.set_forkserver_preload(['nonexistent_module_xyz'], on_error='ignore') + + r, w = self.ctx.Pipe(duplex=False) + p = self.ctx.Process(target=self._send_value, args=(w, 99)) + p.start() + w.close() + result = r.recv() + r.close() + p.join() + + self.assertEqual(result, 99) + self.assertEqual(p.exitcode, 0) + + def test_preload_on_error_warn(self): + """Test that invalid modules emit warnings with on_error='warn'.""" + self.ctx.set_forkserver_preload(['nonexistent_module_xyz'], on_error='warn') + + r, w = self.ctx.Pipe(duplex=False) + p = self.ctx.Process(target=self._send_value, args=(w, 123)) + p.start() + w.close() + result = r.recv() + r.close() + p.join() + + self.assertEqual(result, 123) + self.assertEqual(p.exitcode, 0) + + def test_preload_on_error_fail_breaks_context(self): + """Test that invalid modules with on_error='fail' breaks the forkserver.""" + self.ctx.set_forkserver_preload(['nonexistent_module_xyz'], on_error='fail') + + r, w = self.ctx.Pipe(duplex=False) + try: + p = self.ctx.Process(target=self._send_value, args=(w, 42)) + with self.assertRaises((EOFError, ConnectionError, BrokenPipeError)) as cm: + p.start() + notes = getattr(cm.exception, '__notes__', []) + self.assertTrue(notes, "Expected exception to have __notes__") + self.assertIn('Forkserver process may have crashed', notes[0]) + finally: + w.close() + r.close() + + def test_preload_valid_modules_with_on_error_fail(self): + """Test that valid modules work fine with on_error='fail'.""" + self.ctx.set_forkserver_preload(['os', 'sys'], on_error='fail') + + r, w = self.ctx.Pipe(duplex=False) + p = self.ctx.Process(target=self._send_value, args=(w, 'success')) + p.start() + w.close() + result = r.recv() + r.close() + p.join() + + self.assertEqual(result, 'success') + self.assertEqual(p.exitcode, 0) + + def test_preload_invalid_on_error_value(self): + """Test that invalid on_error values raise ValueError.""" + with self.assertRaises(ValueError) as cm: + self.ctx.set_forkserver_preload(['os'], on_error='invalid') + self.assertIn("on_error must be 'ignore', 'warn', or 'fail'", str(cm.exception)) + + +class TestHandlePreload(unittest.TestCase): + """Unit tests for _handle_preload() function.""" + + def test_handle_preload_main_on_error_fail(self): + """Test that __main__ import failures raise with on_error='fail'.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f: + f.write('raise RuntimeError("test error in __main__")\n') + f.flush() + with self.assertRaises(RuntimeError) as cm: + forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail') + self.assertIn("test error in __main__", str(cm.exception)) + + def test_handle_preload_main_on_error_warn(self): + """Test that __main__ import failures warn with on_error='warn'.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f: + f.write('raise ImportError("test import error")\n') + f.flush() + with self.assertWarns(ImportWarning) as cm: + forkserver._handle_preload(['__main__'], main_path=f.name, on_error='warn') + self.assertIn("Failed to preload __main__", str(cm.warning)) + self.assertIn("test import error", str(cm.warning)) + + def test_handle_preload_main_on_error_ignore(self): + """Test that __main__ import failures are ignored with on_error='ignore'.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f: + f.write('raise ImportError("test import error")\n') + f.flush() + forkserver._handle_preload(['__main__'], main_path=f.name, on_error='ignore') + + def test_handle_preload_main_valid(self): + """Test that valid __main__ preload works.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f: + f.write('test_var = 42\n') + f.flush() + forkserver._handle_preload(['__main__'], main_path=f.name, on_error='fail') + + def test_handle_preload_module_on_error_fail(self): + """Test that module import failures raise with on_error='fail'.""" + with self.assertRaises(ModuleNotFoundError): + forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='fail') + + def test_handle_preload_module_on_error_warn(self): + """Test that module import failures warn with on_error='warn'.""" + with self.assertWarns(ImportWarning) as cm: + forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='warn') + self.assertIn("Failed to preload module", str(cm.warning)) + + def test_handle_preload_module_on_error_ignore(self): + """Test that module import failures are ignored with on_error='ignore'.""" + forkserver._handle_preload(['nonexistent_test_module_xyz'], on_error='ignore') + + def test_handle_preload_combined(self): + """Test preloading both __main__ and modules.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py') as f: + f.write('import sys\n') + f.flush() + forkserver._handle_preload(['__main__', 'os', 'sys'], main_path=f.name, on_error='fail') + + +if __name__ == '__main__': + unittest.main() diff --git a/Misc/ACKS b/Misc/ACKS index f5f15f2eb7ea24..fc9085696e3d03 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -1337,6 +1337,7 @@ Trent Nelson Andrew Nester Osvaldo Santana Neto Chad Netzer +Nick Neumann Max Neunhöffer Anthon van der Neut George Neville-Neil diff --git a/Misc/NEWS.d/next/Library/2025-11-22-20-30-00.gh-issue-141860.frksvr.rst b/Misc/NEWS.d/next/Library/2025-11-22-20-30-00.gh-issue-141860.frksvr.rst new file mode 100644 index 00000000000000..b1efd9c014f1f4 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-11-22-20-30-00.gh-issue-141860.frksvr.rst @@ -0,0 +1,5 @@ +Add an ``on_error`` keyword-only parameter to +:func:`multiprocessing.set_forkserver_preload` to control how import failures +during module preloading are handled. Accepts ``'ignore'`` (default, silent), +``'warn'`` (emit :exc:`ImportWarning`), or ``'fail'`` (raise exception). +Contributed by Nick Neumann and Gregory P. Smith.