diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 3238c23eda0..ee5fa6af4ee 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -87,6 +87,7 @@ TensorSpec, ) from executorch.exir.types import LeafValueSpec, ValueSpec +from torch._subclasses.fake_tensor import FakeTensor from torch.export.exported_program import ExportedProgram from torch.utils import _pytree as pytree @@ -933,6 +934,35 @@ def _emit_argument( return arg return self._emit_evalue(self._constant_to_evalue(arg, arg_type)) + def _get_sym_ret( + self, + val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], + ) -> Optional[_AbstractValue]: + """ + Returns the emit ret for sym value. + """ + ret = None + if isinstance(val, torch.SymInt): + ret = self._emit_evalue(EValue(Int(0))) + elif isinstance(val, torch.BoolType): + ret = self._emit_evalue(EValue(Bool(False))) + elif isinstance(val, torch.FloatType): + ret = self._emit_evalue(EValue(Double(0))) + return ret + + def _get_sym_and_fake_tensor_ret( + self, + val: Tuple[Union[torch.SymInt, torch.BoolType, torch.FloatType, FakeTensor]], + spec: TensorSpec, + ) -> Union[List[_AbstractValue], _AbstractValue, Tuple[_AbstractValue, ...]]: + # Try to get the ret if it's a sym value. + ret = self._get_sym_ret(val) + # If the ret is None, it means that the val is not a sym value, but a regular tensor + if ret is None: + ret = self._emit_spec(spec) + assert ret is not None, "Can't have a None ret" + return ret + def _emit_delegate( self, lowered_module: "LoweredBackendModule", # noqa @@ -944,7 +974,40 @@ def _emit_delegate( processed_bytes = lowered_module.processed_bytes delegate_index = self.emitter_state.delegate_cache.get(processed_bytes) - delegate_ret = self._emit_spec(self.node.meta["spec"]) + delegate_ret = None + + if isinstance(self.node.meta["spec"], list): + delegate_ret = [] + for index, _ in enumerate(self.node.meta["val"]): + ret = self._get_sym_and_fake_tensor_ret( + self.node.meta["val"][index], self.node.meta["spec"][index] + ) + delegate_ret.append(ret) + elif isinstance(self.node.meta["spec"], tuple): + if isinstance(self.node.meta["val"], FakeTensor): + # There is a case when node.meta["spec"] is (TensorSpec, ) while node.meta["val"] is FakeTensor + ret = self._get_sym_and_fake_tensor_ret( + self.node.meta["val"], self.node.meta["spec"][0] + ) + delegate_ret = (ret,) + else: + delegate_ret = [] + for index, _ in enumerate(self.node.meta["val"]): + ret = self._get_sym_and_fake_tensor_ret( + self.node.meta["val"][index], self.node.meta["spec"][index] + ) + delegate_ret.append(ret) + delegate_ret = tuple(delegate_ret) + elif isinstance(self.node.meta["spec"], TensorSpec): + ret = self._get_sym_and_fake_tensor_ret( + self.node.meta["val"], self.node.meta["spec"] + ) + delegate_ret = ret + else: + raise NotImplementedError( + f"self.node.meta['spec'] {type(self.node.meta['spec'])} is not supported" + ) + assert delegate_ret is not None, "Can't have a None delegate_ret" if delegate_index is None: # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if # present. @@ -1062,13 +1125,8 @@ def _get_empty_tensor_evalue() -> EValue: torch.BoolType, torch.NumberType, ), f"Only symbolic ops that return a Int Bool Float are supported currently got {type(target._schema.returns[0].type)}." - if type(target._schema.returns[0].type) == torch.IntType: - ret = self._emit_evalue(EValue(Int(0))) - elif type(target._schema.returns[0].type) == torch.BoolType: - ret = self._emit_evalue(EValue(Bool(False))) - elif type(target._schema.returns[0].type) == torch.FloatType: - ret = self._emit_evalue(EValue(Double(0))) - else: # type(target._schema.returns[0].type) == torch.NumberType: + ret = self._get_sym_ret(target._schema.returns[0]) + if ret is None: # type(target._schema.returns[0].type) == torch.NumberType: # Cant definitively say what type this is, the runtime operator just overrides the EValue completely # though so we can just serialize whatever as a placeholder. ret = self._emit_evalue(EValue(Int(0)))