Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions daal4py/sklearn/_n_jobs_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import threading
from functools import wraps
from inspect import Parameter, signature
from multiprocessing import cpu_count
from numbers import Integral
from warnings import warn

import threadpoolctl
from joblib import cpu_count

from daal4py import daalinit as set_n_threads
from daal4py import num_threads as get_n_threads
Expand All @@ -46,16 +46,24 @@ def get_suggested_n_threads(n_cpus):
Usually, limit is equal to `n_logical_cpus` // `n_jobs`.
Returns None if limit is not set.
"""

# Comment 2025-11-18: as of joblib>=1.5.2, by the point that this section
# is reached under a joblib job (e.g. as triggered by sklearn metaestimators)
# or under a threadpoolctl context that doesn't specify 'api', limits for
# openmp will always be set - in the case of joblib, to the number of threads
# divided by the number of parallel jobs, and in the case of threadpoolctl,
# to the number that is passed under the context. However, limits for other
# components like 'openblas' would be set under some setups but not others
# (e.g. if installing SciPy from pip with its bundled openblas, but not if
# installing it from conda-forge with MKL as BLAS backend), and might be set
# by the user to something that doesn't match with the number of parallel
# jobs from joblib - hence this looks at the openmp configuration, even
# though openmp is not used by oneDAL.
n_threads_map = {
lib_ctl.internal_api: lib_ctl.get_num_threads()
for lib_ctl in threadpool_controller.lib_controllers
if lib_ctl.internal_api != "mkl"
if lib_ctl.internal_api == "openmp"
}
# openBLAS is limited to 24, 64 or 128 threads by default
# depending on SW/HW configuration.
# thus, these numbers of threads from openBLAS are uninformative
if "openblas" in n_threads_map and n_threads_map["openblas"] in [24, 64, 128]:
del n_threads_map["openblas"]
# remove default values equal to n_cpus as uninformative
for backend in list(n_threads_map.keys()):
if n_threads_map[backend] == n_cpus:
Expand Down
Loading