Skip to content

Commit fae8c7d

Browse files
committed
Enable layer normalization in DNNL byoc.
1 parent 274d8fa commit fae8c7d

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
4242

4343
from ... import _ffi_api
44-
from ...dataflow_pattern import wildcard, is_op
44+
from ...dataflow_pattern import wildcard, is_op, is_constant, rewrite, DFPatternCallback
4545
from .register import register_pattern_table
4646

4747
logger = logging.getLogger("DNNL")
@@ -92,6 +92,7 @@ def _func_wrapper(expr):
9292
_register_external_op_helper("nn.softmax")
9393
_register_external_op_helper("add")
9494
_register_external_op_helper("multiply")
95+
_register_external_op_helper("nn.layer_norm")
9596

9697

9798
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
@@ -526,3 +527,52 @@ def visit_call(self, call):
526527
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
527528
new_mod = transform.RemoveUnusedFunctions()(new_mod)
528529
return new_mod
530+
531+
532+
class LayerNormRewrite(DFPatternCallback):
533+
'''
534+
A callback to rewrite the following operators into a single layer normalization operator.
535+
536+
1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
537+
2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
538+
3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
539+
4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */;
540+
5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
541+
6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
542+
7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
543+
8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
544+
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
545+
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
546+
'''
547+
548+
def __init__(self):
549+
super(LayerNormRewrite, self).__init__()
550+
self.data = wildcard()
551+
self.eps = wildcard()
552+
self.gamma = wildcard()
553+
self.beta = wildcard()
554+
mu = is_op("mean")(self.data)
555+
diff = is_op("subtract")(self.data, mu)
556+
cdiff = diff | is_op("cast")(diff)
557+
p1 = is_op("power")(cdiff, is_constant())
558+
mp1 = is_op("mean")(p1)
559+
added_eps = is_op("add")(mp1, self.eps)
560+
deno = is_op("sqrt")(added_eps)
561+
div_out = is_op("divide")(diff, deno)
562+
weighted = is_op("multiply")(div_out, self.gamma)
563+
added_bias = is_op("add")(weighted, self.beta)
564+
self.pattern = added_bias
565+
566+
def callback(self, pre, post, node_map):
567+
data = node_map[self.data][0]
568+
gamma = node_map[self.gamma][0]
569+
beta = node_map[self.beta][0]
570+
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta, epsilon=1e-5)
571+
572+
573+
def rewrite_layer_norm(mod):
574+
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
575+
operator so that we can offload them to dnnl layer normalization byoc part.
576+
"""
577+
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
578+
return mod

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
203203
Binary(nid, dnnl::algorithm::binary_add);
204204
} else if ("multiply" == op_name) {
205205
Binary(nid, dnnl::algorithm::binary_mul);
206+
} else if ("nn.layer_norm" == op_name) {
207+
LayerNorm(nid);
206208
} else {
207209
LOG(FATAL) << "Unsupported op: " << op_name;
208210
}
@@ -449,6 +451,47 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
449451
{DNNL_ARG_VARIANCE, var_tr}});
450452
}
451453

454+
void LayerNorm(const size_t& nid) {
455+
auto node = nodes_[nid];
456+
457+
auto data_entry = node.GetInputs()[0];
458+
auto gamma_entry = node.GetInputs()[1];
459+
auto beta_entry = node.GetInputs()[2];
460+
461+
dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
462+
463+
float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
464+
465+
// Memory description.
466+
dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
467+
468+
// LN description.
469+
auto lnorm_desc = dnnl::layer_normalization_forward::desc(
470+
dnnl::prop_kind::forward_inference, data_md, epsilon,
471+
dnnl::normalization_flags::use_scale | dnnl::normalization_flags::use_shift);
472+
473+
auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);
474+
auto lnorm_prim = dnnl::layer_normalization_forward(lnorm_prim_desc);
475+
476+
net_.push_back(lnorm_prim);
477+
478+
// Memories.
479+
auto data_memory = BindDNNLMemory(data_entry, data_md);
480+
JSONGraphNodeEntry out_entry(nid, 0);
481+
auto dst_memory = BindDNNLMemory(out_entry, data_md);
482+
auto scale_memory = BindDNNLMemory(gamma_entry, data_md);
483+
auto shift_memory = BindDNNLMemory(beta_entry, data_md);
484+
auto mean_memory = dnnl::memory(lnorm_prim_desc.mean_desc(), engine_);
485+
auto variance_memory = dnnl::memory(lnorm_prim_desc.variance_desc(), engine_);
486+
487+
net_args_.push_back({{DNNL_ARG_SRC, data_memory},
488+
{DNNL_ARG_MEAN, mean_memory},
489+
{DNNL_ARG_VARIANCE, variance_memory},
490+
{DNNL_ARG_SCALE, scale_memory},
491+
{DNNL_ARG_SHIFT, shift_memory},
492+
{DNNL_ARG_DST, dst_memory}});
493+
}
494+
452495
void Pooling(const size_t& nid, dnnl::algorithm algo) {
453496
auto node = nodes_[nid];
454497

0 commit comments

Comments
 (0)