@@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self):
235235 )
236236
237237 def test_remove_clone (self ):
238- class Clone (torch .nn .Module ):
239- def forward (self , x , y ):
240- t1 = x .clone ()
241- t2 = y .clone ()
242- return t1 + t2
243-
244- x = torch .ones (3 , 5 )
245- y = torch .ones (3 , 5 )
246- graph_module = export_to_edge (Clone (), (x , y )).exported_program ().graph_module
247- new_graph_module = RemoveCloneOpPass ()(graph_module ).graph_module
248- new_graph_module .graph .eliminate_dead_code ()
249- # Assert that t1 and t2 are optimized away
250- self .assertEqual (count_node (new_graph_module , torch .ops .aten .clone .out ), 0 )
238+ builder = GraphBuilder ()
239+ x = builder .placeholder ("x" , torch .randn ([3 , 5 ], dtype = torch .float32 ))
240+ clone = builder .call_operator (op = exir_ops .edge .aten .clone .default , args = (x ,))
241+ builder .output ([clone ])
242+ original = builder .get_graph_module ()
243+ graph_after_passes = RemoveCloneOpPass ()(original ).graph_module
244+ self .assertEqual (
245+ count_node (graph_after_passes , torch .ops .aten .clone .default ), 0
246+ )
251247
252248 def test_remove_contiguous (self ):
253- class Contiguous ( torch . nn . Module ):
254- def forward ( self , x , y ):
255- t1 = x . contiguous ()
256- t2 = y . contiguous ( )
257- return t1 + t2
258-
259- x = torch . ones ( 3 , 5 )
260- y = torch . ones ( 3 , 5 )
261- graph_module = (
262- export_to_edge ( Contiguous (), ( x , y )). exported_program (). graph_module
249+ builder = GraphBuilder ()
250+ x = builder . placeholder ( "x" , torch . randn ([ 3 , 5 ], dtype = torch . float32 ))
251+ contiguous = builder . call_operator (
252+ op = exir_ops . edge . aten . contiguous . default , args = ( x , )
253+ )
254+ builder . output ([ contiguous ])
255+ original = builder . get_graph_module ( )
256+ graph_after_passes = RemoveContiguousOpPass ()( original ). graph_module
257+ self . assertEqual (
258+ count_node ( graph_after_passes , torch . ops . aten . contiguous . default ), 0
263259 )
264- new_graph_module = RemoveContiguousOpPass ()(graph_module ).graph_module
265- new_graph_module .graph .eliminate_dead_code ()
266- # Assert that t1 and t2 are optimized away
267- self .assertEqual (count_node (new_graph_module , torch .ops .aten .contiguous .out ), 0 )
268260
269261 @parameterized .expand (
270262 [
@@ -274,119 +266,129 @@ def forward(self, x, y):
274266 )
275267 @torch .no_grad ()
276268 def test_remove_nop_view (self , shape , new_shape ):
277- class View (torch .nn .Module ):
278- def __init__ (self , new_shape ):
279- super ().__init__ ()
280- self .new_shape = new_shape
281-
282- def forward (self , x : torch .Tensor ):
283- return x .view (self .new_shape )
284-
285- model = View (new_shape )
286- x = torch .randn (shape )
287- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
288- p = RemoveNopSliceOrViewOpPass ()
289- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
290- graph_after_passes .graph .eliminate_dead_code ()
291- # Assert that view op was removed
269+ builder = GraphBuilder ()
270+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
271+ view = builder .call_operator (
272+ op = exir_ops .edge .aten .view_copy .default , args = (x , new_shape )
273+ )
274+ builder .output ([view ])
275+ original = builder .get_graph_module ()
276+ graph_after_passes = cast (
277+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
278+ ).graph_module
292279 self .assertEqual (
293280 count_node (graph_after_passes , exir_ops .edge .aten .view_copy .default ), 0
294281 )
295282
296283 def test_remove_nop_slice (self ):
297- class Slice (torch .nn .Module ):
298- def forward (self , x ):
299- return torch .slice_copy (x , dim = 0 , start = 0 , step = 1 )
300-
301- x = torch .ones (3 , 5 )
302- model = Slice ()
303- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
304- p = RemoveNopSliceOrViewOpPass ()
305- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
306- graph_after_passes .graph .eliminate_dead_code ()
307- # Assert that slice op was removed
284+ builder = GraphBuilder ()
285+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
286+ slice_ = builder .call_operator (
287+ op = exir_ops .edge .aten .slice_copy .Tensor ,
288+ args = (
289+ x ,
290+ 0 , # dim
291+ 0 , # start
292+ 3 , # end
293+ ),
294+ )
295+ builder .output ([slice_ ])
296+ original = builder .get_graph_module ()
297+ graph_after_passes = cast (
298+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
299+ ).graph_module
308300 self .assertEqual (
309301 count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
310302 )
311303
312- def test_remove_nop_select (self ):
313- class SelectFeasible1 ( torch . nn . Module ):
314- def forward ( self , x ):
315- y = x . select ( 0 , 0 )
316- z = y . view ([ 1 , 5 , 6 ])
317- return z
318-
319- x = torch . ones ( 1 , 5 , 6 )
320- graph_module = (
321- export_to_edge ( SelectFeasible1 (), ( x ,)). exported_program (). graph_module
304+ def test_remove_nop_select_before_view (self ):
305+ builder = GraphBuilder ()
306+ x = builder . placeholder ( "x" , torch . randn ( 1 , 5 , 6 , dtype = torch . float32 ))
307+ select = builder . call_operator (
308+ op = exir_ops . edge . aten . select_copy . int ,
309+ args = (
310+ x ,
311+ 0 , # dim
312+ 0 , # index
313+ ),
322314 )
323- self .assertEqual (
324- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
315+ view = builder .call_operator (
316+ op = exir_ops .edge .aten .view_copy .default ,
317+ args = (select , [1 , 5 , 6 ]), # new shape
325318 )
326- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
327- # Assert that select op was removed
319+ builder .output ([view ])
320+ original = builder .get_graph_module ()
321+ graph_after_passes = cast (
322+ PassResult , RemoveNopSelectOpPass ()(original )
323+ ).graph_module
328324 self .assertEqual (
329- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
325+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
330326 )
331327
332- class SelectFeasible2 (torch .nn .Module ):
333- def forward (self , x , y ):
334- x = x .select (0 , 0 )
335- z = x + y
336- return z
337-
338- x = torch .ones (1 , 5 , 6 )
339- y = torch .ones (1 , 5 , 6 )
340- graph_module = (
341- export_to_edge (SelectFeasible2 (), (x , y )).exported_program ().graph_module
342- )
343- self .assertEqual (
344- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
328+ def test_remove_nop_select_before_add (self ):
329+ builder = GraphBuilder ()
330+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
331+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
332+ select = builder .call_operator (
333+ op = exir_ops .edge .aten .select_copy .int ,
334+ args = (
335+ x ,
336+ 0 , # dim
337+ 0 , # index
338+ ),
345339 )
346- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
347- # Assert that select op was removed
340+ add = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (select , y ))
341+ builder .output ([add ])
342+ original = builder .get_graph_module ()
343+ graph_after_passes = cast (
344+ PassResult , RemoveNopSelectOpPass ()(original )
345+ ).graph_module
348346 self .assertEqual (
349- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
347+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
350348 )
351349
352- class SelectFeasible3 (torch .nn .Module ):
353- def forward (self , x , y ):
354- x = x .select (0 , 0 )
355- z = x * y
356- return z
357-
358- x = torch .ones (1 , 5 , 6 )
359- y = torch .ones (1 , 5 , 6 )
360- graph_module = (
361- export_to_edge (SelectFeasible3 (), (x , y )).exported_program ().graph_module
362- )
363- self .assertEqual (
364- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
350+ def test_remove_nop_select_before_mul (self ):
351+ builder = GraphBuilder ()
352+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
353+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
354+ select = builder .call_operator (
355+ op = exir_ops .edge .aten .select_copy .int ,
356+ args = (
357+ x ,
358+ 0 , # dim
359+ 0 , # index
360+ ),
365361 )
366- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
367- # Assert that select op was removed
362+ mul = builder .call_operator (op = exir_ops .edge .aten .mul .Tensor , args = (select , y ))
363+ builder .output ([mul ])
364+ original = builder .get_graph_module ()
365+ graph_after_passes = cast (
366+ PassResult , RemoveNopSelectOpPass ()(original )
367+ ).graph_module
368368 self .assertEqual (
369- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
369+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
370370 )
371371
372- class SelectFeasible4 (torch .nn .Module ):
373- def forward (self , x , y ):
374- x = x .select (0 , 0 )
375- z = x / y
376- return z
377-
378- x = torch .ones (1 , 5 , 6 )
379- y = torch .ones (1 , 5 , 6 )
380- graph_module = (
381- export_to_edge (SelectFeasible4 (), (x , y )).exported_program ().graph_module
382- )
383- self .assertEqual (
384- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
372+ def test_remove_nop_select_before_div (self ):
373+ builder = GraphBuilder ()
374+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
375+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
376+ select = builder .call_operator (
377+ op = exir_ops .edge .aten .select_copy .int ,
378+ args = (
379+ x ,
380+ 0 , # dim
381+ 0 , # index
382+ ),
385383 )
386- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
387- # Assert that select op was removed
384+ div = builder .call_operator (op = exir_ops .edge .aten .div .Tensor , args = (select , y ))
385+ builder .output ([div ])
386+ original = builder .get_graph_module ()
387+ graph_after_passes = cast (
388+ PassResult , RemoveNopSelectOpPass ()(original )
389+ ).graph_module
388390 self .assertEqual (
389- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
391+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
390392 )
391393
392394 def test_remove_nop_quant_dequant (self ):
0 commit comments