diff --git a/daal4py/sklearn/_n_jobs_support.py b/daal4py/sklearn/_n_jobs_support.py index 78d9cf0208..74fe707825 100644 --- a/daal4py/sklearn/_n_jobs_support.py +++ b/daal4py/sklearn/_n_jobs_support.py @@ -19,11 +19,11 @@ import threading from functools import wraps from inspect import Parameter, signature -from multiprocessing import cpu_count from numbers import Integral from warnings import warn import threadpoolctl +from joblib import cpu_count from daal4py import daalinit as set_n_threads from daal4py import num_threads as get_n_threads @@ -46,24 +46,26 @@ def get_suggested_n_threads(n_cpus): Usually, limit is equal to `n_logical_cpus` // `n_jobs`. Returns None if limit is not set. """ - n_threads_map = { - lib_ctl.internal_api: lib_ctl.get_num_threads() - for lib_ctl in threadpool_controller.lib_controllers - if lib_ctl.internal_api != "mkl" - } - # openBLAS is limited to 24, 64 or 128 threads by default - # depending on SW/HW configuration. - # thus, these numbers of threads from openBLAS are uninformative - if "openblas" in n_threads_map and n_threads_map["openblas"] in [24, 64, 128]: - del n_threads_map["openblas"] - # remove default values equal to n_cpus as uninformative - for backend in list(n_threads_map.keys()): - if n_threads_map[backend] == n_cpus: - del n_threads_map[backend] - if len(n_threads_map) > 0: - return min(n_threads_map.values()) - else: - return None + + # Comment 2025-11-18: as of joblib>=1.5.2, by the point that this section + # is reached under a joblib job (e.g. as triggered by sklearn metaestimators) + # or under a threadpoolctl context that doesn't specify 'api', limits for + # openmp will always be set - in the case of joblib, to the number of threads + # divided by the number of parallel jobs, and in the case of threadpoolctl, + # to the number that is passed under the context. However, limits for other + # components like 'openblas' would be set under some setups but not others + # (e.g. if installing SciPy from pip with its bundled openblas, but not if + # installing it from conda-forge with MKL as BLAS backend), and might be set + # by the user to something that doesn't match with the number of parallel + # jobs from joblib - hence this looks at the openmp configuration, even + # though openmp is not used by oneDAL. + for lib_ctl in threadpool_controller.lib_controllers: + if lib_ctl.internal_api == "openmp": + n_threads = lib_ctl.get_num_threads() + # remove default values equal to n_cpus as uninformative + if n_threads is not None and n_threads != n_cpus: + return n_threads + return None def _run_with_n_jobs(method):