Skip to content

Commit 26f9820

Browse files
Google-ML-Automationjax authors
authored andcommitted
[JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data. 1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results. 2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner. 3. The profile session runner should be passed to pxla.py and then called. 4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h 5. Once FDO is collected we need to share it between hosts to keep deterministic compilation. PiperOrigin-RevId: 638197166
1 parent 741d1d3 commit 26f9820

File tree

9 files changed

+557
-60
lines changed

9 files changed

+557
-60
lines changed

jax/_src/compilation_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ def decompress_executable(executable):
157157
else:
158158
return zlib.decompress(executable)
159159

160+
161+
def is_executable_in_cache(cache_key: str) -> bool:
162+
"""Checks if the executable is in the cache."""
163+
cache = _get_cache()
164+
if cache is None:
165+
return False
166+
167+
# TODO(patrios): add check cache key method to cache interface.
168+
executable_and_time = cache.get(cache_key)
169+
return executable_and_time is not None
170+
171+
160172
def get_executable_and_time(
161173
cache_key: str, compile_options, backend
162174
) -> tuple[xla_client.LoadedExecutable | None, int | None]:

jax/_src/compiler.py

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import tempfile
2323
import time
24-
from typing import Any
24+
from typing import Any, Optional
2525
import warnings
2626

2727
from jax._src import compilation_cache
@@ -243,6 +243,7 @@ def compile_or_get_cached(
243243
devices: np.ndarray,
244244
compile_options: xc.CompileOptions,
245245
host_callbacks: Sequence[Any],
246+
pgle_profiler: profiler.PGLEProfiler | None = None,
246247
) -> xc.LoadedExecutable:
247248
sym_name = computation.operation.attributes['sym_name']
248249
module_name = ir.StringAttr(sym_name).value
@@ -278,14 +279,55 @@ def compile_or_get_cached(
278279
return backend_compile(backend, computation, compile_options,
279280
host_callbacks)
280281

282+
is_multi_process = (
283+
len({device.process_index for device in devices.flatten()}) > 1)
284+
min_device_process_id = (
285+
min(devices.flatten(), key=lambda device: device.id).process_index)
286+
287+
# When PGLE is enabled there might be 3 types of situations:
288+
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
289+
# in the persistent cache. In this case the module should be returned from
290+
# cache and PGLE should be disabled for this module. Is module is stored in
291+
# the persistent cache under the "pgle_profiled_module_key" which calculated
292+
# with replacing FDO profile with flag which identify that module were PGLE
293+
# profiled.
294+
# 2. PGLE profiled module is not in the persistent cache and the module is
295+
# getting built with an FDO profile. In this case we need to share FDO profile
296+
# with other processes and store the result under the
297+
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
298+
# module.
299+
# 3. PGLE profiled module is not in the persistent cache and the module is
300+
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
301+
# simply return the non-PGLE profiled module from the persistent cache.
302+
if (config.enable_pgle.value
303+
and config.pgle_profiling_runs.value > 0):
304+
fdo_profile = compile_options.executable_build_options.fdo_profile
305+
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
306+
307+
pgle_profiled_module_key = compilation_cache.get_cache_key(
308+
computation, devices, compile_options, backend)
309+
compile_options.executable_build_options.fdo_profile = fdo_profile
310+
311+
if _is_executable_in_cache(pgle_profiled_module_key):
312+
# Load PGLE profiled module from the persistent cache.
313+
cache_key = pgle_profiled_module_key
314+
if pgle_profiler is not None:
315+
pgle_profiler.disable()
316+
elif fdo_profile is not None and len(fdo_profile) > 0:
317+
# Store module under PGLE profiled module cache key.
318+
cache_key = pgle_profiled_module_key
319+
if is_multi_process and distributed.global_state.client is not None:
320+
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
321+
computation, devices, compile_options, backend,
322+
distributed.global_state.client,
323+
min_device_process_id
324+
)
325+
281326
cache_retrieval_start = time.monotonic()
282327
retrieved_executable, retrieved_compile_time = _cache_read(
283328
module_name, cache_key, compile_options, backend)
284329
cache_retrieval_time = time.monotonic() - cache_retrieval_start
285330

286-
287-
is_multi_process = (
288-
len({device.process_index for device in devices.flatten()}) > 1)
289331
if retrieved_executable is not None:
290332
assert retrieved_compile_time is not None
291333
logger.debug("Persistent compilation cache hit for '%s'", module_name)
@@ -315,7 +357,7 @@ def compile_or_get_cached(
315357
distributed.global_state.client,
316358
module_name,
317359
cache_key,
318-
min(devices.flatten(), key=lambda device: device.id).process_index
360+
min_device_process_id
319361
)
320362
elif (
321363
config.share_autotune_config_between_hosts.value
@@ -330,7 +372,7 @@ def compile_or_get_cached(
330372
distributed.global_state.client,
331373
module_name,
332374
cache_key,
333-
min(devices.flatten(), key=lambda device: device.id).process_index
375+
min_device_process_id
334376
)
335377
else:
336378
return _compile_and_write_cache(
@@ -342,6 +384,58 @@ def compile_or_get_cached(
342384
cache_key,
343385
)
344386

387+
# The process that has the lowest device ID should share FDO profile before
388+
# compilation with other processes.
389+
def _share_fdo_profiles(
390+
computation: ir.Module,
391+
devices: np.ndarray,
392+
compile_options: xc.CompileOptions,
393+
backend: xc.Client,
394+
global_client: lib.xla_extension.DistributedRuntimeClient,
395+
min_process_id
396+
) -> Optional[bytes]:
397+
sym_name = computation.operation.attributes['sym_name']
398+
module_name = ir.StringAttr(sym_name).value
399+
fdo_profile = compile_options.executable_build_options.fdo_profile
400+
if fdo_profile is None or len(fdo_profile) == 0:
401+
return fdo_profile
402+
403+
compile_options.executable_build_options.fdo_profile = b""
404+
profile_key = (
405+
compilation_cache.get_cache_key(
406+
computation, devices, compile_options, backend
407+
)
408+
+ "_fdo_sync"
409+
)
410+
if profile_key in _share_fdo_profiles.modules_profiles:
411+
return _share_fdo_profiles.modules_profiles[profile_key]
412+
413+
share_timeout = config.share_binary_between_hosts_timeout_ms.value
414+
if distributed.global_state.process_id == min_process_id:
415+
logger.debug(
416+
"Sharing FDO profile: %s. For module %s. Process %d.",
417+
fdo_profile,
418+
module_name,
419+
min_process_id,
420+
)
421+
global_client.key_value_set_bytes(profile_key, fdo_profile)
422+
else:
423+
logger.debug(
424+
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
425+
fdo_profile,
426+
module_name,
427+
min_process_id,
428+
)
429+
fdo_profile = global_client.blocking_key_value_get_bytes(
430+
profile_key, share_timeout
431+
)
432+
433+
_share_fdo_profiles.modules_profiles[profile_key] = fdo_profile
434+
return fdo_profile
435+
436+
437+
_share_fdo_profiles.modules_profiles = {}
438+
345439

346440
# The process with the first_process_id should compile the module and write an
347441
# autotune config to the K-V storage.
@@ -520,6 +614,20 @@ def _compile_and_write_cache(
520614
)
521615
return executable
522616

617+
def _is_executable_in_cache(cache_key) -> bool:
618+
"""Checks if executable is presented in cache on a given key
619+
"""
620+
try:
621+
return compilation_cache.is_executable_in_cache(cache_key)
622+
except Exception as ex:
623+
if config.raise_persistent_cache_errors.value:
624+
raise
625+
warnings.warn(
626+
f"Error reading persistent compilation cache entry for "
627+
f"'{cache_key}': {type(ex).__name__}: {ex}")
628+
return False
629+
630+
523631
def _cache_read(
524632
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
525633
backend: xc.Client

jax/_src/config.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def trace_context():
217217
debug_key_reuse.value,
218218
jax_xla_profile_version.value,
219219
# Technically this affects jaxpr->stablehlo lowering, not tracing.
220-
hlo_source_file_canonicalization_regex.value)
220+
hlo_source_file_canonicalization_regex.value,
221+
pgle_profiling_runs.value,
222+
enable_pgle.value)
221223

222224
config = Config()
223225

@@ -815,6 +817,8 @@ class _GlobalExtraJitContext(NamedTuple):
815817
threefry_gpu_kernel_lowering: bool = False
816818
softmax_custom_jvp: bool = False
817819
xla_profile_version: int = 0
820+
pgle_profiling_runs: int = 0
821+
enable_pgle: bool = False
818822

819823

820824
def _update_global_jit_state(**kw):
@@ -850,6 +854,8 @@ class _ThreadLocalExtraJitContext(NamedTuple):
850854
threefry_gpu_kernel_lowering: bool | None = None
851855
softmax_custom_jvp: bool | None = None
852856
xla_profile_version: int | None = None
857+
pgle_profiling_runs: int | None = None
858+
enable_pgle: bool | None = None
853859

854860

855861
class _ThreadLocalStateCache(threading.local):
@@ -1221,6 +1227,42 @@ def _update_jax_memories_thread_local(val):
12211227
help='Timeout for the compiled module share.',
12221228
)
12231229

1230+
enable_pgle = define_bool_state(
1231+
name='jax_enable_pgle',
1232+
default=False,
1233+
help=(
1234+
'If set to True and the property jax_pgle_profiling_runs is set to '
1235+
'greater than 0, the modules will be recompiled after running specified '
1236+
'number times with collected data provided to the profile guided latency '
1237+
'estimator.'
1238+
),
1239+
update_global_hook=lambda val: _update_global_jit_state(enable_pgle=val),
1240+
update_thread_local_hook=lambda val: update_thread_local_jit_state(
1241+
enable_pgle=val),
1242+
)
1243+
1244+
pgle_profiling_runs = define_int_state(
1245+
name='jax_pgle_profiling_runs',
1246+
default=3,
1247+
help=(
1248+
'Amount of times module should be profiled before recompilation when '
1249+
'PGLE is used.'
1250+
),
1251+
update_global_hook=lambda val: _update_global_jit_state(
1252+
pgle_profiling_runs=val
1253+
),
1254+
update_thread_local_hook=lambda val: update_thread_local_jit_state(
1255+
pgle_profiling_runs=val
1256+
),
1257+
)
1258+
1259+
pgle_aggregation_percentile = define_int_state(
1260+
name='jax_pgle_aggregation_percentile',
1261+
default=90,
1262+
help='Percentile used to aggregate performance data between devices when '
1263+
'PGLE is used.',
1264+
)
1265+
12241266
enable_compilation_cache = define_bool_state(
12251267
name='jax_enable_compilation_cache',
12261268
default=True,

0 commit comments

Comments
 (0)