Skip to content

Commit 2f62fb1

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix back compat test to ignore warnings
PiperOrigin-RevId: 842272730
1 parent ea5aee9 commit 2f62fb1

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

jax/_src/test_util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,8 +1248,6 @@ class JaxTestCase(parameterized.TestCase):
12481248
'jax_legacy_prng_key': 'error',
12491249
}
12501250

1251-
1252-
12531251
def setUp(self):
12541252
super().setUp()
12551253
self.enterContext(assert_global_configs_unchanged())

tests/export_back_compat_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ def func(x):
784784
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
785785
self.run_one_test(func, data)
786786

787+
@jtu.ignore_warning(category=DeprecationWarning,
788+
message='`with mesh:` context manager')
787789
def test_tpu_sharding(self):
788790
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
789791
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
@@ -1006,8 +1008,10 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol):
10061008
)
10071009

10081010

1009-
@jtu.with_config(jax_use_shardy_partitioner=True)
10101011
class ShardyCompatTest(bctu.CompatTestBase):
1012+
1013+
@jtu.ignore_warning(category=DeprecationWarning,
1014+
message='`with mesh:` context manager')
10111015
def test_shardy_sharding_ops_with_different_meshes(self):
10121016
# Tests whether we can save and load a module with meshes that have the
10131017
# same axis sizes (and same order) but different axis names.
@@ -1046,7 +1050,5 @@ def shard_map_func(x): # b: f32[2, 4]
10461050
expect_current_custom_calls=custom_call_targets_override)
10471051

10481052

1049-
1050-
10511053
if __name__ == "__main__":
10521054
absltest.main(testLoader=jtu.JaxTestLoader())

tests/fused_attention_stablehlo_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def dot_product_attention_fp8(query, key, value, fp8_metas):
264264
return out[0], (query_grad, key_grad, value_grad)
265265

266266

267+
@jtu.ignore_warning(category=DeprecationWarning,
268+
message='`with mesh:` context manager')
267269
class DotProductAttentionTest(jtu.JaxTestCase):
268270
def setUp(self):
269271
super().setUp()

0 commit comments

Comments
 (0)