Skip to content

Commit 6a1397e

Browse files
Merge pull request #33848 from gnecula:export_memory_space
PiperOrigin-RevId: 842754890
2 parents f837ebc + 723b90e commit 6a1397e

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# ruff: noqa
16+
17+
# Pasted from the test output (see export_serialization_back_compat_test.py module docstring)
18+
serializations = [
19+
dict(
20+
serialization_version=5,
21+
exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00L\x00J\x00D\x00@\x00<\x008\x004\x00.\x00(\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0e\x00\x08\x00\x07\x00\x00\x000\x00*\x00\x00\x00\x00\x00\x00\x01D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00X\x02\x00\x00\x80\x02\x00\x00\x88\x02\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\xa0\x02\x00\x00\xcc\x02\x00\x00\xcc\x02\x00\x00\x04\x03\x00\x00X\x03\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00 \x02\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01\x1f\x07\x01\x05\t\t\x01\x03\x0f\x03\x03\x13\x05\x05\x17\x1b\x03kE\x0f\x01\x1b\x07\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x1f\x0b\x0b\x13\x13\x0b\x0b\x1b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x05\x0b\x17\x0f\x1b\x07\x07\x02\xf9\x1f\x05\t\x03\x07\x07\t\x0b\r\x0f\x11\x05\x0f\x11\x03\x01\x05\x11\x11\x01\t\x05\x13\x11\x01\x05\x05\x15\t\x03\x1d\x19\x01\x05\x17\x05\x03\x1d\x01\x03\x17\t\r\x15\x05!%\x01\x0b\x03#\x01\x01\t\x17\x01\x0b\x01\x01\x01\x1d\x19\x1d\x1b\x03\x05-3\r\x03/1\x1d\x1d\x1d\x1f\r\x05\')5\x1f\x1d!#\t\x03\x03;\r\x05=?\')\x1d#\x1d%\x1d\'\x1d)\x01\x02\x02\x01\t)\x05\t\r\r)\x01\x0b\x11\x05\x07\x05\x03\x05\x1b\t\x04I\x05\x01Q\x01\x05\x01\x07\x047\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04\x1b\x03\x05\x07\x05\r\x0b\x17\x00\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\x1a\x04+\x0f\x0b\x0f!\x1b!)\x19#\x05\x19%)9\x15\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda\x00x\x00mhlo.memory_kind\x00pinned_host\x00jax.global_constant\x00_platform_index\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08\x1b\x07\x05\'\x01\x05\x1b\x03\x0b+79AC\x02\x00\x00\x00\x14\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00cuda\x00\x00\x00\x00\x03\x00\x00\x00tpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x08\x00\x00\x00<lambda>\x00\x00\x00\x00"),
22+
),
23+
]

tests/export_serialization_back_compat_test.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
* Create a new test method, with a function to be serialized that exercises
2424
the feature you want to test, and a call to self.export_and_serialize.
2525
You can follow the model of the tests below, which are parameterized by
26-
the test data. Use `None` for the test data to signal that you want to
27-
use a fresh serialization.
26+
the testdata. Use only `None` for the testdata parameter to signal that
27+
you want to use a current serialization and not a saved one.
2828
* Run the test. This will save the serialized data in
2929
TEST_UNDECLARED_OUTPUTS_DIR (or "/tmp/back_compat_testdata" if not set).
3030
* Copy the test data defined in the output file, to the file
@@ -55,13 +55,15 @@
5555

5656
import jax
5757
from jax._src import config
58+
from jax._src import core
5859
from jax._src.export import _export
5960
from jax._src.export.serialization import _SERIALIZATION_VERSION
6061
from jax.sharding import PartitionSpec as P
6162
from jax._src import test_util as jtu
6263

6364
from jax._src.internal_test_util.export_back_compat_test_data import export_with_specified_sharding
6465
from jax._src.internal_test_util.export_back_compat_test_data import export_with_unspecified_sharding
66+
from jax._src.internal_test_util.export_back_compat_test_data import export_with_memory_space
6567

6668
config.parse_flags_with_absl()
6769
jtu.request_cpu_devices(8)
@@ -75,14 +77,15 @@ def setUp(self):
7577

7678
def export_and_serialize(self, fun, *args,
7779
vjp_order=0,
80+
platforms=None,
7881
**kwargs) -> bytearray:
7982
"""Export and serialize a function.
8083
8184
The test data is saved in TEST_UNDECLARED_OUTPUTS_DIR (or
8285
"/tmp/back_compat_testdata" if not set) and should be copied as explained
8386
in the module docstring.
8487
"""
85-
exp = _export.export(fun)(*args, **kwargs)
88+
exp = _export.export(fun, platforms=platforms)(*args, **kwargs)
8689
serialized = exp.serialize(vjp_order=vjp_order)
8790
updated_testdata = f"""
8891
# Paste to the test data file (see export_serialization_back_compat_test.py module docstring)
@@ -98,7 +101,8 @@ def export_and_serialize(self, fun, *args,
98101
"/tmp/back_compat_testdata")
99102
if not os.path.exists(output_dir):
100103
os.makedirs(output_dir)
101-
output_file = os.path.join(output_dir, f"export_{self._testMethodName}.py")
104+
output_file_basename = f"export_{self._testMethodName.replace('test_', '')}.py"
105+
output_file = os.path.join(output_dir, output_file_basename)
102106
logging.info("Writing the updated serialized Exported at %s", output_file)
103107
with open(output_file, "w") as f:
104108
f.write(updated_testdata)
@@ -163,5 +167,38 @@ def f(b):
163167
self.assertEqual(out.addressable_shards[1].index, (slice(8, 16), slice(None)))
164168

165169

170+
@jtu.parameterized_filterable(
171+
kwargs=[
172+
dict(testdata=testdata,
173+
testcase_name=("current" if testdata is None
174+
else f"v{testdata['serialization_version']}"))
175+
for testdata in [None, *export_with_memory_space.serializations]
176+
]
177+
)
178+
def test_with_memory_space(self, testdata: dict[str, Any] | None):
179+
# This test is based on export_test.py::test_memory_space_from_arg
180+
mesh = jtu.create_mesh((2,), "x")
181+
with jax.set_mesh(mesh):
182+
shd = jax.sharding.NamedSharding(mesh, P("x", None),
183+
memory_kind="pinned_host")
184+
a = jax.device_put(np.ones((2, 3), dtype=np.float32), shd)
185+
f = jax.jit(lambda x: x)
186+
187+
if testdata is None:
188+
serialized = self.export_and_serialize(
189+
f, a, platforms=("tpu", "cuda"))
190+
else:
191+
serialized = testdata["exported_serialized"]
192+
193+
exported = _export.deserialize(serialized)
194+
self.assertEqual(exported.in_avals[0].memory_space, core.MemorySpace.Host)
195+
self.assertEqual(exported.out_avals[0].memory_space, core.MemorySpace.Host)
196+
197+
if jtu.device_under_test() in ("tpu", "gpu"):
198+
b = exported.call(a)
199+
self.assertEqual(b.aval.memory_space, core.MemorySpace.Host)
200+
self.assertEqual(b.sharding, a.sharding)
201+
202+
166203
if __name__ == "__main__":
167204
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)