From 49cc0e96e48544d175f258a72365b2e38951289a Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Wed, 10 Dec 2025 16:12:58 -0800 Subject: [PATCH] Add FlatTensorDataMap to low-level pybindings (#15900) Summary: Create pybindings around the FlatTensorDataMap::load function This allows us to use the low-level program data separation APIs in pybindings, through program/method as well as Module. Use `py::capsule` to capture the type `const NamedDataMap*` that is passed into method_load. Example usage: ``` >>> inputs = (torch.randn(3),) >>> program = _load_program_from_buffer(program_buffer) >>> data_map = _load_flat_tensor_data_map("model.ptd") >>> method = program.load_method("forward", data_map.get_named_data_map()) >>> outputs = method(inputs)[0] ``` Reviewed By: JacobSzwejbka Differential Revision: D87461455 --- extension/pybindings/portable_lib.py | 3 + extension/pybindings/pybindings.cpp | 69 ++++++++++++++++++- extension/pybindings/pybindings.pyi | 57 +++++++++++++++ extension/pybindings/test/test_pybindings.py | 46 +++++++++++++ .../extension/pybindings/pybindings.bzl | 18 ++--- 5 files changed, 183 insertions(+), 10 deletions(-) diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py index 27468c8b7b5..29382c3010a 100644 --- a/extension/pybindings/portable_lib.py +++ b/extension/pybindings/portable_lib.py @@ -59,6 +59,8 @@ _get_registered_backend_names, # noqa: F401 _is_available, # noqa: F401 _load_bundled_program_from_buffer, # noqa: F401 + _load_flat_tensor_data_map, # noqa: F401 + _load_flat_tensor_data_map_from_buffer, # noqa: F401 _load_for_executorch, # noqa: F401 _load_for_executorch_from_buffer, # noqa: F401 _load_for_executorch_from_bundled_program, # noqa: F401 @@ -71,6 +73,7 @@ ExecuTorchMethod, # noqa: F401 ExecuTorchModule, # noqa: F401 ExecuTorchProgram, # noqa: F401 + FlatTensorDataMap, # noqa: F401 MethodMeta, # noqa: F401 Verification, # noqa: F401 ) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index eb81bda22f7..8ec77ec157d 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -82,6 +83,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Kernel; using ::executorch::ET_RUNTIME_NAMESPACE::Method; using ::executorch::ET_RUNTIME_NAMESPACE::Program; using ::executorch::extension::BufferDataLoader; +using ::executorch::extension::FlatTensorDataMap; using ::executorch::extension::MallocMemoryAllocator; using ::executorch::extension::MmapDataLoader; using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule; @@ -1367,9 +1369,21 @@ struct PyProgram final { return std::string(res.get()); } - std::unique_ptr load_method(const std::string& method_name) { + std::unique_ptr load_method( + const std::string& method_name, + py::object named_data_map_obj = py::none()) { + const NamedDataMap* named_data_map = nullptr; + if (!named_data_map_obj.is_none()) { + // Extract pointer from py::capsule. + py::capsule named_data_map_capsule = + named_data_map_obj.cast(); + named_data_map = named_data_map_capsule.get_pointer(); + } Result res = state_->program_->load_method( - method_name.c_str(), memory_->mem_manager(), event_tracer_.get()); + method_name.c_str(), + memory_->mem_manager(), + event_tracer_.get(), + named_data_map); THROW_IF_ERROR( res.error(), "Failed to load method %s, error: 0x:%" PRIx32, @@ -1470,6 +1484,40 @@ py::bool_ is_available(const std::string& backend_name) { return backend->is_available(); } +struct PyFlatTensorDataMap final { + explicit PyFlatTensorDataMap( + std::unique_ptr loader, + FlatTensorDataMap data_map) + : loader_(std::move(loader)), data_map_(std::move(data_map)) {} + static std::unique_ptr load_from_file( + const std::string& path) { + auto loader = loader_from_file(path); + auto result = FlatTensorDataMap::load(loader.get()); + THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap"); + return std::make_unique( + std::move(loader), std::move(result.get())); + } + static std::unique_ptr load_from_buffer( + const py::bytes& buffer) { + auto loader = loader_from_buffer( + buffer.cast().data(), py::len(buffer)); + auto result = FlatTensorDataMap::load(loader.get()); + THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap"); + return std::make_unique( + std::move(loader), std::move(result.get())); + } + + // Get a pointer to the underlying NamedDataMap as a py::capsule. + // The PyFlatTensorDataMap must outlive this pointer. + py::capsule get_named_data_map() { + return py::capsule(&data_map_, "NamedDataMap"); + } + + private: + std::unique_ptr loader_; // Keep DataLoader alive. + FlatTensorDataMap data_map_; +}; + } // namespace PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { @@ -1677,6 +1725,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { "load_method", &PyProgram::load_method, py::arg("method_name"), + py::arg("named_data_map") = py::none(), call_guard) .def( "method_meta", @@ -1728,6 +1777,22 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("name"), call_guard) .def("method_meta", &PyMethod::method_meta, call_guard); + + m.def( + "_load_flat_tensor_data_map", + &PyFlatTensorDataMap::load_from_file, + py::arg("data_path"), + call_guard); + m.def( + "_load_flat_tensor_data_map_from_buffer", + &PyFlatTensorDataMap::load_from_buffer, + py::arg("data_buffer"), + call_guard); + py::class_(m, "FlatTensorDataMap") + .def( + "get_named_data_map", + &PyFlatTensorDataMap::get_named_data_map, + call_guard); } namespace { diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 9e5ab6211ce..e32ee03e9fa 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -297,3 +297,60 @@ def _threadpool_get_thread_count() -> int: This API is experimental and subject to change without notice. """ ... + +@experimental("This API is experimental and subject to change without notice.") +class FlatTensorDataMap: + """FlatTensorDataMap loads external data from a .ptd file. + + .. warning:: + + This API is experimental and subject to change without notice. + """ + + def get_named_data_map(self) -> Any: + """Get a pointer to the underlying NamedDataMap. + + Returns: + A capsule containing a pointer to the internal NamedDataMap + that can be passed to ExecuTorchProgram.load_method(). + + Warning: + The FlatTensorDataMap instance must outlive the returned capsule. + """ + ... + +@experimental("This API is experimental and subject to change without notice.") +def _load_flat_tensor_data_map( + data_path: str, +) -> FlatTensorDataMap: + """Load a flat tensor data map from a file. + + .. warning:: + + This API is experimental and subject to change without notice. + + Args: + data_path: Path to the .ptd file with external data. + + Returns: + A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method(). + """ + ... + +@experimental("This API is experimental and subject to change without notice.") +def _load_flat_tensor_data_map_from_buffer( + data_buffer: bytes, +) -> FlatTensorDataMap: + """Load a flat tensor data map from a buffer. + + .. warning:: + + This API is experimental and subject to change without notice. + + Args: + data_buffer: Buffer holding a .ptd file with external data. + + Returns: + A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method(). + """ + ... diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index ec45428c7d7..4838e660980 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -733,3 +733,49 @@ def test_program_data_separation(self) -> None: ) with self.assertRaises(RuntimeError): executorch_module_bundled_no_data.forward(inputs) + + def test_flat_tensor_data_map(self) -> None: + eager_module = ModuleLinear() + inputs = eager_module.get_inputs() + expected = eager_module(inputs[0]) + exported_program = export(eager_module, inputs, strict=True) + exec_program = to_edge(exported_program).to_executorch( + config=ExecutorchBackendConfig( + # Move all tensor data to '_default_external_constant' file. + external_constants=True, + ) + ) + program_buffer = exec_program.buffer + assert len(exec_program._tensor_data) == 1 + data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant")) + + # Test 1: Load FlatTensorDataMap from buffer. + program_from_buffer = self.load_prog_fn(program_buffer) + data_map_from_buffer = self.runtime._load_flat_tensor_data_map_from_buffer( + data_buffer + ) + method = program_from_buffer.load_method( + "forward", data_map_from_buffer.get_named_data_map() + ) + executorch_output = method(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output)) + + # Test 2: Load FlatTensorDataMap from file. + import os + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + pte_file = os.path.join(tmpdir, "linear.pte") + with open(pte_file, "wb") as f: + f.write(program_buffer) + ptd_file = os.path.join(tmpdir, "linear.ptd") + with open(ptd_file, "wb") as ptd: + ptd.write(data_buffer) + + program_from_file = self.runtime._load_program(pte_file) + data_map_from_file = self.runtime._load_flat_tensor_data_map(ptd_file) + method_1 = program_from_file.load_method( + "forward", data_map_from_file.get_named_data_map() + ) + executorch_output1 = method_1(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output1)) diff --git a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl index 7e14ca8713a..b17431aad80 100644 --- a/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl +++ b/shim_et/xplat/executorch/extension/pybindings/pybindings.bzl @@ -8,35 +8,37 @@ MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB = [ ] PORTABLE_MODULE_DEPS = [ - "//executorch/runtime/kernel:operator_registry", - "//executorch/runtime/executor:program", "//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs", - "//executorch/extension/aten_util:aten_bridge", "//executorch/devtools/bundled_program:runtime", + "//executorch/devtools/etdump:etdump_flatcc", + "//executorch/extension/aten_util:aten_bridge", "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/data_loader:mmap_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/extension/memory_allocator:malloc_memory_allocator", "//executorch/extension/module:bundled_module", "//executorch/extension/module:module", "//executorch/extension/tensor:tensor", + "//executorch/runtime/executor:program", "//executorch/runtime/executor/test:test_backend_compiler_lib", - "//executorch/devtools/etdump:etdump_flatcc", + "//executorch/runtime/kernel:operator_registry", ] + get_all_cpu_backend_targets() ATEN_MODULE_DEPS = [ - "//executorch/runtime/kernel:operator_registry_aten", - "//executorch/runtime/executor:program_aten", "//executorch/runtime/core/exec_aten:lib_aten", "//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs", + "//executorch/devtools/bundled_program:runtime_aten", + "//executorch/devtools/etdump:etdump_flatcc", "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/data_loader:mmap_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map_aten", "//executorch/extension/memory_allocator:malloc_memory_allocator", "//executorch/extension/module:bundled_module_aten", "//executorch/extension/module:module_aten", "//executorch/extension/tensor:tensor_aten", - "//executorch/devtools/bundled_program:runtime_aten", "//executorch/runtime/executor/test:test_backend_compiler_lib_aten", - "//executorch/devtools/etdump:etdump_flatcc", + "//executorch/runtime/executor:program_aten", + "//executorch/runtime/kernel:operator_registry_aten", ] # Generated lib for all ATen ops with aten kernel used by models in model inventory