diff --git a/exir/tests/models.py b/exir/tests/models.py index c9eb0761935..74c86dab807 100644 --- a/exir/tests/models.py +++ b/exir/tests/models.py @@ -7,7 +7,7 @@ # pyre-strict import itertools -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import executorch.exir as exir @@ -34,6 +34,11 @@ def forward( def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]: return (torch.rand(4), torch.rand(5)) + def get_dynamic_shape(self) -> Any: # pyre-ignore[3] + dim = torch.export.Dim("dim", max=10) + dim2 = torch.export.Dim("dim2", max=10) + return ({0: dim}, {0: dim2}) + class ModelWithUnusedArg(nn.Module): def __init__(self) -> None: diff --git a/exir/tests/test_dynamic_shape_propagation.py b/exir/tests/test_dynamic_shape_propagation.py index 1cd699a4266..abc07d60437 100644 --- a/exir/tests/test_dynamic_shape_propagation.py +++ b/exir/tests/test_dynamic_shape_propagation.py @@ -7,8 +7,10 @@ from unittest import TestCase from executorch import exir +from executorch.exir import to_edge from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass from executorch.exir.tests.models import Repeat +from torch.export import export class TestDynamicShapeProp(TestCase): @@ -17,15 +19,14 @@ def test_repeat(self): inputs = eager_model.get_random_inputs() inputs = inputs[0], inputs[1] - prog = exir.capture( - eager_model, - inputs, - exir.CaptureConfig(enable_dynamic_shape=True), - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) + prog = to_edge( + export(eager_model, inputs, dynamic_shapes=eager_model.get_dynamic_shape()), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) - new_prog = prog.transform(SpecPropPass(), HintBasedSymShapeEvalPass()) + new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()]) - gm = new_prog.exported_program.graph_module + gm = new_prog.exported_program().graph_module DebugPass(show_spec=True)(gm) *_, return_node = gm.graph.nodes