Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-unsafe
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -94,9 +94,14 @@ class ExecutorchBackendConfig:
# Moreover, static views will be elided from the ExecuTorch graph
remove_view_copy: bool = True

# If set to true, all constant tensors will be stored in a separate file,
# external to the PTE file.
external_constants: bool = False
# Bool: if True, all constant tensors will be stored in a separate file. If False,
# all constant tensors will be stored in the PTE file.
# Callable: a function from torch.fx.Node to Optional[str]. This will be called for each
# placeholder (constant tensor) node, and if it returns a string, that node will be
# tagged with the string. If None, the constant tensor is stored in the PTE file.
# Otherwise, it is stored in a file named by the string. E.g., a function
# lambda x: "model_weights" will save all constants into a file "model_weights.ptd".
external_constants: Union[bool, Callable[[torch.fx.Node], Optional[str]]] = False

# If set to true, all trainable weights will be stored in a separate file,
# external to the PTE file.
Expand Down
29 changes: 29 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,9 +1717,38 @@ def forward(self, x):
external_map = emitter_output.external_constant_map[
"_default_external_constant"
]
self.assertEqual(len(external_map), 2)
self.assertEqual(external_map["linear.weight"], 0)
self.assertEqual(external_map["linear.bias"], 1)

def test_constant_tagged_tensors_custom(self) -> None:
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)

def forward(self, x):
return self.linear(x)

model = to_edge(
export(LinearModule(), (torch.ones(5, 5),), strict=True)
).to_executorch(
config=ExecutorchBackendConfig(
external_constants=lambda x: (
"linear_weight" if "weight" in x.name else None
),
)
)
emitter_output = model._emitter_output
# constant_buffer contains placeholder and linear bias.
self.assertEqual(len(emitter_output.program.constant_buffer), 2)
# external constant buffer contains linear weight.
self.assertEqual(len(emitter_output.external_constant_buffer), 1)
# The lambda saves all constants to the key 'linear_weight'.
external_map = emitter_output.external_constant_map["linear_weight"]
self.assertEqual(len(external_map), 1)
self.assertEqual(external_map["linear.weight"], 0)

def test_constant_tagged_tensor_dedup(self) -> None:
class ConstantModule(nn.Module):
def __init__(self):
Expand Down
14 changes: 11 additions & 3 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,11 +1737,19 @@ def to_executorch( # noqa (FLAKE8) C901
# TODO(who?)
p.update_placeholder_tensor_specs(program, new_gm)

# Extract constants if the config says too.
if config.external_constants:
# Tag constant weights.
if (
isinstance(config.external_constants, bool)
and config.external_constants
):
new_gm_res = external_constants_pass(new_gm)
new_gm = new_gm_res.graph_module
elif config.external_mutable_weights:
elif callable(config.external_constants):
new_gm_res = external_constants_pass(new_gm, config.external_constants)
new_gm = new_gm_res.graph_module

# Tag mutable weights.
if config.external_mutable_weights:
new_gm_res = external_mutable_weights_pass(new_gm, program)
new_gm = new_gm_res.graph_module

Expand Down
Loading