Skip to content

Commit ed07ee7

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
Reduce number of sampled tests and mark some more tests as thread unsafe.
PiperOrigin-RevId: 815059810
1 parent 30562ff commit ed07ee7

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

jax/experimental/jax2tf/tests/call_tf_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def _transfer_guard(guard_level):
871871
jax2tf.call_tf(tf_fun)(jax_array_on_gpu)
872872

873873

874+
@jtu.thread_unsafe_test_class()
874875
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
875876
"""Reloading output of jax2tf into JAX with call_tf."""
876877

jax/experimental/jax2tf/tests/shape_poly_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def check_shape_poly(tst, f_jax: Callable, *,
232232
return h.run_test(tst)
233233

234234

235+
@jtu.thread_unsafe_test_class()
235236
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
236237

237238
def test_simple_unary(self):

tests/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ jax_multiplatform_test(
14371437
},
14381438
# Use fewer cases to prevent timeouts.
14391439
backend_variant_args = {
1440-
"cpu": ["--jax_num_generated_cases=30"],
1440+
"cpu": ["--jax_num_generated_cases=20"],
14411441
"cpu_x32": ["--jax_num_generated_cases=30"],
14421442
"gpu_p100": ["--jax_num_generated_cases=40"],
14431443
"gpu_v100": ["--jax_num_generated_cases=40"],

0 commit comments

Comments
 (0)