Skip to content

Commit

Permalink
[draft export] generate fake outputs when real tensor prop finds mism…
Browse files Browse the repository at this point in the history
…atches (pytorch#139766)

Currently real tensor tracing raises MetadataMismatchErrors if registered fake kernels don't match the real kernels (e.g. shape, aliasing, dtype, etc.). This adds an option to use fake kernel inference to bypass mismatches - this option defaults to False for real tensor tracing, but is on for draft export.

Pull Request resolved: pytorch#139766
Approved by: https://github.com/angelayi, https://github.com/zou3519
  • Loading branch information
pianpwk authored and pytorchmergebot committed Nov 21, 2024
1 parent 6647661 commit 1132b67
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 165 deletions.
87 changes: 86 additions & 1 deletion test/export/test_draft_export.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Owner(s): ["oncall: export"]
import copy
from typing import List, Tuple

import torch
from torch.export import Dim
from torch.export import Dim, export
from torch.export._draft_export import draft_export, FailureType
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.torchbind_impls import (
_empty_tensor_queue,
init_torchbind_implementations,
)
from torch.utils._pytree import tree_leaves


class TestDraftExport(TestCase):
Expand Down Expand Up @@ -271,6 +273,89 @@ def forward(self, tq, x):
self.assertEqual(tq3.size(), 2)
self.assertEqual(tq.size(), 2)

def test_override_size_and_dtype_mismatched_fake_kernels(self):
class M(torch.nn.Module):
def forward(self, a):
return torch.ops.mylib.foo(a)

@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> List[torch.Tensor]:
x = a * 2
y = a.repeat(2, 2)
z = a.to(torch.bfloat16)
return [x, y, z]

@foo.register_fake
def foo_fake_impl(a):
x = torch.empty_like(a) # good
y = torch.empty_like(a) # size mismatch
z = torch.empty_like(a) # dtype mismatch
return [x, y, z]

mod = M()
inputs = (torch.randn(3, 3),)
with self.assertRaises(RuntimeError):
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
export(mod, inputs)

ep, report = draft_export(mod, inputs)
for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)):
self.assertTrue(torch.allclose(ep_out, eager_out))
self.assertEqual(ep_out.dtype, eager_out.dtype)

self.assertEqual(len(report.failures), 2)
self.assertEqual(
report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
)
self.assertEqual(
report.failures[1].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
)
self.assertEqual(
sorted([f.data["reason"] for f in report.failures]),
[
"Dtypes torch.bfloat16 and torch.float32 are not equal!",
"mismatch between fake value 3 and real value 6 ",
],
)

def test_override_incorrectly_aliasing_kernel(self):
class M(torch.nn.Module):
def forward(self, a):
return torch.ops.mylib.foo(a)

@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return a * 2, a + 2

@foo.register_fake
def foo_fake_impl(a):
return a, torch.empty_like(a) # incorrectly aliasing

mod = M()
inputs = (torch.randn(3, 3),)
with self.assertRaisesRegex(
RuntimeError,
"Real tensor propagation found an aliasing mismatch",
):
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
export(mod, inputs)

ep, report = draft_export(mod, inputs)
for ep_out, eager_out in zip(
tree_leaves(ep.module()(*inputs)), tree_leaves(mod(*inputs))
):
self.assertTrue(torch.allclose(ep_out, eager_out))
self.assertEqual(ep_out.dtype, eager_out.dtype)

self.assertEqual(len(report.failures), 1)
self.assertEqual(
report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
)
self.assertTrue(
"Mismatched aliasing spec between fake kernel and real kernel"
in report.failures[0].data["reason"]
)


if __name__ == "__main__":
run_tests()
37 changes: 32 additions & 5 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,8 +1113,8 @@ def foo_fake_impl(a, b):
# catch concrete inequality
with self.assertRaisesRegex(
error_type,
"Real tensor propagation found an output size mismatch between fake shape 8 and real shape 4, "
"at output index 0, dimension 0 for func: mylib.foo.default",
r"Real tensor propagation found an output size mismatch between fake shape 8 and real shape 4, "
r"at output\.size\(0\), for func: mylib.foo.default",
):
export(
M(),
Expand All @@ -1133,8 +1133,8 @@ def foo_fake_impl(a, b):
)
with self.assertRaisesRegex(
error_type,
"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
"at output index 0, dimension 0 for func: mylib.foo.default",
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
r"at output\.size\(0\), for func: mylib.foo.default",
):
export(
M(),
Expand Down Expand Up @@ -1193,7 +1193,7 @@ def foo_fake_impl(a):
with self.assertRaisesRegex(
error_type,
r"Real tensor propagation found a metadata mismatch between fake tensor (.*\n)*.* "
r"and real tensor (.*\n)*.* at output index 0, for func: mylib.foo_dtype.default",
r"and real tensor (.*\n)*.* at output, for func: mylib.foo_dtype.default",
):
ep = export(N(), (torch.randn(4, 4),))

Expand Down Expand Up @@ -1415,6 +1415,33 @@ def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
ep = export(model, inputs)

def test_real_tensor_errors_on_aliasing_custom_op(self):
@torch.library.custom_op("export::foo_alias", mutates_args={})
def foo(x: torch.Tensor) -> torch.Tensor:
return x

class Foo(torch.nn.Module):
def forward(self, x):
return torch.ops.export.foo_alias(x) * 2

model = Foo()
inputs = (torch.randn(4, 4),)
error_type = (
RuntimeError
if is_non_strict_test(self._testMethodName)
else torch._dynamo.exc.TorchRuntimeError
)
with self.assertRaisesRegex(
error_type,
(
r"The output of this custom operator \(1\) must not also be an input "
r"to this custom operator and \(2\) may not alias any inputs to this "
r"custom operator or other returns"
),
):
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
ep = export(model, inputs)

@testing.expectedFailureSerDer # SymBool serialization? TODO(pianpwk)
@testing.expectedFailureSerDerNonStrict
def test_real_tensor_bool_cast(self):
Expand Down
5 changes: 5 additions & 0 deletions torch/_functorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ def remote_autograd_cache_default() -> Optional[bool]:
# Supported formats are defined here https://graphviz.org/docs/outputs/
torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")

# Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
# kernel mismatch is detected, bypasses by making a fake kernel from the
# real tensor outputs.
generate_fake_kernels_from_real_mismatches = False


# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
Expand Down
Loading

0 comments on commit 1132b67

Please sign in to comment.