2121import os
2222import tempfile
2323import time
24- from typing import Any
24+ from typing import Any , Optional
2525import warnings
2626
2727from jax ._src import compilation_cache
@@ -243,6 +243,7 @@ def compile_or_get_cached(
243243 devices : np .ndarray ,
244244 compile_options : xc .CompileOptions ,
245245 host_callbacks : Sequence [Any ],
246+ pgle_profiler : profiler .PGLEProfiler | None = None ,
246247) -> xc .LoadedExecutable :
247248 sym_name = computation .operation .attributes ['sym_name' ]
248249 module_name = ir .StringAttr (sym_name ).value
@@ -278,14 +279,55 @@ def compile_or_get_cached(
278279 return backend_compile (backend , computation , compile_options ,
279280 host_callbacks )
280281
282+ is_multi_process = (
283+ len ({device .process_index for device in devices .flatten ()}) > 1 )
284+ min_device_process_id = (
285+ min (devices .flatten (), key = lambda device : device .id ).process_index )
286+
287+ # When PGLE is enabled there might be 3 types of situations:
288+ # 1. PGLE profiled module (the one which was recompiled with FDO profile) is
289+ # in the persistent cache. In this case the module should be returned from
290+ # cache and PGLE should be disabled for this module. Is module is stored in
291+ # the persistent cache under the "pgle_profiled_module_key" which calculated
292+ # with replacing FDO profile with flag which identify that module were PGLE
293+ # profiled.
294+ # 2. PGLE profiled module is not in the persistent cache and the module is
295+ # getting built with an FDO profile. In this case we need to share FDO profile
296+ # with other processes and store the result under the
297+ # "pgle_profiled_module_key" so later in case 1 we will be able to find the
298+ # module.
299+ # 3. PGLE profiled module is not in the persistent cache and the module is
300+ # getting compiled to be PGLEd (FDO profile is empty). In this case we need to
301+ # simply return the non-PGLE profiled module from the persistent cache.
302+ if (config .enable_pgle .value
303+ and config .pgle_profiling_runs .value > 0 ):
304+ fdo_profile = compile_options .executable_build_options .fdo_profile
305+ compile_options .executable_build_options .fdo_profile = b"pgle profiled"
306+
307+ pgle_profiled_module_key = compilation_cache .get_cache_key (
308+ computation , devices , compile_options , backend )
309+ compile_options .executable_build_options .fdo_profile = fdo_profile
310+
311+ if _is_executable_in_cache (pgle_profiled_module_key ):
312+ # Load PGLE profiled module from the persistent cache.
313+ cache_key = pgle_profiled_module_key
314+ if pgle_profiler is not None :
315+ pgle_profiler .disable ()
316+ elif fdo_profile is not None and len (fdo_profile ) > 0 :
317+ # Store module under PGLE profiled module cache key.
318+ cache_key = pgle_profiled_module_key
319+ if is_multi_process and distributed .global_state .client is not None :
320+ compile_options .executable_build_options .fdo_profile = _share_fdo_profiles (
321+ computation , devices , compile_options , backend ,
322+ distributed .global_state .client ,
323+ min_device_process_id
324+ )
325+
281326 cache_retrieval_start = time .monotonic ()
282327 retrieved_executable , retrieved_compile_time = _cache_read (
283328 module_name , cache_key , compile_options , backend )
284329 cache_retrieval_time = time .monotonic () - cache_retrieval_start
285330
286-
287- is_multi_process = (
288- len ({device .process_index for device in devices .flatten ()}) > 1 )
289331 if retrieved_executable is not None :
290332 assert retrieved_compile_time is not None
291333 logger .debug ("Persistent compilation cache hit for '%s'" , module_name )
@@ -315,7 +357,7 @@ def compile_or_get_cached(
315357 distributed .global_state .client ,
316358 module_name ,
317359 cache_key ,
318- min ( devices . flatten (), key = lambda device : device . id ). process_index
360+ min_device_process_id
319361 )
320362 elif (
321363 config .share_autotune_config_between_hosts .value
@@ -330,7 +372,7 @@ def compile_or_get_cached(
330372 distributed .global_state .client ,
331373 module_name ,
332374 cache_key ,
333- min ( devices . flatten (), key = lambda device : device . id ). process_index
375+ min_device_process_id
334376 )
335377 else :
336378 return _compile_and_write_cache (
@@ -342,6 +384,58 @@ def compile_or_get_cached(
342384 cache_key ,
343385 )
344386
387+ # The process that has the lowest device ID should share FDO profile before
388+ # compilation with other processes.
389+ def _share_fdo_profiles (
390+ computation : ir .Module ,
391+ devices : np .ndarray ,
392+ compile_options : xc .CompileOptions ,
393+ backend : xc .Client ,
394+ global_client : lib .xla_extension .DistributedRuntimeClient ,
395+ min_process_id
396+ ) -> Optional [bytes ]:
397+ sym_name = computation .operation .attributes ['sym_name' ]
398+ module_name = ir .StringAttr (sym_name ).value
399+ fdo_profile = compile_options .executable_build_options .fdo_profile
400+ if fdo_profile is None or len (fdo_profile ) == 0 :
401+ return fdo_profile
402+
403+ compile_options .executable_build_options .fdo_profile = b""
404+ profile_key = (
405+ compilation_cache .get_cache_key (
406+ computation , devices , compile_options , backend
407+ )
408+ + "_fdo_sync"
409+ )
410+ if profile_key in _share_fdo_profiles .modules_profiles :
411+ return _share_fdo_profiles .modules_profiles [profile_key ]
412+
413+ share_timeout = config .share_binary_between_hosts_timeout_ms .value
414+ if distributed .global_state .process_id == min_process_id :
415+ logger .debug (
416+ "Sharing FDO profile: %s. For module %s. Process %d." ,
417+ fdo_profile ,
418+ module_name ,
419+ min_process_id ,
420+ )
421+ global_client .key_value_set_bytes (profile_key , fdo_profile )
422+ else :
423+ logger .debug (
424+ "Waiting for FDO profile: %s. For module %s. Should be set by process %d." ,
425+ fdo_profile ,
426+ module_name ,
427+ min_process_id ,
428+ )
429+ fdo_profile = global_client .blocking_key_value_get_bytes (
430+ profile_key , share_timeout
431+ )
432+
433+ _share_fdo_profiles .modules_profiles [profile_key ] = fdo_profile
434+ return fdo_profile
435+
436+
437+ _share_fdo_profiles .modules_profiles = {}
438+
345439
346440# The process with the first_process_id should compile the module and write an
347441# autotune config to the K-V storage.
@@ -520,6 +614,20 @@ def _compile_and_write_cache(
520614 )
521615 return executable
522616
617+ def _is_executable_in_cache (cache_key ) -> bool :
618+ """Checks if executable is presented in cache on a given key
619+ """
620+ try :
621+ return compilation_cache .is_executable_in_cache (cache_key )
622+ except Exception as ex :
623+ if config .raise_persistent_cache_errors .value :
624+ raise
625+ warnings .warn (
626+ f"Error reading persistent compilation cache entry for "
627+ f"'{ cache_key } ': { type (ex ).__name__ } : { ex } " )
628+ return False
629+
630+
523631def _cache_read (
524632 module_name : str , cache_key : str , compile_options : xc .CompileOptions ,
525633 backend : xc .Client
0 commit comments