diff --git a/exir/experimental/TARGETS b/exir/experimental/TARGETS deleted file mode 100644 index 418ee54df6e..00000000000 --- a/exir/experimental/TARGETS +++ /dev/null @@ -1,25 +0,0 @@ -load("@fbcode_macros//build_defs:python_library.bzl", "python_library") - -oncall("executorch") - -python_library( - name = "export_pt2", - srcs = ["export_pt2.py"], - deps = [ - "//caffe2:torch", - "//executorch/exir:error", - "//executorch/exir:lib", - "//executorch/exir:tracer", - ], -) - -python_library( - name = "lib", - srcs = [ - "__init__.py", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:tensor", - ], -) diff --git a/exir/experimental/__init__.py b/exir/experimental/__init__.py deleted file mode 100644 index c3e1dc8317f..00000000000 --- a/exir/experimental/__init__.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import copy -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils._pytree as pytree -from executorch.exir.tensor import TensorSpec -from torch._export.serde.schema import TensorMeta -from torch._export.serde.serialize import ( - _SERIALIZE_TO_TORCH_DTYPE, - serialize_tensor_meta, -) -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.fx.experimental.symbolic_shapes import ShapeEnv - - -def add_assertions(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: - modified_graph_module = copy.deepcopy(graph_module) - - graph = modified_graph_module.graph - for node in graph.nodes: - if node.op != "call_function" and node.op != "placeholder": - continue - - # Ignore constants - if node.meta.get("val", None) is None: - continue - - # Ignore non-torch ops - if node.op == "call_function" and ( - not isinstance(node.target, torch._ops.OpOverload) - ): - continue - - shape = node.meta["val"].shape - dtype = node.meta["val"].dtype - node_name = node.name - with graph.inserting_after(node): - - def check_spec( - x: TensorSpec, shape: List[int], dtype: torch.dtype, node_name: str - ) -> None: - assert list(x.shape) == list( - shape - ), f"Expected {node_name} shape to be {shape}, got {x.shape}" - assert ( - x.dtype == dtype - ), f"Expected {node_name} dtype to be {dtype}, got {x.dtype}" - - graph.call_function(check_spec, (node, shape, dtype, node_name)) - - modified_graph_module.recompile() - return modified_graph_module - - -def convert_fake_tensor_to_tensor_meta( - ep: torch.fx.GraphModule, -) -> Tuple[torch.fx.GraphModule, Optional[ShapeEnv]]: - """ - Replace the faketensor metadata with the tensor metadata dataclass since we - cannot serialize faketensors - """ - shape_env = None - for node in ep.graph.nodes: - - def get_shape_env( - val: Union[List[FakeTensor], FakeTensor] - ) -> Optional[ShapeEnv]: - val_flat, _ = pytree.tree_flatten(val) - curr_shape_env = None - for v in val_flat: - if not isinstance(v, FakeTensor): - continue - if curr_shape_env is None: - curr_shape_env = v.fake_mode.shape_env - else: - assert ( - curr_shape_env is v.fake_mode.shape_env - ), "Multiple shape envs detected." - return curr_shape_env - - if (val := node.meta.get("val", None)) is not None: - if shape_env is None: - shape_env = get_shape_env(val) - elif (new_shape_env := get_shape_env(val)) is not None: - assert shape_env is new_shape_env, "Multiple shape envs detected." - - node.meta["tensor_meta"] = pytree.tree_map_only( - torch.Tensor, serialize_tensor_meta, val - ) - del node.meta["val"] - - return ep, shape_env - - -def convert_tensor_meta_to_fake_tensor( - ep: torch.fx.GraphModule, shape_env: Optional[ShapeEnv] = None -) -> torch.fx.GraphModule: - """ - Replace (inplace) the tensor metadata with faketensor - """ - fake_tensor_mode: FakeTensorMode = FakeTensorMode( - allow_non_fake_inputs=True, shape_env=shape_env - ) - for node in ep.graph.nodes: - if (val := node.meta.get("tensor_meta", None)) is not None: - - def _extract_faketensor(tensor_meta: TensorMeta) -> FakeTensor: - return FakeTensor( - fake_tensor_mode, - torch.empty( - # TODO Support dynamic shape. - tuple(s.as_int for s in tensor_meta.sizes), - dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], - device="meta", - requires_grad=tensor_meta.requires_grad, - ), - torch.device("cpu"), - ) - - node.meta["val"] = pytree.tree_map_only( - TensorMeta, _extract_faketensor, val - ) - return ep diff --git a/exir/experimental/export_pt2.py b/exir/experimental/export_pt2.py deleted file mode 100644 index df040147f4b..00000000000 --- a/exir/experimental/export_pt2.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -# This class is for prototyping PyTorch 2.0 Export -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, List, Optional, Tuple - -import torch - -from executorch import exir -from executorch.exir import CaptureConfig -from executorch.exir.error import ExportError, ExportErrorType, InternalError - -from executorch.exir.tracer import Value - -from torch._dynamo.guards import Guard as DynamoGuard - - -class GuardType(Enum): - TENSOR_MATCH = 1 - - -class GuardResolution(Enum): - IGNORE = 1 - CHECK_AT_RUNTIME = 2 - ERROR_AT_EXPORT = 3 - - -@dataclass -class Guard: - """ - This is our own custom Guard class to store - information needed for EXIR. This will only - store things we actually need. - """ - - guard_type: GuardType - obj: Any # pyre-ignore - check_code: str - - -@dataclass -class Trace: - """ - Immutable object that abstracts the result of exir.trace - which is essentially a torch.fx.GraphModule plus all the assumptions - that are made about this tracing that are represented as Guard. - """ - - graph_module: torch.fx.GraphModule - guards: List[Guard] - inputs: Tuple[Value] - - -class ExportSession: - def __init__(self, trace: Trace) -> None: - """ - Mutable object where user can interactively resolve guards to access the final graph_module. - """ - self.trace = trace - self.guard_rules: List[Callable[[Guard], Optional[GuardResolution]]] = [] - - # TODO (make more specific rule) - def default_rule(guard: Guard) -> Optional[GuardResolution]: - if guard.guard_type != GuardType.TENSOR_MATCH: - return GuardResolution.IGNORE - return None - - self.guard_rules.append(default_rule) - - def summary(self) -> str: - """ - Prints the current status of guard resolutions in a module - hierarchical way. - """ - # TODO implement this - return "" - - def export(self) -> Optional[torch.fx.GraphModule]: - """ - Exports a final GraphModule that is ready to be executed. - This will require that all guards imposed on GraphModule are - resolved. - """ - - def _guard_remaining_filter(guard: Guard) -> bool: - guard_resolutions: List[Optional[GuardResolution]] = [ - guard_rule(guard) for guard_rule in self.guard_rules - ] - # if there was no guard resolutions, we should keep the guard - if len(guard_resolutions) == 0: - return True - - # later rules take priority - for idx in range(len(guard_resolutions) - 1, -1, -1): - if guard_resolutions[idx] is None: - continue - assert guard_resolutions is not None - if guard_resolutions[idx] in [ - GuardResolution.CHECK_AT_RUNTIME, - GuardResolution.IGNORE, - ]: - return False - if guard_resolutions[idx] == GuardResolution.ERROR_AT_EXPORT: - return True - # nothing has been resolved - return True - - remaining_guards = list(filter(_guard_remaining_filter, self.trace.guards)) - if len(remaining_guards) > 0: - raise ExportError( - ExportErrorType.VIOLATION_OF_SPEC, - "There are outstanding guards to be resolved to export this graph", - ) - return self.trace.graph_module - - def add_guard_rule( - self, guard_rule: Callable[[Guard], Optional[GuardResolution]] - ) -> None: - """ - Adds user provided guard rule. This rule will be applied when you call export() method. - """ - self.guard_rules.append(guard_rule) - - -def trace(root: Callable[..., Value], concrete_args: Tuple[Value, ...]) -> Trace: - """ - Runs torchdynamo with no-python mode and dispatch trace - to create a Trace object which is graph module plus guards that - need to be resolved. - """ - # TODO (yidi) cannot enable functionalization under exir.capture() pt2 mode - graph_module = exir.capture( - root, - concrete_args, - CaptureConfig(enable_functionalization=False), - ).graph_module - - # TODO convert torchdynamo guards to our own guards - def _convert_dynamo_guard_to_exir_guard( - dynamo_guard: DynamoGuard, - ) -> Optional[Guard]: - if dynamo_guard.guard_types is not None and len(dynamo_guard.guard_types) > 0: - # TODO (make sure this list is always element of 1) - guard_type = dynamo_guard.guard_types[0] - # TODO (add more guard types) - if guard_type == "TENSOR_MATCH": - # pyre-fixme[29]: `Optional[object]` is not a function. - return Guard(GuardType.TENSOR_MATCH, dynamo_guard.obj_weakref(), "") - - raise InternalError(f"Unregistered guard type: {dynamo_guard.guard_types}") - - guards: List[Guard] = [] - for g in graph_module.guards: - try: - guard = _convert_dynamo_guard_to_exir_guard(g) - assert isinstance(guard, Guard) - guards.append(guard) - except InternalError as e: - print(str(e)) - - return Trace(graph_module, guards, concrete_args)