Skip to content

Commit 3f3c2da

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add FlatTensorDataMap to low-level pybindings (pytorch#15900)
Summary: Create pybindings around the FlatTensorDataMap::load function This allows us to use the low-level program data separation APIs in pybindings, through program/method as well as Module. Use `py::capsule` to capture the type `const NamedDataMap*` that is passed into method_load. Example usage: ``` >>> inputs = (torch.randn(3),) >>> program = _load_program_from_buffer(program_buffer) >>> data_map = _load_flat_tensor_data_map("model.ptd") >>> method = program.load_method("forward", data_map.get_named_data_map()) >>> outputs = method(inputs)[0] ``` Differential Revision: D87461455
1 parent 93bf861 commit 3f3c2da

File tree

5 files changed

+189
-15
lines changed

5 files changed

+189
-15
lines changed

extension/pybindings/portable_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
_get_registered_backend_names, # noqa: F401
6060
_is_available, # noqa: F401
6161
_load_bundled_program_from_buffer, # noqa: F401
62+
_load_flat_tensor_data_map, # noqa: F401
63+
_load_flat_tensor_data_map_from_buffer, # noqa: F401
6264
_load_for_executorch, # noqa: F401
6365
_load_for_executorch_from_buffer, # noqa: F401
6466
_load_for_executorch_from_bundled_program, # noqa: F401
@@ -71,6 +73,7 @@
7173
ExecuTorchMethod, # noqa: F401
7274
ExecuTorchModule, # noqa: F401
7375
ExecuTorchProgram, # noqa: F401
76+
FlatTensorDataMap, # noqa: F401
7477
MethodMeta, # noqa: F401
7578
Verification, # noqa: F401
7679
)

extension/pybindings/pybindings.cpp

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/devtools/etdump/etdump_flatcc.h>
2222
#include <executorch/extension/data_loader/buffer_data_loader.h>
2323
#include <executorch/extension/data_loader/mmap_data_loader.h>
24+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
2425
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2526
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/module/module.h>
@@ -82,6 +83,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Kernel;
8283
using ::executorch::ET_RUNTIME_NAMESPACE::Method;
8384
using ::executorch::ET_RUNTIME_NAMESPACE::Program;
8485
using ::executorch::extension::BufferDataLoader;
86+
using ::executorch::extension::FlatTensorDataMap;
8587
using ::executorch::extension::MallocMemoryAllocator;
8688
using ::executorch::extension::MmapDataLoader;
8789
using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
@@ -295,14 +297,15 @@ struct PyBundledModule : public BundledModule {
295297
: BundledModule(buffer.cast<std::string_view>().data()),
296298
bundled_program_ptr_(buffer),
297299
program_ptr_(static_cast<const void*>(
300+
bundled_program_flatbuffer::GetBundledProgram(
301+
get_bundled_program_ptr())
302+
->program()
303+
->data())),
304+
program_len_(
298305
bundled_program_flatbuffer::GetBundledProgram(
299306
get_bundled_program_ptr())
300307
->program()
301-
->data())),
302-
program_len_(bundled_program_flatbuffer::GetBundledProgram(
303-
get_bundled_program_ptr())
304-
->program()
305-
->size()) {}
308+
->size()) {}
306309

