Skip to content

Commit e135cbe

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add FlatTensorDataMap to low-level pybindings
Summary: Create pybindings around the FlatTensorDataMap::load function This allows us to use the low-level program data separation APIs in pybindings, through program/method as well as Module. Use `py::capsule` to capture the type `const NamedDataMap*` that is passed into method_load. Example usage: ``` >>> inputs = (torch.randn(3),) >>> program = _load_program_from_buffer(program_buffer) >>> data_map = _load_flat_tensor_data_map("model.ptd") >>> method = program.load_method("forward", data_map.get_named_data_map()) >>> outputs = method(inputs)[0] ``` Differential Revision: D87461455
1 parent aff5086 commit e135cbe

File tree

5 files changed

+191
-16
lines changed

5 files changed

+191
-16
lines changed

extension/pybindings/portable_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
_get_registered_backend_names, # noqa: F401
6060
_is_available, # noqa: F401
6161
_load_bundled_program_from_buffer, # noqa: F401
62+
_load_flat_tensor_data_map, # noqa: F401
63+
_load_flat_tensor_data_map_from_buffer, # noqa: F401
6264
_load_for_executorch, # noqa: F401
6365
_load_for_executorch_from_buffer, # noqa: F401
6466
_load_for_executorch_from_bundled_program, # noqa: F401
@@ -70,6 +72,7 @@
7072
ExecuTorchMethod, # noqa: F401
7173
ExecuTorchModule, # noqa: F401
7274
ExecuTorchProgram, # noqa: F401
75+
FlatTensorDataMap, # noqa: F401
7376
MethodMeta, # noqa: F401
7477
Verification, # noqa: F401
7578
)

extension/pybindings/pybindings.cpp

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/devtools/etdump/etdump_flatcc.h>
2222
#include <executorch/extension/data_loader/buffer_data_loader.h>
2323
#include <executorch/extension/data_loader/mmap_data_loader.h>
24+
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
2425
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2526
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/module/module.h>
@@ -82,6 +83,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Kernel;
8283
using ::executorch::ET_RUNTIME_NAMESPACE::Method;
8384
using ::executorch::ET_RUNTIME_NAMESPACE::Program;
8485
using ::executorch::extension::BufferDataLoader;
86+
using ::executorch::extension::FlatTensorDataMap;
8587
using ::executorch::extension::MallocMemoryAllocator;
8688
using ::executorch::extension::MmapDataLoader;
8789
using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
@@ -294,15 +296,17 @@ struct PyBundledModule : public BundledModule {
294296
uint32_t bundled_input_pool_size)
295297
: BundledModule(buffer.cast<std::string_view>().data()),
296298
bundled_program_ptr_(buffer),
297-
program_ptr_(static_cast<const void*>(
299+
program_ptr_(
300+
static_cast<const void*>(
301+
bundled_program_flatbuffer::GetBundledProgram(
302+
get_bundled_program_ptr())
303+
->program()
304+
->data())),
305+
program_len_(
298306
bundled_program_flatbuffer::GetBundledProgram(
299307
get_bundled_program_ptr())
300308
->program()
301-
->data())),
302-
program_len_(bundled_program_flatbuffer::GetBundledProgram(
303-
get_bundled_program_ptr())
304-
->program()
305-
->size()) {}
309+
->size()) {}
306310

