Skip to content

Commit f7b714b

Browse files
committed
up
1 parent 49d11b1 commit f7b714b

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ class Model(torch.nn.Module):
8585
def __init__(self) -> None:
8686
super().__init__()
8787

88-
buffer = torch.ones(1)
89-
self.register_buffer("buffer", buffer)
90-
9188
def forward(self, q, k, v, mask):
9289
out = torch.ops.aten.scaled_dot_product_attention.default(
9390
q, k, v, attn_mask=mask
@@ -99,7 +96,10 @@ def forward(self, q, k, v, mask):
9996
out = out.transpose(1, 2)
10097
out = out.view(1, -1)
10198
out = out.permute(0, 1)
102-
out = out.add_(self.buffer)
99+
out = out.add_(1.0)
100+
out = out.mul_(2.0)
101+
out = out.div_(3.0)
102+
out = out.sub_(4.0)
103103
out = torch.ops.aten.view_copy.default(out, (-1,))
104104
out = out.select(0, 0)
105105
return out

exir/program/_program.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,8 +976,33 @@ def keep(op):
976976
schema = op._schema
977977
native_schema = _pybind_schema_to_native_schema(schema)
978978
if native_schema.is_mutable:
979+
logging.warn(
980+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable."
981+
)
979982
return False
983+
980984
if native_schema.aliased_return_names() != [None]:
985+
logging.warn(
986+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
987+
)
988+
return False
989+
990+
# Explicit block list of ops that don't work if asked for
991+
# preservation
992+
if op in [
993+
# Hits infinte recursion error when op is in
994+
# EDGE_DO_NOT_DECOMP namespace
995+
torch.ops.aten._to_copy.default,
996+
# scalar to tensor type promotion does not work on ops
997+
# in EDGE_DO_NOT_DECOMP namespace
998+
torch.ops.aten.mul.Tensor,
999+
torch.ops.aten.add.Tensor,
1000+
torch.ops.aten.sub.Tensor,
1001+
torch.ops.aten.div.Tensor,
1002+
]:
1003+
logging.warn(
1004+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist."
1005+
)
9811006
return False
9821007
return True
9831008

0 commit comments

Comments
 (0)