File tree Expand file tree Collapse file tree 4 files changed +12
-2
lines changed
Expand file tree Collapse file tree 4 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -261,7 +261,10 @@ def put_executable_and_time(
261261 " since cache is disabled/not initialized" , cache_key )
262262 return
263263
264- serialized_executable = backend .serialize_executable (executable )
264+ if hasattr (executable , "serialize" ) or xla_client ._version >= 389 :
265+ serialized_executable = executable .serialize ()
266+ else :
267+ serialized_executable = backend .serialize_executable (executable )
265268 executable_and_time = combine_executable_and_time (
266269 serialized_executable , compile_time )
267270 executable_and_time = compress_executable (executable_and_time )
Original file line number Diff line number Diff line change @@ -1199,6 +1199,7 @@ class LoadedExecutable:
11991199 def client (self ) -> Client : ...
12001200 def local_devices (self ) -> list [Device ]: ...
12011201 def get_hlo_text (self ) -> str : ...
1202+ def serialize (self ) -> bytes : ...
12021203 def size_of_generated_code_in_bytes (self ) -> int : ...
12031204 def get_compiled_memory_stats (self ) -> CompiledMemoryStats : ...
12041205 def execute_sharded (
Original file line number Diff line number Diff line change @@ -545,6 +545,12 @@ void PyLoadedExecutable::Register(nb::module_& m) {
545545 .def (" get_hlo_text" ,
546546 xla::ValueOrThrowWrapper (
547547 &PyLoadedExecutable::GetHumanReadableProgramText))
548+ .def (" serialize" ,
549+ [](const PyLoadedExecutable& exec) -> nb::bytes {
550+ std::string serialized =
551+ xla::ValueOrThrow (exec.ifrt_loaded_executable ()->Serialize ());
552+ return nb::bytes (serialized.data (), serialized.size ());
553+ })
548554 .def (" size_of_generated_code_in_bytes" ,
549555 &PyLoadedExecutable::SizeOfGeneratedCodeInBytes)
550556 .def (
Original file line number Diff line number Diff line change 4747# Please suffix the version number with a brief description of your change
4848# in a comment. The goal here is to force a merge conflict if two changes
4949# attempt to grab the same version number.
50- _version = 388 # Add ArrayMeta
50+ _version = 389 # LoadedExecutable.serialize
5151
5252# An internal increasing version number for protecting jaxlib code against
5353# ifrt changes.
You can’t perform that action at this time.
0 commit comments