2323 * Create a new test method, with a function to be serialized that exercises
2424 the feature you want to test, and a call to self.export_and_serialize.
2525 You can follow the model of the tests below, which are parameterized by
26- the test data . Use `None` for the test data to signal that you want to
27- use a fresh serialization.
26+ the testdata . Use only `None` for the testdata parameter to signal that
27+ you want to use a current serialization and not a saved one .
2828 * Run the test. This will save the serialized data in
2929 TEST_UNDECLARED_OUTPUTS_DIR (or "/tmp/back_compat_testdata" if not set).
3030 * Copy the test data defined in the output file, to the file
5555
5656import jax
5757from jax ._src import config
58+ from jax ._src import core
5859from jax ._src .export import _export
5960from jax ._src .export .serialization import _SERIALIZATION_VERSION
6061from jax .sharding import PartitionSpec as P
6162from jax ._src import test_util as jtu
6263
6364from jax ._src .internal_test_util .export_back_compat_test_data import export_with_specified_sharding
6465from jax ._src .internal_test_util .export_back_compat_test_data import export_with_unspecified_sharding
66+ from jax ._src .internal_test_util .export_back_compat_test_data import export_with_memory_space
6567
6668config .parse_flags_with_absl ()
6769jtu .request_cpu_devices (8 )
@@ -75,14 +77,15 @@ def setUp(self):
7577
7678 def export_and_serialize (self , fun , * args ,
7779 vjp_order = 0 ,
80+ platforms = None ,
7881 ** kwargs ) -> bytearray :
7982 """Export and serialize a function.
8083
8184 The test data is saved in TEST_UNDECLARED_OUTPUTS_DIR (or
8285 "/tmp/back_compat_testdata" if not set) and should be copied as explained
8386 in the module docstring.
8487 """
85- exp = _export .export (fun )(* args , ** kwargs )
88+ exp = _export .export (fun , platforms = platforms )(* args , ** kwargs )
8689 serialized = exp .serialize (vjp_order = vjp_order )
8790 updated_testdata = f"""
8891 # Paste to the test data file (see export_serialization_back_compat_test.py module docstring)
@@ -98,7 +101,8 @@ def export_and_serialize(self, fun, *args,
98101 "/tmp/back_compat_testdata" )
99102 if not os .path .exists (output_dir ):
100103 os .makedirs (output_dir )
101- output_file = os .path .join (output_dir , f"export_{ self ._testMethodName } .py" )
104+ output_file_basename = f"export_{ self ._testMethodName .replace ('test_' , '' )} .py"
105+ output_file = os .path .join (output_dir , output_file_basename )
102106 logging .info ("Writing the updated serialized Exported at %s" , output_file )
103107 with open (output_file , "w" ) as f :
104108 f .write (updated_testdata )
@@ -163,5 +167,38 @@ def f(b):
163167 self .assertEqual (out .addressable_shards [1 ].index , (slice (8 , 16 ), slice (None )))
164168
165169
170+ @jtu .parameterized_filterable (
171+ kwargs = [
172+ dict (testdata = testdata ,
173+ testcase_name = ("current" if testdata is None
174+ else f"v{ testdata ['serialization_version' ]} " ))
175+ for testdata in [None , * export_with_memory_space .serializations ]
176+ ]
177+ )
178+ def test_with_memory_space (self , testdata : dict [str , Any ] | None ):
179+ # This test is based on export_test.py::test_memory_space_from_arg
180+ mesh = jtu .create_mesh ((2 ,), "x" )
181+ with jax .set_mesh (mesh ):
182+ shd = jax .sharding .NamedSharding (mesh , P ("x" , None ),
183+ memory_kind = "pinned_host" )
184+ a = jax .device_put (np .ones ((2 , 3 ), dtype = np .float32 ), shd )
185+ f = jax .jit (lambda x : x )
186+
187+ if testdata is None :
188+ serialized = self .export_and_serialize (
189+ f , a , platforms = ("tpu" , "cuda" ))
190+ else :
191+ serialized = testdata ["exported_serialized" ]
192+
193+ exported = _export .deserialize (serialized )
194+ self .assertEqual (exported .in_avals [0 ].memory_space , core .MemorySpace .Host )
195+ self .assertEqual (exported .out_avals [0 ].memory_space , core .MemorySpace .Host )
196+
197+ if jtu .device_under_test () in ("tpu" , "gpu" ):
198+ b = exported .call (a )
199+ self .assertEqual (b .aval .memory_space , core .MemorySpace .Host )
200+ self .assertEqual (b .sharding , a .sharding )
201+
202+
166203if __name__ == "__main__" :
167204 absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments