|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import concurrent.futures |
15 | 16 | from functools import partial |
16 | 17 | import glob |
17 | 18 | import os |
|
28 | 29 | import jax.numpy as jnp |
29 | 30 | import jax.profiler |
30 | 31 | import jax._src.test_util as jtu |
| 32 | +from jax._src.lib import jaxlib_extension_version |
31 | 33 | from jax._src import profiler |
32 | 34 | from jax import jit |
33 | 35 |
|
|
45 | 47 | jax.config.parse_flags_with_absl() |
46 | 48 |
|
47 | 49 |
|
48 | | -@jtu.thread_unsafe_test_class() # profiler isn't thread-safe |
| 50 | +# We do not allow multiple concurrent profiler sessions. |
| 51 | +@jtu.thread_unsafe_test_class() |
49 | 52 | class ProfilerTest(unittest.TestCase): |
50 | 53 | # These tests simply test that the profiler API does not crash; they do not |
51 | 54 | # check functional correctness. |
52 | 55 |
|
53 | 56 | def setUp(self): |
54 | | - if sys.version_info >= (3, 14) and jtu.TEST_NUM_THREADS.value > 1: |
55 | | - # TODO(phawkins): try reenabling these after |
56 | | - # https://github.com/python/cpython/issues/132817 is fixed. Simply |
57 | | - # installing the profiler hook is unsafe if there are multiple threads. |
58 | | - self.skipTest("Profiler tests are not thread-safe under Python 3.14") |
| 57 | + if ( |
| 58 | + sys.version_info < (3, 14) |
| 59 | + and hasattr(sys, "_is_gil_enabled") |
| 60 | + and not sys._is_gil_enabled() |
| 61 | + ): |
| 62 | + self.skipTest( |
| 63 | + "Profiler tests are not thread-safe under Python 3.13 free threading" |
| 64 | + ) |
59 | 65 |
|
60 | 66 | super().setUp() |
61 | 67 | self.worker_start = threading.Event() |
@@ -103,6 +109,35 @@ def testProgrammaticProfiling(self): |
103 | 109 | self.assertIn(b"/device:TPU", proto) |
104 | 110 | self.assertIn(b"pxla.py", proto) |
105 | 111 |
|
| 112 | + @unittest.skipIf( |
| 113 | + jaxlib_extension_version < 379, "Requires jaxlib 0.8 or later." |
| 114 | + ) |
| 115 | + def testProgrammaticProfilingConcurrency(self): |
| 116 | + def work(): |
| 117 | + x = jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( |
| 118 | + jnp.ones(jax.local_device_count())) |
| 119 | + jax.block_until_ready(x) |
| 120 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 121 | + try: |
| 122 | + jax.profiler.start_trace(tmpdir) |
| 123 | + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
| 124 | + for _ in range(10): |
| 125 | + executor.submit(work) |
| 126 | + finally: |
| 127 | + jax.profiler.stop_trace() |
| 128 | + |
| 129 | + proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"), |
| 130 | + recursive=True) |
| 131 | + self.assertEqual(len(proto_path), 1) |
| 132 | + with open(proto_path[0], "rb") as f: |
| 133 | + proto = f.read() |
| 134 | + # Sanity check that serialized proto contains host, device, and |
| 135 | + # Python traces without deserializing. |
| 136 | + self.assertIn(b"/host:CPU", proto) |
| 137 | + if jtu.test_device_matches(["tpu"]): |
| 138 | + self.assertIn(b"/device:TPU", proto) |
| 139 | + self.assertIn(b"pxla.py", proto) |
| 140 | + |
106 | 141 | def testProgrammaticProfilingWithOptions(self): |
107 | 142 | with tempfile.TemporaryDirectory() as tmpdir: |
108 | 143 | try: |
|
0 commit comments