Skip to content

Commit 5b8b7a1

Browse files
lucylqfacebook-github-bot
authored andcommitted
Save external constant tensors to custom filename
Summary: The pass is applied in _program.py, it's hard to do this on the eager model as we have to run the pass after SpecPropPass. Differential Revision: D87280747
1 parent bee30ac commit 5b8b7a1

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

exir/capture/_config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-unsafe
88
from dataclasses import dataclass, field
9-
from typing import Dict, List, Optional, Union
9+
from typing import Callable, Dict, List, Optional, Union
1010

1111
import torch
1212

@@ -94,9 +94,14 @@ class ExecutorchBackendConfig:
9494
# Moreover, static views will be elided from the ExecuTorch graph
9595
remove_view_copy: bool = True
9696

97-
# If set to true, all constant tensors will be stored in a separate file,
98-
# external to the PTE file.
99-
external_constants: bool = False
97+
# Bool: if True, all constant tensors will be stored in a separate file. If False,
98+
# all constant tensors will be stored in the PTE file.
99+
# Callable: a function from torch.fx.Node to Optional[str]. This will tag all
100+
# placeholder (constant tensor) nodes with the string. If None, the constant
101+
# tensor is stored in the PTE file. Otherwise, it is stored in a file named
102+
# by the string. E.g., a function lambda x: "model_weights" will save all
103+
# constants into a file "model_weights.ptd".
104+
external_constants: Union[bool, Callable[[torch.fx.Node], Optional[str]]] = False
100105

101106
# If set to true, all trainable weights will be stored in a separate file,
102107
# external to the PTE file.

exir/emit/test/test_emit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,9 +1716,38 @@ def forward(self, x):
17161716
external_map = emitter_output.external_constant_map[
17171717
"_default_external_constant"
17181718
]
1719+
self.assertEqual(len(external_map), 2)
17191720
self.assertEqual(external_map["linear.weight"], 0)
17201721
self.assertEqual(external_map["linear.bias"], 1)
17211722

1723+
def test_constant_tagged_tensors_custom(self) -> None:
1724+
class LinearModule(torch.nn.Module):
1725+
def __init__(self):
1726+
super().__init__()
1727+
self.linear = torch.nn.Linear(5, 5)
1728+
1729+
def forward(self, x):
1730+
return self.linear(x)
1731+
1732+
model = to_edge(
1733+
export(LinearModule(), (torch.ones(5, 5),), strict=True)
1734+
).to_executorch(
1735+
config=ExecutorchBackendConfig(
1736+
external_constants=lambda x: (
1737+
"linear_weight" if "weight" in x.name else None
1738+
),
1739+
)
1740+
)
1741+
emitter_output = model._emitter_output
1742+
# constant_buffer contains placeholder and linear bias.
1743+
self.assertEqual(len(emitter_output.program.constant_buffer), 2)
1744+
# external constant buffer contains linear weight.
1745+
self.assertEqual(len(emitter_output.external_constant_buffer), 1)
1746+
# The lambda saves all constants to the key 'linear_weight'.
1747+
external_map = emitter_output.external_constant_map["linear_weight"]
1748+
self.assertEqual(len(external_map), 1)
1749+
self.assertEqual(external_map["linear.weight"], 0)
1750+
17221751
def test_constant_tagged_tensor_dedup(self) -> None:
17231752
class ConstantModule(nn.Module):
17241753
def __init__(self):

exir/program/_program.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,10 +1734,16 @@ def to_executorch( # noqa (FLAKE8) C901
17341734
# TODO(who?)
17351735
p.update_placeholder_tensor_specs(program, new_gm)
17361736

1737-
# Extract constants if the config says too.
1738-
if config.external_constants:
1737+
# Extract constants if the config says to.
1738+
if (
1739+
isinstance(config.external_constants, bool)
1740+
and config.external_constants
1741+
):
17391742
new_gm_res = external_constants_pass(new_gm)
17401743
new_gm = new_gm_res.graph_module
1744+
elif callable(config.external_constants):
1745+
new_gm_res = external_constants_pass(new_gm, config.external_constants)
1746+
new_gm = new_gm_res.graph_module
17411747
elif config.external_mutable_weights:
17421748
new_gm_res = external_mutable_weights_pass(new_gm, program)
17431749
new_gm = new_gm_res.graph_module

0 commit comments

Comments
 (0)