Skip to content

Commit 6cb78e5

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add LoadedExecutable.serialize() to match Executable.serialize().
PiperOrigin-RevId: 841880441
1 parent d5bf94e commit 6cb78e5

File tree

4 files changed

+12
-2
lines changed

4 files changed

+12
-2
lines changed

jax/_src/compilation_cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,10 @@ def put_executable_and_time(
261261
" since cache is disabled/not initialized", cache_key)
262262
return
263263

264-
serialized_executable = backend.serialize_executable(executable)
264+
if hasattr(executable, "serialize") or xla_client._version >= 389:
265+
serialized_executable = executable.serialize()
266+
else:
267+
serialized_executable = backend.serialize_executable(executable)
265268
executable_and_time = combine_executable_and_time(
266269
serialized_executable, compile_time)
267270
executable_and_time = compress_executable(executable_and_time)

jaxlib/_jax/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,7 @@ class LoadedExecutable:
11991199
def client(self) -> Client: ...
12001200
def local_devices(self) -> list[Device]: ...
12011201
def get_hlo_text(self) -> str: ...
1202+
def serialize(self) -> bytes: ...
12021203
def size_of_generated_code_in_bytes(self) -> int: ...
12031204
def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...
12041205
def execute_sharded(

jaxlib/py_executable.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,12 @@ void PyLoadedExecutable::Register(nb::module_& m) {
545545
.def("get_hlo_text",
546546
xla::ValueOrThrowWrapper(
547547
&PyLoadedExecutable::GetHumanReadableProgramText))
548+
.def("serialize",
549+
[](const PyLoadedExecutable& exec) -> nb::bytes {
550+
std::string serialized =
551+
xla::ValueOrThrow(exec.ifrt_loaded_executable()->Serialize());
552+
return nb::bytes(serialized.data(), serialized.size());
553+
})
548554
.def("size_of_generated_code_in_bytes",
549555
&PyLoadedExecutable::SizeOfGeneratedCodeInBytes)
550556
.def(

jaxlib/xla_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# Please suffix the version number with a brief description of your change
4848
# in a comment. The goal here is to force a merge conflict if two changes
4949
# attempt to grab the same version number.
50-
_version = 388 # Add ArrayMeta
50+
_version = 389 # LoadedExecutable.serialize
5151

5252
# An internal increasing version number for protecting jaxlib code against
5353
# ifrt changes.

0 commit comments

Comments
 (0)