Skip to content

Commit 49996a4

Browse files
lukebaumannGoogle-ML-Automation
authored andcommitted
Expose profiler advanced configuration as a Python dict.
In profiler.cc, the advanced_configuration property of tensorflow::ProfileOptions is now exposed as a Python dictionary. The getter converts the proto map to a nb::dict, handling different value types (bool, int64, string). Example error: ``` ProfileOptions().advanced_configuration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Unable to convert function return value to a Python type! The signature was (self) -> proto2::Map<std::__u::basic_string<char, std::__u::char_traits<char>, std::__u::allocator<char>>, tensorflow::ProfileOptions_AdvancedConfigValue> ``` PiperOrigin-RevId: 841944949
1 parent 2a6de35 commit 49996a4

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/profiler_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import jax.numpy as jnp
3131
import jax.profiler
3232
import jax._src.test_util as jtu
33+
import jaxlib
3334

3435
from jax._src import profiler
3536
from jax import jit
@@ -508,5 +509,18 @@ def on_profile():
508509
unittest.mock.ANY,
509510
)
510511

512+
@unittest.skipIf(jaxlib.version <= (0, 8, 2), "advanced_configuration getter is added in jaxlib 0.8.2")
513+
def test_advanced_configuration_getter(self):
514+
options = jax.profiler.ProfileOptions()
515+
advanced_config = {
516+
"tpu_trace_mode": "TRACE_COMPUTE",
517+
"tpu_num_sparse_cores_to_trace": 1,
518+
"enableFwThrottleEvent": True,
519+
}
520+
options.advanced_configuration = advanced_config
521+
returned_config = options.advanced_configuration
522+
self.assertIsInstance(returned_config, dict)
523+
self.assertEqual(returned_config, advanced_config)
524+
511525
if __name__ == "__main__":
512526
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)