307311
static std::unique_ptr<PyBundledModule> load_from_buffer(
308312
const py::bytes& buffer,
@@ -1367,9 +1371,21 @@ struct PyProgram final {
13671371
return std::string(res.get());
13681372
}
13691373

1370-
std::unique_ptr<PyMethod> load_method(const std::string& method_name) {
1374+
std::unique_ptr<PyMethod> load_method(
1375+
const std::string& method_name,
1376+
py::object named_data_map_obj = py::none()) {
1377+
const NamedDataMap* named_data_map = nullptr;
1378+
if (!named_data_map_obj.is_none()) {
1379+
// Extract pointer from py::capsule.
1380+
py::capsule named_data_map_capsule =
1381+
named_data_map_obj.cast<py::capsule>();
1382+
named_data_map = named_data_map_capsule.get_pointer<const NamedDataMap>();
1383+
}
13711384
Result<Method> res = state_->program_->load_method(
1372-
method_name.c_str(), memory_->mem_manager(), event_tracer_.get());
1385+
method_name.c_str(),
1386+
memory_->mem_manager(),
1387+
event_tracer_.get(),
1388+
named_data_map);
13731389
THROW_IF_ERROR(
13741390
res.error(),
13751391
"Failed to load method %s, error: 0x:%" PRIx32,
@@ -1470,6 +1486,40 @@ py::bool_ is_available(const std::string& backend_name) {
14701486
return backend->is_available();
14711487
}
14721488

1489+
struct PyFlatTensorDataMap final {
1490+
explicit PyFlatTensorDataMap(
1491+
std::unique_ptr<DataLoader> loader,
1492+
FlatTensorDataMap data_map)
1493+
: loader_(std::move(loader)), data_map_(std::move(data_map)) {}
1494+
static std::unique_ptr<PyFlatTensorDataMap> load_from_file(
1495+
const std::string& path) {
1496+
auto loader = loader_from_file(path);
1497+
auto result = FlatTensorDataMap::load(loader.get());
1498+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1499+
return std::make_unique<PyFlatTensorDataMap>(
1500+
std::move(loader), std::move(result.get()));
1501+
}
1502+
static std::unique_ptr<PyFlatTensorDataMap> load_from_buffer(
1503+
const py::bytes& buffer) {
1504+
auto loader = loader_from_buffer(
1505+
buffer.cast<std::string_view>().data(), py::len(buffer));
1506+
auto result = FlatTensorDataMap::load(loader.get());
1507+
THROW_IF_ERROR(result.error(), "Failed to load FlatTensorDataMap");
1508+
return std::make_unique<PyFlatTensorDataMap>(
1509+
std::move(loader), std::move(result.get()));
1510+
}
1511+
1512+
// Get a pointer to the underlying NamedDataMap as a py::capsule.
1513+
// The PyFlatTensorDataMap must outlive this pointer.
1514+
py::capsule get_named_data_map() {
1515+
return py::capsule(&data_map_, "NamedDataMap");
1516+
}
1517+
1518+
private:
1519+
std::unique_ptr<DataLoader> loader_; // Keep DataLoader alive.
1520+
FlatTensorDataMap data_map_;
1521+
};
1522+
14731523
} // namespace
14741524

14751525
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
@@ -1670,6 +1720,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
16701720
"load_method",
16711721
&PyProgram::load_method,
16721722
py::arg("method_name"),
1723+
py::arg("named_data_map") = py::none(),
16731724
call_guard)
16741725
.def(
16751726
"method_meta",
@@ -1721,6 +1772,22 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
17211772
py::arg("name"),
17221773
call_guard)
17231774
.def("method_meta", &PyMethod::method_meta, call_guard);
1775+
1776+
m.def(
1777+
"_load_flat_tensor_data_map",
1778+
&PyFlatTensorDataMap::load_from_file,
1779+
py::arg("data_path"),
1780+
call_guard);
1781+
m.def(
1782+
"_load_flat_tensor_data_map_from_buffer",
1783+
&PyFlatTensorDataMap::load_from_buffer,
1784+
py::arg("data_buffer"),
1785+
call_guard);
1786+
py::class_<PyFlatTensorDataMap>(m, "FlatTensorDataMap")
1787+
.def(
1788+
"get_named_data_map",
1789+
&PyFlatTensorDataMap::get_named_data_map,
1790+
call_guard);
17241791
}
17251792

