Skip to content

Commit cf6f0aa

Browse files
bchetiouijax authors
authored andcommitted
[Mosaic GPU] Add a synchronization point at the end of the constructor for BarrierArray.
Without such a synchronization point, calls to `mbarrier_init`s may end up happening after uses of the `mbarrier` being initialized---which is undefined behaviour and leads to deadlocks. This allows us to reenable the previously broken test cases. PiperOrigin-RevId: 638246527
1 parent 26f9820 commit cf6f0aa

File tree

2 files changed

+1
-12
lines changed

2 files changed

+1
-12
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def __init__(self, num_barriers: int, arrival_count: int = 1):
502502
with once():
503503
for i in range(num_barriers):
504504
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
505+
gpu.barrier()
505506

506507
def __iter__(self) -> Iterator["Barrier"]:
507508
for offset in range(self.num_barriers):

tests/mosaic/matmul_test.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,6 @@ def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype):
6060
if n < tile_n:
6161
self.skipTest(f"No use in running a test with {n=} < {tile_n=}.")
6262

63-
# TODO(bchetioui): investigate why this test case fails with error
64-
# Illegal barrier arrive operation
65-
# under memcheck.
66-
if tile_m == 64 and tile_n == 64 and stages == 2:
67-
self.skipTest("Broken test case---skipping.")
68-
6963
try:
7064
matmul.verify(
7165
m,
@@ -102,12 +96,6 @@ def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n, high_precision):
10296
if n < tile_n:
10397
self.skipTest(f"No use in running a test with {n=} < {tile_n=}.")
10498

105-
# TODO(bchetioui): investigate why this test case fails with error
106-
# Illegal barrier arrive operation
107-
# under memcheck.
108-
if tile_m == 64 and tile_n == 64 and stages == 2:
109-
self.skipTest("Broken test case---skipping.")
110-
11199
try:
112100
matmul.verify(
113101
m,

0 commit comments

Comments
 (0)