1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
15- from executorch .backends .cadence .aot .compiler import export_to_edge
1615from executorch .backends .cadence .aot .graph_builder import single_op_builder
1716from executorch .backends .cadence .aot .pass_utils import count_node
1817from executorch .backends .cadence .aot .simplify_ops import (
@@ -40,82 +39,47 @@ def test_simplify_slice_scatter_op(
4039 end : Optional [int ] = None ,
4140 step : int = 1 ,
4241 ):
43- class SliceScatter (torch .nn .Module ):
44- def __init__ (
45- self , dim : int , start : Optional [int ], end : Optional [int ], step : int
46- ):
47- super ().__init__ ()
48- self .dim = dim
49- self .start = start
50- self .end = end
51- self .step = step
52-
53- def forward (self , x : torch .Tensor , y : torch .Tensor ):
54- return torch .slice_scatter (
55- x , y , self .dim , self .start , self .end , self .step
56- )
57-
58- model = SliceScatter (dim , start , end , step )
59- x = torch .randn (in_shape )
60- y = torch .randn (src_shape )
61- graph_module = export_to_edge (model , (x , y )).exported_program ().graph_module
62-
63- p = SimplifySliceOpPass ()
64-
65- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
66-
67- self .assertEqual (
68- count_node (graph_after_passes , exir_ops .edge .aten .slice_scatter .default ), 0
42+ x = torch .randn (* in_shape )
43+ y = torch .randn (* src_shape )
44+ gm = single_op_builder (
45+ placeholders = (x , y ),
46+ op = exir_ops .edge .aten .slice_scatter .default ,
47+ args = (x , y , dim , start , end , step ),
6948 )
49+ p = SimplifySliceOpPass ()
50+ gm = cast (PassResult , p (gm )).graph_module
51+ self .assertEqual (count_node (gm , exir_ops .edge .aten .slice_scatter .default ), 0 )
7052
7153 @parameterized .expand (
7254 [
73- [(3 , 16 , 5 ), ( 3 , 0 , 5 ), 1 , 15 , 3 , 3 ],
55+ [(3 , 16 , 5 ), 1 , 15 , 3 , 3 ],
7456 ]
7557 )
7658 @torch .no_grad ()
7759 def test_simplify_slice_op (
7860 self ,
7961 in_shape : Tuple [int ],
80- src_shape : Tuple [int ],
8162 dim : int ,
8263 start : Optional [int ] = None ,
8364 end : Optional [int ] = None ,
8465 step : int = 1 ,
8566 ):
86- class SliceCopy (torch .nn .Module ):
87- def __init__ (
88- self , dim : int , start : Optional [int ], end : Optional [int ], step : int
89- ):
90- super ().__init__ ()
91- self .dim = dim
92- self .start = start
93- self .end = end
94- self .step = step
95-
96- def forward (self , x : torch .Tensor ) -> torch .Tensor :
97- return torch .slice_copy (
98- x , dim = self .dim , start = self .start , end = self .end , step = self .step
99- )
100-
101- # Create a model with single slice copy op.
102- model = SliceCopy (dim , start , end , step )
103- x = torch .randn (in_shape )
104- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
105- self .assertEqual (
106- count_node (graph_module , exir_ops .edge .aten .slice_copy .Tensor ), 1
67+ x = torch .randn (* in_shape )
68+ gm = single_op_builder (
69+ placeholders = (x ,),
70+ op = exir_ops .edge .aten .slice_copy .Tensor ,
71+ args = (
72+ x ,
73+ dim ,
74+ start ,
75+ end ,
76+ step ,
77+ ),
10778 )
108-
10979 p = SimplifySliceOpPass ()
110-
111- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
112-
113- self .assertEqual (
114- count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
115- )
116- self .assertEqual (
117- count_node (graph_after_passes , exir_ops .edge .aten .full .default ), 1
118- )
80+ gm = cast (PassResult , p (gm )).graph_module
81+ self .assertEqual (count_node (gm , exir_ops .edge .aten .slice_copy .Tensor ), 0 )
82+ self .assertEqual (count_node (gm , exir_ops .edge .aten .full .default ), 1 )
11983
12084 def test_simplify_slice_op_args (self ) -> None :
12185 x = torch .rand (4 , 5 )
@@ -125,24 +89,10 @@ def test_simplify_slice_op_args(self) -> None:
12589 args = (x , 1 ),
12690 kwargs = {"end" : 3 },
12791 )
128- self .assertEqual (
129- [
130- (n .args [1 :], n .kwargs )
131- for n in gm .graph .find_nodes (
132- op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
133- )
134- ],
135- [((1 ,), {"end" : 3 })],
136- )
137-
92+ original_slice_copy = list (gm .graph .nodes )[1 ]
93+ self .assertEqual (original_slice_copy .args [1 :], (1 ,))
94+ self .assertEqual (original_slice_copy .kwargs , {"end" : 3 })
13895 gm = BindOptionalArgsPass ().call (gm ).graph_module
139-
140- self .assertEqual (
141- [
142- (n .args [1 :], n .kwargs )
143- for n in gm .graph .find_nodes (
144- op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
145- )
146- ],
147- [((1 , None , 3 , 1 ), {})],
148- )
96+ modified_slice_copy = list (gm .graph .nodes )[1 ]
97+ self .assertEqual (modified_slice_copy .args [1 :], (1 , None , 3 , 1 ))
98+ self .assertEqual (modified_slice_copy .kwargs , {})
0 commit comments