Skip to content

Commit 49cc0e9

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add FlatTensorDataMap to low-level pybindings (#15900)
Summary: Create pybindings around the FlatTensorDataMap::load function This allows us to use the low-level program data separation APIs in pybindings, through program/method as well as Module. Use `py::capsule` to capture the type `const NamedDataMap*` that is passed into method_load. Example usage: ``` >>> inputs = (torch.randn(3),) >>> program = _load_program_from_buffer(program_buffer) >>> data_map = _load_flat_tensor_data_map("model.ptd") >>> method = program.load_method("forward", data_map.get_named_data_map()) >>> outputs = method(inputs)[0] ``` Reviewed By: JacobSzwejbka Differential Revision: D87461455
1 parent ee236cb commit 49cc0e9

File tree

5 files changed

+183
-10
lines changed

5 files changed

+183
-10
lines changed

extension/pybindings/portable_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
_get_registered_backend_names, # noqa: F401
6060
_is_available, # noqa: F401
6161
_load_bundled_program_from_buffer, # noqa: F401
62+
_load_flat_tensor_data_map, # noqa: F401
63+
_load_flat_tensor_data_map_from_buffer, # noqa: F401
6264
_load_for_executorch, # noqa: F401
6365
_load_for_executorch_from_buffer, # noqa: F401
6466
_load_for_executorch_from_bundled_program, # noqa: F401
@@ -71,6 +73,7 @@
7173
ExecuTorchMethod, # noqa: F401
7274
ExecuTorchModule, # noqa: F401
7375
ExecuTorchProgram, # noqa: F401
76+
FlatTensorDataMap, # noqa: F401
7477
MethodMeta, # noqa: F401
7578
Verification, # noqa: F401
7679
)

extension/pybindings/pybindings.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/devtools/etdump/etdump_flatcc.h>
2222
#include <executorch/extension/data_loader/buffer_data_loader.h>
2323
#include <executorch/extension/data_loader/mmap_data_loader.h>
24+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
2425
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2526
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/module/module.h>
@@ -82,6 +83,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Kernel;
8283
using ::executorch::ET_RUNTIME_NAMESPACE::Method;
8384
using ::executorch::ET_RUNTIME_NAMESPACE::Program;
8485
using ::executorch::extension::BufferDataLoader;
86+
using ::executorch::extension::FlatTensorDataMap;
8587
using ::executorch::extension::MallocMemoryAllocator;
8688
using ::executorch::extension::MmapDataLoader;
8789
using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
@@ -1367,9 +1369,21 @@ struct PyProgram final {
13671369
return std::string(res.get());
13681370
}
13691371

