diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index d4d9e1465ce..8d6d30f7903 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -28,6 +28,10 @@ LoweredBackendModule, ) from executorch.exir.pass_base import ExportPass +from executorch.exir.program._fake_program import ( + get_fake_program, + update_to_real_program, +) from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param from torch.export import ExportedProgram @@ -343,8 +347,14 @@ def to_backend( Returns: ExportedProgram: The input program, with some portions targeted for delegation. """ - copied_edge_program = copy.deepcopy(edge_program) - partitioner_result = partitioner_instance(copied_edge_program) + # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. + # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. + try: + fake_edge_program = get_fake_program(edge_program) + except AssertionError as e: + logging.warning(f"No fake mode found for {edge_program.graph_module}: {e}") + fake_edge_program = copy.deepcopy(edge_program) + partitioner_result = partitioner_instance(fake_edge_program) tagged_exported_program = partitioner_result.tagged_exported_program # Check that the partitioner did not modify the original graph @@ -360,6 +370,7 @@ def to_backend( partitioner_result.partition_tags is not None ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" + update_to_real_program(tagged_exported_program, edge_program) tagged_graph_module = _partition_and_lower( tagged_exported_program.graph_module, partitioner_result, edge_program ) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 87a28291730..49da0648a06 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -8,6 +8,7 @@ python_library( "__init__.py", ], deps = [ + ":fake_program", ":program", ], ) @@ -38,3 +39,13 @@ python_library( "//executorch/exir/verification:verifier", ], ) + +python_library( + name = "fake_program", + srcs = [ + "_fake_program.py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/exir/program/__init__.py b/exir/program/__init__.py index e6b290d8c87..4d00297685a 100644 --- a/exir/program/__init__.py +++ b/exir/program/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from executorch.exir.program._fake_program import get_fake_program from executorch.exir.program._program import ( _to_edge, edge_to_executorch_passes, @@ -24,4 +25,6 @@ "edge_to_executorch_passes", "EdgeProgramManager", "ExecutorchProgramManager", + "get_fake_program", + "get_real_program", ] diff --git a/exir/program/_fake_program.py b/exir/program/_fake_program.py new file mode 100644 index 00000000000..ce3eaf86ca1 --- /dev/null +++ b/exir/program/_fake_program.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Dict, Union + +import torch + +from torch._guards import detect_fake_mode +from torch.export import ExportedProgram + + +def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram: + """Create a fake exported program. This uses fake tensors for the state dict + to prevent mutation, and points to the real constants, to avoid large memory + usage from copying when constants are large. + + Args: + real_exported_program: the original exported program + Returns: + A new exported program, with fake tensors. + """ + fake_mode = detect_fake_mode( + tuple( + node.meta["val"] + for node in real_exported_program.graph.nodes + if node.op == "placeholder" + ) + ) + if fake_mode is None: + raise AssertionError( + "Could not detect fake mode for graph: ", real_exported_program.graph + ) + + new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {} + + for key, tensor in real_exported_program.state_dict.items(): + fake = fake_mode.from_tensor(tensor, static_shapes=True) + new_state_dict[key] = fake + + gm = copy.deepcopy(real_exported_program.graph_module) + fake_exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=copy.deepcopy(real_exported_program.graph_signature), + state_dict=new_state_dict, + range_constraints=copy.deepcopy(real_exported_program.range_constraints), + module_call_graph=copy.deepcopy(real_exported_program.module_call_graph), + verifier=real_exported_program.verifier, + constants=real_exported_program.constants, + ) + return fake_exported_program + + +def update_to_real_program( + fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram +) -> None: + """Update the fake exported program to point to the real state dict. Modifies the + fake exported program in-place. + """ + fake_exported_program._state_dict = real_exported_program.state_dict diff --git a/exir/program/test/TARGETS b/exir/program/test/TARGETS index f67702b8117..f9788c42265 100644 --- a/exir/program/test/TARGETS +++ b/exir/program/test/TARGETS @@ -6,6 +6,7 @@ python_unittest( # @autodeps-skip pybindings don't work well with autodeps name = "test_program", srcs = [ + "test_fake_program.py", "test_program.py", ], deps = [ diff --git a/exir/program/test/test_fake_program.py b/exir/program/test/test_fake_program.py new file mode 100644 index 00000000000..02718f2b7e4 --- /dev/null +++ b/exir/program/test/test_fake_program.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import sys +import unittest + +import torch + +from executorch.exir.program._fake_program import ( + get_fake_program, + update_to_real_program, +) +from torch.export import export, ExportedProgram + + +def get_exported_program() -> ExportedProgram: + class Linear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf", torch.randn(10, 10), persistent=False) + + def forward(self, arg) -> torch.Tensor: + return self.linear(arg) + self.buf + + linear = Linear() + exported_program = export( + linear, + args=(torch.randn(10, 10),), + ).run_decompositions() + return exported_program + + +class TestFakeProgram(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_fake_program(self) -> None: + exported_program = get_exported_program() + fake_program = get_fake_program(exported_program) + print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}") + print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}") + + # Fake program deep copies attributes besides verifier, state_dict and constants. + self.assertEqual(exported_program.graph_signature, fake_program.graph_signature) + self.assertNotEqual( + id(exported_program.graph_signature), id(fake_program.graph_signature) + ) + self.assertEqual( + exported_program.module_call_graph, fake_program.module_call_graph + ) + self.assertNotEqual( + id(exported_program.module_call_graph), id(fake_program.module_call_graph) + ) + + # Verifier is static. + self.assertEqual(exported_program.verifier, fake_program.verifier) + self.assertEqual(id(exported_program.verifier), id(fake_program.verifier)) + + # Fake program uses fake tensors for the state dict. Size should be smaller. + self.assertLess( + sys.getsizeof(fake_program.state_dict), + sys.getsizeof(exported_program.state_dict), + ) + + # Do not copy constants. + self.assertEqual(exported_program.constants, fake_program.constants) + self.assertEqual(id(exported_program.constants), id(fake_program.constants)) + + update_to_real_program(fake_program, exported_program) + self.assertEqual(exported_program.state_dict, fake_program.state_dict) + self.assertEqual(id(exported_program.state_dict), id(fake_program.state_dict))