4141from tvm .relay .expr_functor import ExprMutator , ExprVisitor
4242
4343from ... import _ffi_api
44- from ...dataflow_pattern import wildcard , is_op
44+ from ...dataflow_pattern import wildcard , is_op , is_constant , rewrite , DFPatternCallback
4545from .register import register_pattern_table
4646
4747logger = logging .getLogger ("DNNL" )
@@ -92,6 +92,7 @@ def _func_wrapper(expr):
9292_register_external_op_helper ("nn.softmax" )
9393_register_external_op_helper ("add" )
9494_register_external_op_helper ("multiply" )
95+ _register_external_op_helper ("nn.layer_norm" )
9596
9697
9798def make_conv_pattern (conv_name , with_bias = True , with_eltwise = None ):
@@ -526,3 +527,52 @@ def visit_call(self, call):
526527 new_mod ["main" ] = SubgraphRemover (subgraphs_to_remove , mod , new_mod ).visit (mod ["main" ])
527528 new_mod = transform .RemoveUnusedFunctions ()(new_mod )
528529 return new_mod
530+
531+
532+ class LayerNormRewrite (DFPatternCallback ):
533+ '''
534+ A callback to rewrite the following operators into a single layer normalization operator.
535+
536+ 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
537+ 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
538+ 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
539+ 4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */;
540+ 5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
541+ 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
542+ 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
543+ 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
544+ 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
545+ 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
546+ '''
547+
548+ def __init__ (self ):
549+ super (LayerNormRewrite , self ).__init__ ()
550+ self .data = wildcard ()
551+ self .eps = wildcard ()
552+ self .gamma = wildcard ()
553+ self .beta = wildcard ()
554+ mu = is_op ("mean" )(self .data )
555+ diff = is_op ("subtract" )(self .data , mu )
556+ cdiff = diff | is_op ("cast" )(diff )
557+ p1 = is_op ("power" )(cdiff , is_constant ())
558+ mp1 = is_op ("mean" )(p1 )
559+ added_eps = is_op ("add" )(mp1 , self .eps )
560+ deno = is_op ("sqrt" )(added_eps )
561+ div_out = is_op ("divide" )(diff , deno )
562+ weighted = is_op ("multiply" )(div_out , self .gamma )
563+ added_bias = is_op ("add" )(weighted , self .beta )
564+ self .pattern = added_bias
565+
566+ def callback (self , pre , post , node_map ):
567+ data = node_map [self .data ][0 ]
568+ gamma = node_map [self .gamma ][0 ]
569+ beta = node_map [self .beta ][0 ]
570+ return relay .op .nn .layer_norm (data = data , gamma = gamma , beta = beta , epsilon = 1e-5 )
571+
572+
573+ def rewrite_layer_norm (mod ):
574+ """Rewrite the input graph to replace multiple operators with a TVM native layer normalization
575+ operator so that we can offload them to dnnl layer normalization byoc part.
576+ """
577+ mod ["main" ] = rewrite (LayerNormRewrite (), mod ["main" ])
578+ return mod
0 commit comments