1370-
std::unique_ptr<PyMethod> load_method(const std::string& method_name) {
1372+
std::unique_ptr<PyMethod> load_method(
1373+
const std::string& method_name,
1374+
py::object named_data_map_obj = py::none()) {
1375+
const NamedDataMap* named_data_map = nullptr;
1376+
if (!named_data_map_obj.is_none()) {
1377+
// Extract pointer from py::capsule.
1378+
py::capsule named_data_map_capsule =
1379+
named_data_map_obj.cast<py::capsule>();
1380+
named_data_map = named_data_map_capsule.get_pointer<const NamedDataMap>();
1381+
}
13711382
Result<Method> res = state_->program_->load_method(
1372-
method_name.c_str(), memory_->mem_manager(), event_tracer_.get());
1383+
method_name.c_str(),
1384+
memory_->mem_manager(),
1385+
event_tracer_.get(),
1386+
named_data_map);
13731387
THROW_IF_ERROR(
13741388
res.error(),
13751389
"Failed to load method %s, error: 0x:%" PRIx32,
@@ -1470,6 +1484,40 @@ py::bool_ is_available(const std::string& backend_name) {
14701484
return backend->is_available();
14711485
}
14721486

1487+
struct PyFlatTensorDataMap final {
1488+
explicit PyFlatTensorDataMap(
1489+
std::unique_ptr<DataLoader> loader,
1490+
FlatTensorDataMap data_map)
1491+
: loader_(std::move(loader)), data_map_(std::move(data_map)) {}
1492+
static std::unique_ptr<PyFlatTensorDataMap> load_from_file(
1493+
const std::string& path) {
1494+
auto loader = loader_from_file(path);
1495+
auto result = FlatTensorDataMap::load(loader.get());
1496+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1497+
return std::make_unique<PyFlatTensorDataMap>(
1498+
std::move(loader), std::move(result.get()));
1499+
}
1500+
static std::unique_ptr<PyFlatTensorDataMap> load_from_buffer(
1501+
const py::bytes& buffer) {
1502+
auto loader = loader_from_buffer(
1503+
buffer.cast<std::string_view>().data(), py::len(buffer));
1504+
auto result = FlatTensorDataMap::load(loader.get());
1505+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1506+
return std::make_unique<PyFlatTensorDataMap>(
1507+
std::move(loader), std::move(result.get()));
1508+
}
1509+
1510+
// Get a pointer to the underlying NamedDataMap as a py::capsule.
1511+
// The PyFlatTensorDataMap must outlive this pointer.
1512+
py::capsule get_named_data_map() {
1513+
return py::capsule(&data_map_, "NamedDataMap");
1514+
}
1515+
1516+
private:
1517+
std::unique_ptr<DataLoader> loader_; // Keep DataLoader alive.
1518+
FlatTensorDataMap data_map_;
1519+
};
1520+
14731521
} // namespace
14741522

14751523
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1677,6 +1725,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16771725
"load_method",
16781726
&PyProgram::load_method,
16791727
py::arg("method_name"),
1728+
py::arg("named_data_map") = py::none(),
16801729
call_guard)
16811730
.def(
16821731
"method_meta",
@@ -1728,6 +1777,22 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
17281777
py::arg("name"),
17291778
call_guard)
17301779
.def("method_meta", &PyMethod::method_meta, call_guard);
1780+
1781+
m.def(
1782+
"_load_flat_tensor_data_map",
1783+
&PyFlatTensorDataMap::load_from_file,
1784+
py::arg("data_path"),
1785+
call_guard);
1786+
m.def(
1787+
"_load_flat_tensor_data_map_from_buffer",
1788+
&PyFlatTensorDataMap::load_from_buffer,
1789+
py::arg("data_buffer"),
1790+
call_guard);
1791+
py::class_<PyFlatTensorDataMap>(m, "FlatTensorDataMap")
1792+
.def(
1793+
"get_named_data_map",
1794+
&PyFlatTensorDataMap::get_named_data_map,
1795+
call_guard);
17311796
}
17321797

