Skip to content

Commit 4e909f0

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[XLA:Python] Fix Python profiler hooks under Python free threading.
PiperOrigin-RevId: 815072110
1 parent b6aa12a commit 4e909f0

File tree

3 files changed

+46
-8
lines changed

3 files changed

+46
-8
lines changed

jaxlib/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# Please suffix the version number with a brief description of your change
4848
# in a comment. The goal here is to force a merge conflict if two changes
4949
# attempt to grab the same version number.
50-
_version = 378 # Changed compile() signature to accept an mlir ModuleOp
50+
_version = 379 # Fixed thread safety issue in profiler.
5151

5252
# An internal increasing version number for protecting jaxlib code against
5353
# ifrt changes.

tests/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,10 @@ jax_multiplatform_test(
11861186
tags = ["multiaccelerator"],
11871187
deps = [
11881188
"//jax/_src:profiler",
1189-
] + py_deps("absl/testing"),
1189+
] + py_deps([
1190+
"absl/testing",
1191+
"portpicker",
1192+
]),
11901193
)
11911194

11921195
jax_multiplatform_test(

tests/profiler_test.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import concurrent.futures
1516
from functools import partial
1617
import glob
1718
import os
@@ -28,6 +29,7 @@
2829
import jax.numpy as jnp
2930
import jax.profiler
3031
import jax._src.test_util as jtu
32+
from jax._src.lib import jaxlib_extension_version
3133
from jax._src import profiler
3234
from jax import jit
3335

@@ -45,17 +47,21 @@
4547
jax.config.parse_flags_with_absl()
4648

4749

48-
@jtu.thread_unsafe_test_class() # profiler isn't thread-safe
50+
# We do not allow multiple concurrent profiler sessions.
51+
@jtu.thread_unsafe_test_class()
4952
class ProfilerTest(unittest.TestCase):
5053
# These tests simply test that the profiler API does not crash; they do not
5154
# check functional correctness.
5255

5356
def setUp(self):
54-
if sys.version_info >= (3, 14) and jtu.TEST_NUM_THREADS.value > 1:
55-
# TODO(phawkins): try reenabling these after
56-
# https://github.com/python/cpython/issues/132817 is fixed. Simply
57-
# installing the profiler hook is unsafe if there are multiple threads.
58-
self.skipTest("Profiler tests are not thread-safe under Python 3.14")
57+
if (
58+
sys.version_info < (3, 14)
59+
and hasattr(sys, "_is_gil_enabled")
60+
and not sys._is_gil_enabled()
61+
):
62+
self.skipTest(
63+
"Profiler tests are not thread-safe under Python 3.13 free threading"
64+
)
5965

6066
super().setUp()
6167
self.worker_start = threading.Event()
@@ -103,6 +109,35 @@ def testProgrammaticProfiling(self):
103109
self.assertIn(b"/device:TPU", proto)
104110
self.assertIn(b"pxla.py", proto)
105111

112+
@unittest.skipIf(
113+
jaxlib_extension_version < 379, "Requires jaxlib 0.8 or later."
114+
)
115+
def testProgrammaticProfilingConcurrency(self):
116+
def work():
117+
x = jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
118+
jnp.ones(jax.local_device_count()))
119+
jax.block_until_ready(x)
120+
with tempfile.TemporaryDirectory() as tmpdir:
121+
try:
122+
jax.profiler.start_trace(tmpdir)
123+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
124+
for _ in range(10):
125+
executor.submit(work)
126+
finally:
127+
jax.profiler.stop_trace()
128+
129+
proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
130+
recursive=True)
131+
self.assertEqual(len(proto_path), 1)
132+
with open(proto_path[0], "rb") as f:
133+
proto = f.read()
134+
# Sanity check that serialized proto contains host, device, and
135+
# Python traces without deserializing.
136+
self.assertIn(b"/host:CPU", proto)
137+
if jtu.test_device_matches(["tpu"]):
138+
self.assertIn(b"/device:TPU", proto)
139+
self.assertIn(b"pxla.py", proto)
140+
106141
def testProgrammaticProfilingWithOptions(self):
107142
with tempfile.TemporaryDirectory() as tmpdir:
108143
try:

0 commit comments

Comments
 (0)