4141from tvm .relay .expr_functor import ExprMutator , ExprVisitor
4242
4343from ... import _ffi_api
44- from ...dataflow_pattern import wildcard , is_op , is_constant , rewrite , DFPatternCallback
44+ from ...dataflow_pattern import wildcard , is_op , is_constant , is_expr , rewrite , DFPatternCallback
4545from .register import register_pattern_table
4646
4747logger = logging .getLogger ("DNNL" )
@@ -456,6 +456,7 @@ def visit_call(self, call):
456456 "nn.conv3d" ,
457457 "nn.conv3d_transpose" ,
458458 "nn.dense" ,
459+ "nn.layer_norm" ,
459460 ]
460461 )
461462 if isinstance (call .op , tvm .tir .op .Op ):
@@ -530,9 +531,10 @@ def visit_call(self, call):
530531
531532
532533class LayerNormRewrite (DFPatternCallback ):
533- '''
534+ """
534535 A callback to rewrite the following operators into a single layer normalization operator.
535536
537+ Pattern #1:
536538 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
537539 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
538540 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
@@ -541,22 +543,38 @@ class LayerNormRewrite(DFPatternCallback):
541543 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
542544 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
543545 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
544- 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
545- 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
546- '''
546+ 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */)
547+ /* ty=Tensor[(1, 3136, 64), float32] */;
548+ 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
549+ /* ty=Tensor[(1, 3136, 64), float32] */;
550+
551+ Pattern #2:
552+ 1 %0 = mean(%input, axis=[-1], keepdims=True);
553+ 2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
554+ 3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */;
555+ 4 %3 = subtract(%input, %0);
556+ 5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
557+ 6 %5 = divide(%3, %4);
558+ 7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */)
559+ /* ty=Tensor[(1, 49, 64), float32] */;
560+ 8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
561+ /* ty=Tensor[(1, 49, 64), float32] */
562+
563+ """
547564
548565 def __init__ (self ):
549566 super (LayerNormRewrite , self ).__init__ ()
550567 self .data = wildcard ()
551- self .eps = wildcard ()
552568 self .gamma = wildcard ()
553569 self .beta = wildcard ()
554570 mu = is_op ("mean" )(self .data )
555571 diff = is_op ("subtract" )(self .data , mu )
556572 cdiff = diff | is_op ("cast" )(diff )
557- p1 = is_op ("power" )(cdiff , is_constant ())
558- mp1 = is_op ("mean" )(p1 )
559- added_eps = is_op ("add" )(mp1 , self .eps )
573+ const_two = is_expr (relay .const (2 )) | is_expr (relay .const (2.0 ))
574+ p1 = is_op ("power" )(cdiff , const_two )
575+ mp1 = is_op ("mean" )(p1 ) | is_op ("variance" )(self .data , mu )
576+ eps = is_expr (relay .const (1e-5 ))
577+ added_eps = is_op ("add" )(mp1 , eps )
560578 deno = is_op ("sqrt" )(added_eps )
561579 div_out = is_op ("divide" )(diff , deno )
562580 weighted = is_op ("multiply" )(div_out , self .gamma )
@@ -567,12 +585,12 @@ def callback(self, pre, post, node_map):
567585 data = node_map [self .data ][0 ]
568586 gamma = node_map [self .gamma ][0 ]
569587 beta = node_map [self .beta ][0 ]
570- return relay .op .nn .layer_norm (data = data , gamma = gamma , beta = beta , epsilon = 1e-5 )
588+ return relay .op .nn .layer_norm (data = data , gamma = gamma , beta = beta )
571589
572590
573591def rewrite_layer_norm (mod ):
574592 """Rewrite the input graph to replace multiple operators with a TVM native layer normalization
575- operator so that we can offload them to dnnl layer normalization byoc part.
593+ operator so that we can offload them to dnnl layer normalization byoc part.
576594 """
577595 mod ["main" ] = rewrite (LayerNormRewrite (), mod ["main" ])
578- return mod
596+ return mod
0 commit comments