307310
static std::unique_ptr<PyBundledModule> load_from_buffer(
308311
const py::bytes& buffer,
@@ -1367,9 +1370,21 @@ struct PyProgram final {
13671370
return std::string(res.get());
13681371
}
13691372

1370-
std::unique_ptr<PyMethod> load_method(const std::string& method_name) {
1373+
std::unique_ptr<PyMethod> load_method(
1374+
const std::string& method_name,
1375+
py::object named_data_map_obj = py::none()) {
1376+
const NamedDataMap* named_data_map = nullptr;
1377+
if (!named_data_map_obj.is_none()) {
1378+
// Extract pointer from py::capsule.
1379+
py::capsule named_data_map_capsule =
1380+
named_data_map_obj.cast<py::capsule>();
1381+
named_data_map = named_data_map_capsule.get_pointer<const NamedDataMap>();
1382+
}
13711383
Result<Method> res = state_->program_->load_method(
1372-
method_name.c_str(), memory_->mem_manager(), event_tracer_.get());
1384+
method_name.c_str(),
1385+
memory_->mem_manager(),
1386+
event_tracer_.get(),
1387+
named_data_map);
13731388
THROW_IF_ERROR(
13741389
res.error(),
13751390
"Failed to load method %s, error: 0x:%" PRIx32,
@@ -1470,6 +1485,40 @@ py::bool_ is_available(const std::string& backend_name) {
14701485
return backend->is_available();
14711486
}
14721487

1488+
struct PyFlatTensorDataMap final {
1489+
explicit PyFlatTensorDataMap(
1490+
std::unique_ptr<DataLoader> loader,
1491+
FlatTensorDataMap data_map)
1492+
: loader_(std::move(loader)), data_map_(std::move(data_map)) {}
1493+
static std::unique_ptr<PyFlatTensorDataMap> load_from_file(
1494+
const std::string& path) {
1495+
auto loader = loader_from_file(path);
1496+
auto result = FlatTensorDataMap::load(loader.get());
1497+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1498+
return std::make_unique<PyFlatTensorDataMap>(
1499+
std::move(loader), std::move(result.get()));
1500+
}
1501+
static std::unique_ptr<PyFlatTensorDataMap> load_from_buffer(
1502+
const py::bytes& buffer) {
1503+
auto loader = loader_from_buffer(
1504+
buffer.cast<std::string_view>().data(), py::len(buffer));
1505+
auto result = FlatTensorDataMap::load(loader.get());
1506+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1507+
return std::make_unique<PyFlatTensorDataMap>(
1508+
std::move(loader), std::move(result.get()));
1509+
}
1510+
1511+
// Get a pointer to the underlying NamedDataMap as a py::capsule.
1512+
// The PyFlatTensorDataMap must outlive this pointer.
1513+
py::capsule get_named_data_map() {
1514+
return py::capsule(&data_map_, "NamedDataMap");
1515+
}
1516+
1517+
private:
1518+
std::unique_ptr<DataLoader> loader_; // Keep DataLoader alive.
1519+
FlatTensorDataMap data_map_;
1520+
};
1521+
14731522
} // namespace
14741523

14751524
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1677,6 +1726,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16771726
"load_method",
16781727
&PyProgram::load_method,
16791728
py::arg("method_name"),
1729+
py::arg("named_data_map") = py::none(),
16801730
call_guard)
16811731
.def(
16821732
"method_meta",
@@ -1728,6 +1778,22 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
17281778
py::arg("name"),
17291779
call_guard)
17301780
.def("method_meta", &PyMethod::method_meta, call_guard);
1781+
1782+
m.def(
1783+
"_load_flat_tensor_data_map",
1784+
&PyFlatTensorDataMap::load_from_file,
1785+
py::arg("data_path"),
1786+
call_guard);
1787+
m.def(
1788+
"_load_flat_tensor_data_map_from_buffer",
1789+
&PyFlatTensorDataMap::load_from_buffer,
1790+
py::arg("data_buffer"),
1791+
call_guard);
1792+
py::class_<PyFlatTensorDataMap>(m, "FlatTensorDataMap")
1793+
.def(
1794+
"get_named_data_map",
1795+
&PyFlatTensorDataMap::get_named_data_map,
1796+
call_guard);
17311797
}
17321798

