@@ -320,7 +320,7 @@ def call_operator(self, op, args, kwargs, meta):
320320
321321
322322@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
323- class ReplaceAddMMWithLinearPass (ExportPass ):
323+ class ReplaceAddMMWithLinearPass (RemoveOrReplacePassInterface ):
324324 """
325325 This pass replaces addmm with linear op.
326326 """
@@ -329,81 +329,79 @@ def __init__(self):
329329 super ().__init__ ()
330330 self .counter = 0
331331
332- def replace_addmm_with_linear (self , graph_module : torch .fx .GraphModule ):
333- graph = graph_module .graph
334- for node in graph .nodes :
335- # We are only interested in admm nodes
336- if node .target != exir_ops .edge .aten .addmm .default :
337- continue
338-
339- # The addmm op has three concrete args: input, mat1, mat2
340- assert len (node .args ) >= 3
341- (bias , mat1 , mat2 ) = node .args [0 :3 ]
342- # The other two args are optional scale args
343- beta = node .kwargs .get ("beta" , 1.0 )
344- alpha = node .kwargs .get ("alpha" , 1.0 )
345-
346- # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
347- # it to linear op by multiplying beta to bias, and alpha to mat2.t().
348- # However, the following two conditions must hold:
349- # a. If bias is not a param, then beta must be 1.0
350- # b. If mat2 is not a param, then mat2 must be a transpose op. Also,
351- # the input to the transpose must be a param, or alpha must be 1.0.
352- fit_bias = is_node_with_op (bias , "get_attr" ) or beta == 1.0
353- fit_mat2 = is_node_with_op (mat2 , "get_attr" )
354- transposed_mat2 = False
355- if (
356- not fit_mat2
357- and is_node_with_op (mat2 , "call_function" )
358- and mat2 .target == exir_ops .edge .aten .transpose_copy .int
359- ):
360- mat2 , transposed_mat2 = mat2 .args [0 ], True
361- fit_mat2 = is_node_with_op (mat2 , "get_attr" ) or alpha == 1.0
332+ @property
333+ def targets (self ) -> list [EdgeOpOverload ]:
334+ return [exir_ops .edge .aten .addmm .default ]
362335
363- if not fit_bias or not fit_mat2 :
364- continue
336+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
337+ # The addmm op has three concrete args: input, mat1, mat2
338+ assert len (node .args ) >= 3
339+ (bias , mat1 , mat2 ) = node .args [0 :3 ]
340+ # The other two args are optional scale args
341+ beta = float (node .kwargs .get ("beta" , 1.0 ))
342+ alpha = float (node .kwargs .get ("alpha" , 1.0 ))
343+
344+ bias = cast (torch .fx .Node , bias )
345+ mat1 = cast (torch .fx .Node , mat1 )
346+ mat2 = cast (torch .fx .Node , mat2 )
347+
348+ # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
349+ # it to linear op by multiplying beta to bias, and alpha to mat2.t().
350+ # However, the following two conditions must hold:
351+ # a. If bias is not a param, then beta must be 1.0
352+ # b. If mat2 is not a param, then mat2 must be a transpose op. Also,
353+ # the input to the transpose must be a param, or alpha must be 1.0.
354+ fit_bias = is_node_with_op (bias , "get_attr" ) or beta == 1.0
355+ fit_mat2 = is_node_with_op (mat2 , "get_attr" )
356+ transposed_mat2 = False
357+ if (
358+ not fit_mat2
359+ and is_node_with_op (mat2 , "call_function" )
360+ and mat2 .target == exir_ops .edge .aten .transpose_copy .int
361+ ):
362+ mat2 , transposed_mat2 = cast (torch .fx .Node , mat2 .args [0 ]), True
363+ fit_mat2 = is_node_with_op (mat2 , "get_attr" ) or alpha == 1.0
365364
366- # Multiply bias by beta
367- if beta != 1.0 :
368- assert is_node_with_op (bias , "get_attr" )
369- bias_tensor = get_tensor_from_attr (graph_module , bias )
370- assert isinstance (bias_tensor , torch .Tensor )
371- bias_tensor = beta * bias_tensor
372- with graph .inserting_before (node ):
373- bias_name = f"_bias_addmm_to_linear_{ self .counter } "
374- graph_module .register_buffer (bias_name , bias_tensor )
375- bias = graph .get_attr (bias_name )
376-
377- # Use associativity of scalar multiplication, and multiply alpha to mat2
378- if is_node_with_op (mat2 , "get_attr" ):
379- mat2_tensor = get_tensor_from_attr (graph_module , mat2 )
380- assert isinstance (mat2_tensor , torch .Tensor )
381- mat2_tensor = alpha * mat2_tensor
382- # transpose mat2
383- mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor .t ()
384- with graph .inserting_before (node ):
385- mat2_name = f"_mat2_addmm_to_linear_{ self .counter } "
386- graph_module .register_buffer (mat2_name , mat2_tensor )
387- mat2 = graph .get_attr (mat2_name )
388-
389- # Construct the linear node
390- linear_args = (mat1 , mat2 , bias )
391- with graph .inserting_before (node ):
392- linear_node = graph .call_function (
393- exir_ops .edge .aten .linear .default , args = linear_args
394- )
395- linear_node .meta = node .meta
396- # Replace all the uses of the addmm op with linear op
397- node .replace_all_uses_with (linear_node )
398- self .counter += 1
365+ if not fit_bias or not fit_mat2 :
366+ return False
399367
400- graph_module . recompile ()
401- graph_module . graph .eliminate_dead_code ()
368+ graph = node . graph
369+ graph_module = graph .owning_module
402370
403- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
404- self .replace_addmm_with_linear (graph_module )
405- result = super ().call (graph_module )
406- return result
371+ # Multiply bias by beta
372+ if beta != 1.0 :
373+ assert is_node_with_op (bias , "get_attr" )
374+ bias_tensor = get_tensor_from_attr (graph_module , bias )
375+ assert isinstance (bias_tensor , torch .Tensor )
376+ bias_tensor = beta * bias_tensor
377+ with graph .inserting_before (node ):
378+ bias_name = f"_bias_addmm_to_linear_{ self .counter } "
379+ graph_module .register_buffer (bias_name , bias_tensor )
380+ bias = graph .get_attr (bias_name )
381+
382+ # Use associativity of scalar multiplication, and multiply alpha to mat2
383+ if is_node_with_op (mat2 , "get_attr" ):
384+ mat2_tensor = get_tensor_from_attr (graph_module , mat2 )
385+ assert isinstance (mat2_tensor , torch .Tensor )
386+ mat2_tensor = alpha * mat2_tensor
387+ # transpose mat2
388+ mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor .t ()
389+ with graph .inserting_before (node ):
390+ mat2_name = f"_mat2_addmm_to_linear_{ self .counter } "
391+ graph_module .register_buffer (mat2_name , mat2_tensor )
392+ mat2 = graph .get_attr (mat2_name )
393+
394+ # Construct the linear node
395+ linear_args = (mat1 , mat2 , bias )
396+ with graph .inserting_before (node ):
397+ linear_node = graph .call_function (
398+ exir_ops .edge .aten .linear .default , args = linear_args
399+ )
400+ linear_node .meta = node .meta
401+ # Replace all the uses of the addmm op with linear op
402+ node .replace_all_uses_with (linear_node )
403+ self .counter += 1
404+ return True
407405
408406
409407@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
0 commit comments