17331798
namespace {

extension/pybindings/pybindings.pyi

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,60 @@ def _threadpool_get_thread_count() -> int:
297297
This API is experimental and subject to change without notice.
298298
"""
299299
...
300+
301+
@experimental("This API is experimental and subject to change without notice.")
302+
class FlatTensorDataMap:
303+
"""FlatTensorDataMap loads external data from a .ptd file.
304+
305+
.. warning::
306+
307+
This API is experimental and subject to change without notice.
308+
"""
309+
310+
def get_named_data_map(self) -> Any:
311+
"""Get a pointer to the underlying NamedDataMap.
312+
313+
Returns:
314+
A capsule containing a pointer to the internal NamedDataMap
315+
that can be passed to ExecuTorchProgram.load_method().
316+
317+
Warning:
318+
The FlatTensorDataMap instance must outlive the returned capsule.
319+
"""
320+
...
321+
322+
@experimental("This API is experimental and subject to change without notice.")
323+
def _load_flat_tensor_data_map(
324+
data_path: str,
325+
) -> FlatTensorDataMap:
326+
"""Load a flat tensor data map from a file.
327+
328+
.. warning::
329+
330+
This API is experimental and subject to change without notice.
331+
332+
Args:
333+
data_path: Path to the .ptd file with external data.
334+
335+
Returns:
336+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
337+
"""
338+
...
339+
340+
@experimental("This API is experimental and subject to change without notice.")
341+
def _load_flat_tensor_data_map_from_buffer(
342+
data_buffer: bytes,
343+
) -> FlatTensorDataMap:
344+
"""Load a flat tensor data map from a buffer.
345+
346+
.. warning::
347+
348+
This API is experimental and subject to change without notice.
349+
350+
Args:
351+
data_buffer: Buffer holding a .ptd file with external data.
352+
353+
Returns:
354+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
355+
"""
356+
...

extension/pybindings/test/test_pybindings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,49 @@ def test_program_data_separation(self) -> None:
733733
)
734734
with self.assertRaises(RuntimeError):
735735
executorch_module_bundled_no_data.forward(inputs)
736+
737+
def test_flat_tensor_data_map(self) -> None:
738+
eager_module = ModuleLinear()
739+
inputs = eager_module.get_inputs()
740+
expected = eager_module(inputs[0])
741+
exported_program = export(eager_module, inputs, strict=True)
742+
exec_program = to_edge(exported_program).to_executorch(
743+
config=ExecutorchBackendConfig(
744+
# Move all tensor data to '_default_external_constant' file.
745+
external_constants=True,
746+
)
747+
)
748+
program_buffer = exec_program.buffer
749+
assert len(exec_program._tensor_data) == 1
750+
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))
751+
752+
# Test 1: Load FlatTensorDataMap from buffer.
753+
program_from_buffer = self.load_prog_fn(program_buffer)
754+
data_map_from_buffer = self.runtime._load_flat_tensor_data_map_from_buffer(
755+
data_buffer
756+
)
757+
method = program_from_buffer.load_method(
758+
"forward", data_map_from_buffer.get_named_data_map()
759+
)
760+
executorch_output = method(inputs)[0]
761+
self.assertTrue(torch.allclose(expected, executorch_output))
762+
763+
# Test 2: Load FlatTensorDataMap from file.
764+
import os
765+
import tempfile
766+
767+
with tempfile.TemporaryDirectory() as tmpdir:
768+
pte_file = os.path.join(tmpdir, "linear.pte")
769+
with open(pte_file, "wb") as f:
770+
f.write(program_buffer)
771+
ptd_file = os.path.join(tmpdir, "linear.ptd")
772+
with open(ptd_file, "wb") as ptd:
773+
ptd.write(data_buffer)
774+
775+
program_from_file = self.runtime._load_program(pte_file)
776+
data_map_from_file = self.runtime._load_flat_tensor_data_map(ptd_file)
777+
method_1 = program_from_file.load_method(
778+
"forward", data_map_from_file.get_named_data_map()
779+
)
780+
executorch_output1 = method_1(inputs)[0]
781+
self.assertTrue(torch.allclose(expected, executorch_output1))

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,37 @@ MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB = [
88
]
99

1010
PORTABLE_MODULE_DEPS = [
11-
"//executorch/runtime/kernel:operator_registry",
12-
"//executorch/runtime/executor:program",
1311
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
14-
"//executorch/extension/aten_util:aten_bridge",
1512
"//executorch/devtools/bundled_program:runtime",
13+
"//executorch/devtools/etdump:etdump_flatcc",
14+
"//executorch/extension/aten_util:aten_bridge",
1615
"//executorch/extension/data_loader:buffer_data_loader",
1716
"//executorch/extension/data_loader:mmap_data_loader",
17+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
1919
"//executorch/extension/module:bundled_module",
2020
"//executorch/extension/module:module",
2121
"//executorch/extension/tensor:tensor",
22+
"//executorch/runtime/executor:program",
2223
"//executorch/runtime/executor/test:test_backend_compiler_lib",
23-
"//executorch/devtools/etdump:etdump_flatcc",
24+
"//executorch/runtime/kernel:operator_registry",
2425
] + get_all_cpu_backend_targets()
2526

2627
ATEN_MODULE_DEPS = [
27-
"//executorch/runtime/kernel:operator_registry_aten",
28-
"//executorch/runtime/executor:program_aten",
2928
"//executorch/runtime/core/exec_aten:lib_aten",
3029
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
30+
"//executorch/devtools/bundled_program:runtime_aten",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
"//executorch/extension/data_loader:buffer_data_loader",
3233
"//executorch/extension/data_loader:mmap_data_loader",
34+
"//executorch/extension/flat_tensor:flat_tensor_data_map_aten",
3335
"//executorch/extension/memory_allocator:malloc_memory_allocator",
3436
"//executorch/extension/module:bundled_module_aten",
3537
"//executorch/extension/module:module_aten",
3638
"//executorch/extension/tensor:tensor_aten",
37-
"//executorch/devtools/bundled_program:runtime_aten",
3839
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
39-
"//executorch/devtools/etdump:etdump_flatcc",
40+
"//executorch/runtime/executor:program_aten",
41+
"//executorch/runtime/kernel:operator_registry_aten",
4042
]
4143

4244
# Generated lib for all ATen ops with aten kernel used by models in model inventory

0 commit comments

Comments
 (0)