diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 3088ced9872a..8ab4940e1bb5 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -508,5 +508,18 @@ def on_profile(): unittest.mock.ANY, ) + @unittest.skipIf(jax._src.lib.ifrt_version < 39, "advanced_configuration getter is newly added") + def test_advanced_configuration_getter(self): + options = jax.profiler.ProfileOptions() + advanced_config = { + "tpu_trace_mode": "TRACE_COMPUTE", + "tpu_num_sparse_cores_to_trace": 1, + "enableFwThrottleEvent": True, + } + options.advanced_configuration = advanced_config + returned_config = options.advanced_configuration + self.assertIsInstance(returned_config, dict) + self.assertEqual(returned_config, advanced_config) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())