diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 69610a73abe..791cb3e16ef 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -10,7 +10,7 @@ import torch from executorch import exir -from executorch.exir import CaptureConfig, EdgeCompileConfig +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.tests.common import register_additional_test_aten_ops from torch.ao.quantization import ( # @manual @@ -26,6 +26,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) +from torch.export import export from torch.nn import functional as F from torch.testing import FileCheck @@ -56,9 +57,11 @@ def forward(self, x, y): ) m = _convert_to_reference_decomposed_fx(m) config = EdgeCompileConfig(_check_ir_validity=False) - m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config) + m = to_edge(export(m, example_inputs), compile_config=config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass(_fix_node_meta_val=True)) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False + ) # check that we are using functional variant of q/dq/add FileCheck().check( "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" @@ -67,12 +70,12 @@ def forward(self, x, y): ).check( "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) m = m.to_executorch() # check that we are using out variant of q/dq/add FileCheck().check("torch.ops.quantized_decomposed.add.out").run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) def test_reshape(self) -> None: @@ -95,9 +98,11 @@ def forward(self, x, y): m(*example_inputs) m = _convert_to_reference_decomposed_fx(m) config = EdgeCompileConfig(_check_ir_validity=False) - m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config) + m = to_edge(export(m, example_inputs), compile_config=config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass(_fix_node_meta_val=True)) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False + ) # check that we are using functional variant of q/dq/add/reshape # make sure we only have two quant and one dequant since the q/dq around reshape # should be fused @@ -114,14 +119,14 @@ def forward(self, x, y): 1, exactly=True, ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False)) # check that we are using out variant of q/dq/add FileCheck().check("torch.ops.quantized_decomposed.add.out").check( "torch.ops.aten.view_copy.out" - ).run(m.exported_program.graph_module.code) + ).run(m.exported_program().graph_module.code) def test_slice(self) -> None: """We don't proactively quantize slice today, but we'll fuse the dq-slice-q @@ -150,9 +155,11 @@ def forward(self, x, y): ) m = _convert_to_reference_decomposed_fx(m) config = EdgeCompileConfig(_check_ir_validity=False) - m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config) + m = to_edge(export(m, example_inputs), compile_config=config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass(_fix_node_meta_val=True)) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False + ) # check that we are using functional variant of q/dq/add/slice # make sure we only have one quant and one dequant since the q/dq around slice # should be fused @@ -169,14 +176,14 @@ def forward(self, x, y): ).check( "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) m = m.to_executorch() # check that we are using out variant of add and slice_copy FileCheck().check("torch.ops.quantized_decomposed.add.out").check( "torch.ops.aten.slice_copy.Tensor_out" - ).run(m.dump_graph_module().code) + ).run(m.exported_program().graph_module.code) def test_cat(self) -> None: class M(torch.nn.Module): @@ -197,9 +204,9 @@ def forward(self, x, y): m(*example_inputs) m = _convert_to_reference_decomposed_fx(m) config = EdgeCompileConfig(_check_ir_validity=False) - m = exir.capture(m, example_inputs, CaptureConfig()).to_edge(config=config) + m = to_edge(export(m, example_inputs), compile_config=config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass()) + m = m.transform([QuantFusionPass()], check_ir_validity=False) # check that we are using functional variant of q/dq/cat FileCheck().check_count( "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", @@ -210,7 +217,7 @@ def forward(self, x, y): 1, exactly=True, ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) m = m.to_executorch() @@ -224,7 +231,7 @@ def forward(self, x, y): ).check("torch.ops.aten.cat.out").check_count( "torch.ops.quantized_decomposed.dequantize_per_tensor.out", 1, exactly=True ).run( - m.dump_graph_module().code + m.exported_program().graph_module.code ) def test_embedding_byte(self) -> None: @@ -292,16 +299,18 @@ def forward(self, indices): _check_ir_validity=False, _use_edge_ops=True, ) - m = exir.capture(m, example_inputs).to_edge(config=compile_config) + m = to_edge(export(m, example_inputs), compile_config=compile_config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass(_fix_node_meta_val=True)) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False + ) # check that we are using functional variant of q/dq/cat FileCheck().check( "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default", ).check( "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default" ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) # TODO: enable after the out variants of quantize_per_channel is supported @@ -348,17 +357,18 @@ def forward(self, indices): _check_ir_validity=False, _use_edge_ops=True, ) - m = exir.capture(m, example_inputs).to_edge(config=compile_config) + m = to_edge(export(m, example_inputs), compile_config=compile_config) # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. - m = m.transform(QuantFusionPass(_fix_node_meta_val=True)) - m(*example_inputs) + m = m.transform( + [QuantFusionPass(_fix_node_meta_val=True)], check_ir_validity=False + ) # check that we are using functional variant of q/dq/cat FileCheck().check( "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default", ).check( "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default" ).run( - m.exported_program.graph_module.code + m.exported_program().graph_module.code ) # TODO: enable after the out variants of quantize_per_channel is supported