diff --git a/examples/selective_build/CMakeLists.txt b/examples/selective_build/CMakeLists.txt index 29791187185..239cdc828de 100644 --- a/examples/selective_build/CMakeLists.txt +++ b/examples/selective_build/CMakeLists.txt @@ -118,7 +118,11 @@ add_executable(selective_build_test ${_executor_runner__srcs}) if(CMAKE_BUILD_TYPE EQUAL "RELEASE") target_link_options(selective_build_test PRIVATE "LINKER:--gc-sections") endif() -target_link_libraries(selective_build_test executorch gflags select_build_lib) +target_link_libraries( + selective_build_test PRIVATE executorch gflags select_build_lib +) +target_link_options_shared_lib(select_build_lib) +target_link_options_shared_lib(executorch) target_compile_options(selective_build_test PUBLIC ${_common_compile_options}) # Print all summary diff --git a/exir/capture/_config.py b/exir/capture/_config.py index d743e4b0329..a2d3b53bcb6 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -75,3 +75,7 @@ class ExecutorchBackendConfig: # be a power of 2. If not provided, uses the value in the schema file. delegate_alignment: Optional[int] = None sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass() + + # If set to true, view_copy operations will be converted to lightweight + # view operations in the ET runtime + remove_view_copy: bool = True diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index af5614bf208..3238c23eda0 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -844,6 +844,32 @@ def _emit_control_flow( ) ) + def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue: + assert len(args) == 2 + + self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6] + size_arg = self._emit_argument(args[1], torch.ListType.ofInts()) + out_arg = self._emit_argument( + self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6] + ) + + op_idx, op = self._get_operator( + name="executorch_prim::et_view", + overload="default", + ) + kernel = Instruction( + KernelCall( + op_idx, + args=[ + self_arg.id, + size_arg.id, + out_arg.id, + ], + ) + ) + self.chain.instructions.append(kernel) + return out_arg + def _add_debug_handle(self, emitter_id: int, target: _Target) -> None: """Updates the debug handle information for the current node. @@ -1198,6 +1224,9 @@ def call_function( assert len(args) == 1 return self._emit_spec(self.node.meta["spec"]) + elif target == memory.view: + return self._emit_view(args) + elif target == memory.free: assert len(args) == 1 # pyre-ignore diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3eebe52faef..b55fb5e5dae 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -265,16 +265,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: edge = to_edge(export(f, inputs)) removed_ops = ["aten::relu_", "aten::view"] - expected_ops = ["aten::sin", "aten::relu", "aten::max", "aten::view_copy"] + expected_ops = [ + "aten::sin", + "aten::relu", + "aten::max", + "executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False + ] for opname in removed_ops: self.assertEqual( self.count_node(edge.exported_program().graph_module, opname), 0 ) for opname in expected_ops: - self.assertTrue( - self.count_node(edge.exported_program().graph_module, opname) >= 1 - ) + if ( + opname != "executorch_prim::et_view" + ): # et_view appears as call_function with target = memory.view in graph + self.assertTrue( + self.count_node(edge.exported_program().graph_module, opname) >= 1 + ) program = edge.to_executorch().executorch_program for opname in removed_ops: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index b8c47b440c5..675f196fcd8 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -397,6 +397,7 @@ def collect_specs_from_nodes( # noqa: C901 or node.target in [ memory.alloc, + memory.view, operator.getitem, torch.ops.higher_order.cond, exir_while, @@ -534,7 +535,13 @@ def get_node_tensor_specs( has no tensor specs. """ # get tensor specs - specs = node.meta.get("spec") + if node.target == memory.view: + base = node.args[0] + assert isinstance(base, torch.fx.Node) + specs = base.meta.get("spec") + else: + specs = node.meta.get("spec") + if isinstance(specs, TensorSpec): specs = [specs] if not isinstance(specs, (list, tuple)): diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 2611d6a1541..f43b2973a4e 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -248,6 +248,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: # we won't see it in the input graph to the to_out_variant pass, unless # it's retraced after running to_out_variant with the first trace. memory.alloc, + memory.view, executorch_call_delegate, torch.ops.aten.copy_.default, } diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 33f98304174..a9304f3eec8 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -6,9 +6,9 @@ # pyre-strict +import copy import logging -import math -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple import torch from executorch.exir import memory @@ -36,28 +36,113 @@ def _is_view_copy(node: torch.fx.Node) -> bool: _VIEW_OP = memory.view +class _Guard: + def __init__( + self, name: str, field_lambda, expected_val: Any # pyre-ignore[2] + ) -> None: + self.name: str = name + self.field_lambda = field_lambda # pyre-ignore[4] + self.expected_val = copy.deepcopy(expected_val) # pyre-ignore[4] + + def __call__(self, view_spec) -> None: # pyre-ignore[2] + assert view_spec._unguarded_access + observed_val = self.field_lambda(view_spec) + if observed_val != self.expected_val: + raise Exception( + f"Guard {self.name} failed. Expected to see value {self.expected_val}, but saw value {observed_val}." + ) + + class _ViewSpec(TensorSpec): def __init__(self, base: TensorSpec, shape: List[int]) -> None: """ - A ViewSpec is an immutable TensorSpec that mirrors its base for non-size - related information. - """ + A _ViewSpec is TensorSpec that shares non-size related fields with its base. + The size-related fields are: shape, stride, dim_order, and shape_dynamism. - if math.prod(base.shape) != math.prod(shape): - raise Exception( - f"Cannot create a ViewSpec because the provided shape {shape} is not consistent with the number of elements in the provided base ({math.prod(base.shape)})." - ) + If either the base or view spec updates a non-size related field, the change + is reflected in both specs. But size related fields are not linked and can + be set separately. - self._init_setters = [ - "_frozen", - "_base", - "_guards", + A _ViewSpec can only be created from a non-sparse, strided TensorSpec. + On creation, a _ViewSpec must be compatible with its base with respect to + shape_dynamism, dtype, and nbytes. + + A _ViewSpec contains _guards that are evaluated on every __getattribute__ call. + The purpose of the guards is to make sure the _ViewSpec is still compatible + with its base. + """ + + # Explicitly put all attributes into _self_fields or _base_fields + # Any attribute that is not in _self_fields or _base_fields will + # raise an Exception. If TensorSpec is extended with a new attribute, + # we should explicitly decide how _ViewSpec will handle it. + self._self_fields = [ + # We need to get the debug method from self + # so that the object id it prints is correct. + "debug", # method + "__repr__", # method + # The following are related to size and should use self "shape", "stride", "dim_order", "shape_dynamism", + "nbytes", # method + "allocated_memory", # property + "is_dynamic_shape_tensor", # property + "is_static_shape_tensor", # property + "is_upper_bound_tensor", # property + "is_dynamic_unbound_tensor", # property + ] + self._base_fields = [ + "scalar_type", + "const", + "alignment", + "storage", + "requires_grad", + "layout", + "is_sparse", + "init_mem_planning_fields", # method + "realign", # method + "from_tensor", # class method + "lifetime", + "mem_id", + "mem_obj_id", + "mem_offset", + "dtype", # property ] - self._frozen = False + + # Make sure _self_fields and _base_fields are disjoint + assert len(set(self._self_fields) & set(self._base_fields)) == 0 + + self._guards: List[_Guard] = [] + self._unguarded_access = False + + # Make sure base is not sparse and add a guard + if base.is_sparse: + raise Exception( + "_ViewSpec can only be created from non-sparse TensorSpec, but base.is_sparse=True." + ) + self._guards.append( + _Guard( + "is_sparse", + lambda view_spec: view_spec.is_sparse, + False, + ) + ) + + # Make sure base layout is strided and add a guard + if base.layout != torch.strided: + raise Exception( + f"_ViewSpec can only be created from TensorSpec with layout={torch.strided}, but got layout={base.layout}." + ) + self._guards.append( + _Guard( + "layout", + lambda view_spec: view_spec.layout, + torch.strided, + ) + ) + self._base = base self.shape: List[int] = shape self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape)) @@ -66,66 +151,108 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: torch.Size(self.shape) ) - # This spec gives a view into its base. - # The base can be modified (e.g., mem_id) and this spec will - # update accordingly, but certain fields we do not expect to change - # We create guards for these - self._guards: Dict[str, Any] = { - "shape_dynamism": base.shape_dynamism, - "scalar_type": base.scalar_type, - "layout": base.layout, - "is_sparse": base.is_sparse, - } - self._frozen = True - - def _check_guards(self) -> None: - for name in self._guards: - if getattr(self._base, name) != self._guards[name]: - raise Exception( - f"The guarded attribute '{name}' has changed value. At creation of the ViewSpec, it was {self._guards[name]}, but it is now {getattr(self._base, name)}." - ) + # Check compatibility with base on creation + if self.shape_dynamism != base.shape_dynamism: + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}." + ) + self._guards.append( + _Guard( + "shape_dynamism_init", + lambda view_spec: view_spec.shape_dynamism, + base.shape_dynamism, + ) + ) + self._guards.append( + _Guard( + "shape_dynamism_eq_base", + lambda view_spec: view_spec.shape_dynamism + == view_spec._base.shape_dynamism, + True, + ) + ) + + if self.dtype != base.dtype: + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has dtype={self.dtype}, but its base has dtype={base.dtype}." + ) + self._guards.append( + _Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype) + ) + + # We do not guard nbytes because dynamic symints are replaced by upper bounds. + # We do guard on rank, though + if self.nbytes() != base.nbytes(): + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}." + ) + self._guards.append( + _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape)) + ) - def __getattribute__(self, name): # pyre-ignore + def _run_guards(self) -> None: + unguarded_access = self._unguarded_access + try: + self._unguarded_access = True + for g in self._guards: + g(self) + finally: + self._unguarded_access = unguarded_access + + def __getattribute__(self, name: str): # pyre-ignore + # Special field so we don't recurse infinitely if name in [ - "_init_setters", - "_frozen", "_base", + "_self_fields", + "_base_fields", "_guards", - "_check_guards", - # Adding debug is needed so that view_spec.debug() shows the right id in - # its string (if debug is excluded, it shows the id(view_spec._base) instead - # of id(view_spec)) - "debug", + "_unguarded_access", + "_run_guards", ]: return object.__getattribute__(self, name) - # Guard check after freeze - if self._frozen: - self._check_guards() + # Get some attributes from self + if name in self._self_fields: + val = object.__getattribute__(self, name) + elif name in self._base_fields: + val = object.__getattribute__(self._base, name) + else: + if len(name) > 0 and name[0] != "_": + logger.warning( + f"Getting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" + ) + val = object.__getattribute__(self, name) - # self._init_setters attributes come from self, others come from base - if name in self._init_setters: - return object.__getattribute__(self, name) - return getattr(self._base, name) + if not self._unguarded_access: + self._run_guards() + return val def __setattr__(self, name: str, val) -> None: # pyre-ignore - if name in ["_init_setters", "_frozen"]: + # Special field so we don't recurse infinitely + if name in [ + "_base", + "_self_fields", + "_base_fields", + "_guards", + "_unguarded_access", + "_run_guards", + ]: object.__setattr__(self, name, val) return - # Allow setting during initialization - if name in self._init_setters and not self._frozen: + if name in self._self_fields: object.__setattr__(self, name, val) return - if name in self._init_setters: - raise Exception( - f"ViewSpec is immutable. Cannot set the attribute '{name}' after creation." - ) + if name in self._base_fields: + object.__setattr__(self._base, name, val) + return - raise Exception( - f"ViewSpec is immutable. To update the non-size related attribute '{name}', update the base." - ) + if len(name) > 0 and name[0] != "_": + logger.warning( + f"Setting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" + ) + object.__setattr__(self, name, val) class ReplaceViewCopyWithViewPass(PassBase): @@ -151,8 +278,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: node.target = _VIEW_OP # Create spec for the node. - # _ViewSpec is an immutable TensorSpec gives a view into - # its base spec for non-size related information. + # _ViewSpec gives a view into its base spec for non-size + # related information. # the shape is not the same as node.args[1] because node.args[1] # can have an inferred sizes (-1). diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 49da0648a06..5ae3cf1ac59 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -33,8 +33,10 @@ python_library( "//executorch/exir/emit:lib", "//executorch/exir/passes:insert_write_back_for_buffers_pass", "//executorch/exir/passes:lib", + "//executorch/exir/passes:normalize_view_copy_base_pass", "//executorch/exir/passes:remove_graph_asserts_pass", "//executorch/exir/passes:remove_mixed_type_operators", + "//executorch/exir/passes:replace_view_copy_with_view_pass", "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/verification:verifier", ], diff --git a/exir/program/_program.py b/exir/program/_program.py index 10d0043398f..086768b879d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -31,8 +31,14 @@ from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) +from executorch.exir.passes.normalize_view_copy_base_pass import ( + NormalizeViewCopyBasePass, +) from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators +from executorch.exir.passes.replace_view_copy_with_view_pass import ( + ReplaceViewCopyWithViewPass, +) from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.print_program import pretty_print, print_program from executorch.exir.schema import Program @@ -615,8 +621,24 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": return new_ep +def pre_memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]: + if config.remove_view_copy: + # pyre-ignore + return [ + NormalizeViewCopyBasePass(), + ReplaceViewCopyWithViewPass(), + config.sym_shape_eval_pass, + config.to_out_var_pass, + ] + else: + # pyre-ignore + return [ + config.sym_shape_eval_pass, + config.to_out_var_pass, + ] + + def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]: - # pyre-ignore passes: List[PassType] = [ *config.passes, SpecPropPass(), @@ -625,9 +647,8 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType] # there exists an unbacked symint operation. EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), - config.sym_shape_eval_pass, - config.to_out_var_pass, - ] + ] + pre_memory_planning_passes(config) + return passes diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 0c3232916d6..94a82d8a2bc 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -411,3 +411,17 @@ python_unittest( "//executorch/exir:print_program", ], ) + +python_unittest( + name = "test_remove_view_copy", + srcs = [ + "test_remove_view_copy.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:memory", + "//executorch/exir/capture:config", + "//executorch/exir/passes:lib", + ], +) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 848269c6573..bfa0d393235 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1498,61 +1498,29 @@ def __init__(self): self.parameter = torch.nn.Parameter(torch.ones(1)) def forward(self, x): - o1 = torch.ops.aten.view_copy.default( - self.parameter, [1] - ) # replaceable parameter - o2 = torch.ops.aten.view_copy.default(x, [1]) # replaceable user input - o3 = torch.ops.aten.view_copy.default( - torch.ops.aten.relu.default(x), [1] - ) # replaceable dynamic unbound - o4 = torch.ops.aten.view_copy.default( - torch.ops.aten.gelu.default(x), [1] - ) # replaceable dynamic bound - o5 = torch.ops.aten.view_copy.default( - torch.ops.aten.tanh.default(x), [1] - ) # replaceable static - return o1, o2, o3, o4, o5 + o1 = torch.ops.aten.view_copy.default(x, [1]) + o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) + return o1, o2 ep = torch.export.export( TestViewCopies(), args=(torch.ones(1),), ) - self.assertEqual(len(ep.graph.nodes), 11) for node in ep.graph.nodes: if node.op == "placeholder": node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC - elif node.target == torch.ops.aten.relu.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND - elif node.target == torch.ops.aten.gelu.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND - elif node.target == torch.ops.aten.tanh.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC - elif node.target == torch.ops.aten.view_copy.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = ( - node.args[0].meta["spec"].shape_dynamism - ) - else: - pass # Run tests gm = ep.graph_module # Check before transformation - n_view_copy_before = 0 - n_memory_view_before = 0 - for node in gm.graph.nodes: - if is_view_copy(node): - n_view_copy_before += 1 - if is_memory_view(node): - n_memory_view_before += 1 - - self.assertEqual(n_view_copy_before, 5) - self.assertEqual(n_memory_view_before, 0) + FileCheck().check_count( + "torch.ops.aten.view_copy.default", 2, exactly=True + ).run(gm.code) + FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run( + gm.code + ) # Do transformation p = ReplaceViewCopyWithViewPass() @@ -1560,14 +1528,10 @@ def forward(self, x): assert gm_res is not None gm = gm_res.graph_module - # Check after transformation - n_view_copy_after = 0 - n_memory_view_after = 0 - for node in gm.graph.nodes: - if is_view_copy(node): - n_view_copy_after += 1 - if is_memory_view(node): - n_memory_view_after += 1 - - self.assertEqual(n_view_copy_after, 0) - self.assertEqual(n_memory_view_after, 5) + # Check before transformation + FileCheck().check_count( + "torch.ops.aten.view_copy.default", 0, exactly=True + ).run(gm.code) + FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( + gm.code + ) diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 00269da92d7..69610a73abe 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -117,7 +117,7 @@ def forward(self, x, y): m.exported_program.graph_module.code ) - m = m.to_executorch() + m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False)) # check that we are using out variant of q/dq/add FileCheck().check("torch.ops.quantized_decomposed.add.out").check( "torch.ops.aten.view_copy.out" diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py new file mode 100644 index 00000000000..0c5b61f8d8f --- /dev/null +++ b/exir/tests/test_remove_view_copy.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn as nn +from executorch.exir import memory, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass + + +class TestModel1(nn.Module): + def __init__(self): + super().__init__() + self.parameter = nn.Parameter(torch.rand(5, 6)) + self.parameter.requires_grad = False + + def forward(self, x): + v1 = self.parameter.view( + 6, 5 + ) # removed, lifetime of parameter will be extended + v2 = x.view(6, 5) # not removed + v3 = torch.ops.aten.mul.Tensor(v1, v2).view( + 30 + ) # removed, lifetime of mul.Tensor will be extended + return v3 + + def get_example_inputs(self): + return (torch.rand(5, 6),) + + +class TestRemoveViewCopy(unittest.TestCase): + def test_disable(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + assert node.target != memory.view + + def test_output_matches(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + + epm_remove = to_edge(ep) + epm_no_remove = copy.deepcopy( + epm_remove + ) # to_executorch modifies the edge_program, so we make a copy + + # Run pass with no removal + etpm_remove = epm_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + # Run pass with removal + etpm_no_remove = epm_no_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + out_remove = etpm_remove.exported_program().module()(*example_inputs) + out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + + self.assertTrue(torch.allclose(out_remove, out_no_remove)) + + def test_spec(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + # etpm.exported_program().graph.print_tabular() + + # idx opcode name target args kwargs + # --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- -------------- + # 0 placeholder p_parameter p_parameter () {} + # 1 placeholder x x () {} + # 2 call_function aten_view_copy_default (p_parameter, [6, 5]) {} + # 3 call_function aten_view_copy_default_1 (x, [6, 5]) {} + # 4 call_function alloc (((6, 5), torch.float32),) {} + # 5 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc} + # 6 call_function aten_view_copy_default_2 (aten_mul_tensor, [30]) {} + # 7 output output_1 output ((aten_view_copy_default_2,),) {} + + for node in etpm.exported_program().graph.nodes: + if node.name == "p_parameter": + # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 5 + self.assertEqual(node.meta["spec"].lifetime, [0, 5]) + elif node.name == "aten_view_copy_default": + # aten_view_copy_default is a memory.view of p_parameter. + # p_parameter is a constant with storage, so we check that the view's storage matches the base + + # assert base is p_parameter + self.assertEqual(node.args[0].name, "p_parameter") + + # assert base is const with storage + self.assertTrue(node.args[0].meta["spec"].const) + self.assertTrue(node.args[0].meta["spec"].storage is not None) + self.assertTrue(node.args[0].meta["spec"].mem_id is None) + self.assertTrue(node.args[0].meta["spec"].mem_offset is None) + + # assert self is const with storage + self.assertTrue(node.meta["spec"].const) + self.assertTrue(node.meta["spec"].storage is not None) + self.assertTrue(node.meta["spec"].mem_id is None) + self.assertTrue(node.meta["spec"].mem_offset is None) + + # assert storage matches + self.assertEqual( + node.meta["spec"].storage, node.args[0].meta["spec"].storage + ) + + # assert lifetime matches + self.assertEqual( + node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime + ) + elif node.name == "aten_mul_tensor": + # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 7 + self.assertEqual(node.meta["spec"].lifetime, [4, 7]) + elif node.name == "aten_view_copy_default_2": + # aten_view_copy_default_2 is a memory.view of aten_mul_tensor + + # assert base is aten_mul_tensor + self.assertEqual(node.args[0].name, "aten_mul_tensor") + + # assert base and self are not const, do not have storage, + # but do have mem_id and mem_offset + self.assertFalse(node.args[0].meta["spec"].const) + self.assertTrue(node.args[0].meta["spec"].storage is None) + self.assertTrue(node.args[0].meta["spec"].mem_id is not None) + self.assertTrue(node.args[0].meta["spec"].mem_offset is not None) + + self.assertFalse(node.meta["spec"].const) + self.assertTrue(node.meta["spec"].storage is None) + self.assertTrue(node.meta["spec"].mem_id is not None) + self.assertTrue(node.meta["spec"].mem_offset is not None) + + # assert self and base mem_id, mem_offset, and lifetime matches + self.assertEqual( + node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id + ) + self.assertEqual( + node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset + ) + self.assertEqual( + node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime + ) + + # Test evalues in execution plan + plan = etpm.executorch_program.execution_plan[0] + self.assertEqual(plan.operators[0].name, "executorch_prim::et_view") + self.assertEqual(plan.operators[1].name, "aten::mul") + + instructions = plan.chains[0].instructions + self.assertEqual(len(instructions), 4) + + self.assertEqual( + instructions[0].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx2 + self.assertEqual( + instructions[1].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx3 + self.assertEqual( + instructions[2].instr_args.op_index, 1 # pyre-ignore + ) # aten:mul @ idx5 + self.assertEqual( + instructions[3].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx6 diff --git a/kernels/prim_ops/et_view.cpp b/kernels/prim_ops/et_view.cpp index 69a75170260..b3d3592fe7b 100644 --- a/kernels/prim_ops/et_view.cpp +++ b/kernels/prim_ops/et_view.cpp @@ -87,18 +87,7 @@ void et_view(RuntimeContext& context, EValue** stack) { // Do some checks ET_CHECK(self.numel() == out.numel()); - // If out has a data_ptr, it must match self - // We hit this path for memory-planned tensors - if (out.const_data_ptr() != nullptr) { - ET_CHECK_MSG( - self.const_data_ptr() == out.const_data_ptr(), - "out has a non-null data_ptr, but it does not equal self's data_ptr."); - - // nothing else to do - return; - } - - // out.const_data_ptr() == nullptr now + // Update data ptr ET_CHECK_MSG( internal::set_tensor_data( out, diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index fdcc13cf13e..7d91a0f6820 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -331,14 +331,13 @@ TEST_F(RegisterPrimOpsTest, TestETView) { EValue(good_outs[0]), EValue(good_outs[1])}; // bad outs expect death - constexpr int N_BAD_OUTS = 3; + constexpr int N_BAD_OUTS = 2; Tensor bad_outs[N_BAD_OUTS] = { tf.ones({1, 3, 2, 1}), // wrong rank - tf.ones({1, 3, 3}), // wrong size - tf.ones({1, 3, 2}) // occupied data_ptr + tf.ones({1, 3, 3}) // wrong size }; EValue bad_out_evalues[N_BAD_OUTS] = { - EValue(bad_outs[0]), EValue(bad_outs[1]), EValue(bad_outs[2])}; + EValue(bad_outs[0]), EValue(bad_outs[1])}; // *************************************************************************** // Run tests @@ -349,7 +348,6 @@ TEST_F(RegisterPrimOpsTest, TestETView) { // Bad out stacks {&self_evalue, &size_int_list_evalue, &bad_out_evalues[0]}, {&self_evalue, &size_int_list_evalue, &bad_out_evalues[1]}, - {&self_evalue, &size_int_list_evalue, &bad_out_evalues[2]}, // Bad size stacks {&self_evalue, &bad_size_int_list_evalue1, &good_out_evalues[0]}, {&self_evalue, &bad_size_int_list_evalue2, &good_out_evalues[0]}};