17331799
namespace {

extension/pybindings/pybindings.pyi

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,60 @@ def _threadpool_get_thread_count() -> int:
297297
This API is experimental and subject to change without notice.
298298
"""
299299
...
300+
301+
@experimental("This API is experimental and subject to change without notice.")
302+
class FlatTensorDataMap:
303+
"""FlatTensorDataMap loads external data from a .ptd file.
304+
305+
.. warning::
306+
307+
This API is experimental and subject to change without notice.
308+
"""
309+
310+
def get_named_data_map(self) -> Any:
311+
"""Get a pointer to the underlying NamedDataMap.
312+
313+
Returns:
314+
A capsule containing a pointer to the internal NamedDataMap
315+
that can be passed to ExecuTorchProgram.load_method().
316+
317+
Warning:
318+
The FlatTensorDataMap instance must outlive the returned capsule.
319+
"""
320+
...
321+
322+
@experimental("This API is experimental and subject to change without notice.")
323+
def _load_flat_tensor_data_map(
324+
data_path: str,
325+
) -> FlatTensorDataMap:
326+
"""Load a flat tensor data map from a file.
327+
328+
.. warning::
329+
330+
This API is experimental and subject to change without notice.
331+
332+
Args:
333+
data_path: Path to the .ptd file with external data.
334+
335+
Returns:
336+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
337+
"""
338+
...
339+
340+
@experimental("This API is experimental and subject to change without notice.")
341+
def _load_flat_tensor_data_map_from_buffer(
342+
data_buffer: bytes,
343+
) -> FlatTensorDataMap:
344+
"""Load a flat tensor data map from a buffer.
345+
346+
.. warning::
347+
348+
This API is experimental and subject to change without notice.
349+
350+
Args:
351+
data_buffer: Buffer holding a .ptd file with external data.
352+
353+
Returns:
354+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
355+
"""
356+
...

extension/pybindings/test/test_pybindings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,49 @@ def test_program_data_separation(self) -> None:
733733
)
734734
with self.assertRaises(RuntimeError):
735735
executorch_module_bundled_no_data.forward(inputs)
736+
737+
def test_flat_tensor_data_map(self) -> None:
738+
eager_module = ModuleLinear()
739+
inputs = eager_module.get_inputs()
740+
expected = eager_module(inputs[0])
741+
exported_program = export(eager_module, inputs, strict=True)
742+
exec_program = to_edge(exported_program).to_executorch(
743+
config=ExecutorchBackendConfig(
744+
# Move all tensor data to '_default_external_constant' file.
745+
external_constants=True,
746+
)
747+
)
748+
program_buffer = exec_program.buffer
749+
assert len(exec_program._tensor_data) == 1
750+
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))
751+
752+
# Test 1: Load FlatTensorDataMap from buffer.
753+
program_from_buffer = self.load_prog_fn(program_buffer)
754+
data_map_from_buffer = self.runtime._load_flat_tensor_data_map_from_buffer(
755+
data_buffer
756+
)
757+
method = program_from_buffer.load_method(
758+
"forward", data_map_from_buffer.get_named_data_map()
759+
)
760+
executorch_output = method(inputs)[0]
761+
self.assertTrue(torch.allclose(expected, executorch_output))
762+
763+
# Test 2: Load FlatTensorDataMap from file.
764+
import os
765+
import tempfile
766+
767+
with tempfile.TemporaryDirectory() as tmpdir:
768+
pte_file = os.path.join(tmpdir, "linear.pte")
769+
with open(pte_file, "wb") as f:
770+
f.write(program_buffer)
771+
ptd_file = os.path.join(tmpdir, "linear.ptd")
772+
with open(ptd_file, "wb") as ptd:
773+
ptd.write(data_buffer)
774+
775+
program_from_file = self.runtime._load_program(pte_file)
776+
data_map_from_file = self.runtime._load_flat_tensor_data_map(ptd_file)
777+
method_1 = program_from_file.load_method(
778+
"forward", data_map_from_file.get_named_data_map()
779+
)
780+
executorch_output1 = method_1(inputs)[0]
781+
self.assertTrue(torch.allclose(expected, executorch_output1))

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,37 @@ MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB = [
88
]
99

1010
PORTABLE_MODULE_DEPS = [
11-
"//executorch/runtime/kernel:operator_registry",
12-
"//executorch/runtime/executor:program",
1311
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
14-
"//executorch/extension/aten_util:aten_bridge",
1512
"//executorch/devtools/bundled_program:runtime",
13+
"//executorch/devtools/etdump:etdump_flatcc",
14+
"//executorch/extension/aten_util:aten_bridge",
1615
"//executorch/extension/data_loader:buffer_data_loader",
1716
"//executorch/extension/data_loader:mmap_data_loader",
17+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
1919
"//executorch/extension/module:bundled_module",
2020
"//executorch/extension/module:module",
2121
"//executorch/extension/tensor:tensor",
22+
"//executorch/runtime/executor:program",
2223
"//executorch/runtime/executor/test:test_backend_compiler_lib",
23-
"//executorch/devtools/etdump:etdump_flatcc",
24+
"//executorch/runtime/kernel:operator_registry",
2425
] + get_all_cpu_backend_targets()
2526

2627
ATEN_MODULE_DEPS = [
27-
"//executorch/runtime/kernel:operator_registry_aten",
28-
"//executorch/runtime/executor:program_aten",
2928
"//executorch/runtime/core/exec_aten:lib_aten",
3029
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
30+
"//executorch/devtools/bundled_program:runtime_aten",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
"//executorch/extension/data_loader:buffer_data_loader",
3233
"//executorch/extension/data_loader:mmap_data_loader",
34+
"//executorch/extension/flat_tensor:flat_tensor_data_map_aten",
3335
"//executorch/extension/memory_allocator:malloc_memory_allocator",
3436
"//executorch/extension/module:bundled_module_aten",
3537
"//executorch/extension/module:module_aten",
3638
"//executorch/extension/tensor:tensor_aten",
37-
"//executorch/devtools/bundled_program:runtime_aten",
3839
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
39-
"//executorch/devtools/etdump:etdump_flatcc",
40+
"//executorch/runtime/executor:program_aten",
41+
"//executorch/runtime/kernel:operator_registry_aten",
4042
]
4143

4244
# Generated lib for all ATen ops with aten kernel used by models in model inventory

0 commit comments

Comments
 (0)