17261793
namespace {

extension/pybindings/pybindings.pyi

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,60 @@ def _unsafe_reset_threadpool(num_threads: int) -> None:
288288
This API is experimental and subject to change without notice.
289289
"""
290290
...
291+
292+
@experimental("This API is experimental and subject to change without notice.")
293+
class FlatTensorDataMap:
294+
"""FlatTensorDataMap loads external data from a .ptd file.
295+
296+
.. warning::
297+
298+
This API is experimental and subject to change without notice.
299+
"""
300+
301+
def get_named_data_map(self) -> Any:
302+
"""Get a pointer to the underlying NamedDataMap.
303+
304+
Returns:
305+
A capsule containing a pointer to the internal NamedDataMap
306+
that can be passed to ExecuTorchProgram.load_method().
307+
308+
Warning:
309+
The FlatTensorDataMap instance must outlive the returned capsule.
310+
"""
311+
...
312+
313+
@experimental("This API is experimental and subject to change without notice.")
314+
def _load_flat_tensor_data_map(
315+
data_path: str,
316+
) -> FlatTensorDataMap:
317+
"""Load a flat tensor data map from a file.
318+
319+
.. warning::
320+
321+
This API is experimental and subject to change without notice.
322+
323+
Args:
324+
data_path: Path to the .ptd file with external data.
325+
326+
Returns:
327+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
328+
"""
329+
...
330+
331+
@experimental("This API is experimental and subject to change without notice.")
332+
def _load_flat_tensor_data_map_from_buffer(
333+
data_buffer: bytes,
334+
) -> FlatTensorDataMap:
335+
"""Load a flat tensor data map from a buffer.
336+
337+
.. warning::
338+
339+
This API is experimental and subject to change without notice.
340+
341+
Args:
342+
data_buffer: Buffer holding a .ptd file with external data.
343+
344+
Returns:
345+
A FlatTensorDataMap instance that can be used with ExecuTorchProgram.load_method().
346+
"""
347+
...

extension/pybindings/test/test_pybindings.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,49 @@ def test_program_data_separation(self) -> None:
733733
)
734734
with self.assertRaises(RuntimeError):
735735
executorch_module_bundled_no_data.forward(inputs)
736+
737+
def test_flat_tensor_data_map(self) -> None:
738+
eager_module = ModuleLinear()
739+
inputs = eager_module.get_inputs()
740+
expected = eager_module(inputs[0])
741+
exported_program = export(eager_module, inputs, strict=True)
742+
exec_program = to_edge(exported_program).to_executorch(
743+
config=ExecutorchBackendConfig(
744+
# Move all tensor data to '_default_external_constant' file.
745+
external_constants=True,
746+
)
747+
)
748+
program_buffer = exec_program.buffer
749+
assert len(exec_program._tensor_data) == 1
750+
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))
751+
752+
# Test 1: Load FlatTensorDataMap from buffer.
753+
program_from_buffer = self.load_prog_fn(program_buffer)
754+
data_map_from_buffer = self.runtime._load_flat_tensor_data_map_from_buffer(
755+
data_buffer
756+
)
757+
method = program_from_buffer.load_method(
758+
"forward", data_map_from_buffer.get_named_data_map()
759+
)
760+
executorch_output = method(inputs)[0]
761+
self.assertTrue(torch.allclose(expected, executorch_output))
762+
763+
# Test 2: Load FlatTensorDataMap from file.
764+
import os
765+
import tempfile
766+
767+
with tempfile.TemporaryDirectory() as tmpdir:
768+
pte_file = os.path.join(tmpdir, "linear.pte")
769+
with open(pte_file, "wb") as f:
770+
f.write(program_buffer)
771+
ptd_file = os.path.join(tmpdir, "linear.ptd")
772+
with open(ptd_file, "wb") as ptd:
773+
ptd.write(data_buffer)
774+
775+
program_from_file = self.runtime._load_program(pte_file)
776+
data_map_from_file = self.runtime._load_flat_tensor_data_map(ptd_file)
777+
method_1 = program_from_file.load_method(
778+
"forward", data_map_from_file.get_named_data_map()
779+
)
780+
executorch_output1 = method_1(inputs)[0]
781+
self.assertTrue(torch.allclose(expected, executorch_output1))

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,37 @@ MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB = [
88
]
99

1010
PORTABLE_MODULE_DEPS = [
11-
"//executorch/runtime/kernel:operator_registry",
12-
"//executorch/runtime/executor:program",
1311
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
14-
"//executorch/extension/aten_util:aten_bridge",
1512
"//executorch/devtools/bundled_program:runtime",
13+
"//executorch/devtools/etdump:etdump_flatcc",
14+
"//executorch/extension/aten_util:aten_bridge",
1615
"//executorch/extension/data_loader:buffer_data_loader",
1716
"//executorch/extension/data_loader:mmap_data_loader",
17+
"//executorch/extension/flat_tensor:flat_tensor_data_map",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
1919
"//executorch/extension/module:bundled_module",
2020
"//executorch/extension/module:module",
2121
"//executorch/extension/tensor:tensor",
22+
"//executorch/runtime/executor:program",
2223
"//executorch/runtime/executor/test:test_backend_compiler_lib",
23-
"//executorch/devtools/etdump:etdump_flatcc",
24+
"//executorch/runtime/kernel:operator_registry",
2425
] + get_all_cpu_backend_targets()
2526

2627
ATEN_MODULE_DEPS = [
27-
"//executorch/runtime/kernel:operator_registry_aten",
28-
"//executorch/runtime/executor:program_aten",
2928
"//executorch/runtime/core/exec_aten:lib_aten",
3029
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
30+
"//executorch/devtools/bundled_program:runtime_aten",
31+
"//executorch/devtools/etdump:etdump_flatcc",
3132
"//executorch/extension/data_loader:buffer_data_loader",
3233
"//executorch/extension/data_loader:mmap_data_loader",
34+
"//executorch/extension/flat_tensor:flat_tensor_data_map_aten",
3335
"//executorch/extension/memory_allocator:malloc_memory_allocator",
3436
"//executorch/extension/module:bundled_module_aten",
3537
"//executorch/extension/module:module_aten",
3638
"//executorch/extension/tensor:tensor_aten",
37-
"//executorch/devtools/bundled_program:runtime_aten",
3839
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
39-
"//executorch/devtools/etdump:etdump_flatcc",
40+
"//executorch/runtime/executor:program_aten",
41+
"//executorch/runtime/kernel:operator_registry_aten",
4042
]
4143

4244
# Generated lib for all ATen ops with aten kernel used by models in model inventory

0 commit comments

Comments
 (0)