diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 2e975cf49c88..c87a7162b070 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -41,7 +41,7 @@ from tvm.relay.expr_functor import ExprMutator, ExprVisitor from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op +from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table logger = logging.getLogger("DNNL") @@ -92,6 +92,7 @@ def _func_wrapper(expr): _register_external_op_helper("nn.softmax") _register_external_op_helper("add") _register_external_op_helper("multiply") +_register_external_op_helper("nn.layer_norm") def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): @@ -455,6 +456,7 @@ def visit_call(self, call): "nn.conv3d", "nn.conv3d_transpose", "nn.dense", + "nn.layer_norm", ] ) if isinstance(call.op, tvm.tir.op.Op): @@ -526,3 +528,69 @@ def visit_call(self, call): new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"]) new_mod = transform.RemoveUnusedFunctions()(new_mod) return new_mod + + +class LayerNormRewrite(DFPatternCallback): + """ + A callback to rewrite the following operators into a single layer normalization operator. + + Pattern #1: + 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */; + 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */; + 4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */; + 5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */; + 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */; + 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */; + 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */; + 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 3136, 64), float32] */; + 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 3136, 64), float32] */; + + Pattern #2: + 1 %0 = mean(%input, axis=[-1], keepdims=True); + 2 %1 = variance(%input, %0, axis=[-1], keepdims=True); + 3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */; + 4 %3 = subtract(%input, %0); + 5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */; + 6 %5 = divide(%3, %4); + 7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 49, 64), float32] */; + 8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */) + /* ty=Tensor[(1, 49, 64), float32] */ + + """ + + def __init__(self): + super(LayerNormRewrite, self).__init__() + self.data = wildcard() + self.gamma = wildcard() + self.beta = wildcard() + mu = is_op("mean")(self.data) + diff = is_op("subtract")(self.data, mu) + cdiff = diff | is_op("cast")(diff) + const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0)) + p1 = is_op("power")(cdiff, const_two) + mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu) + eps = is_expr(relay.const(1e-5)) + added_eps = is_op("add")(mp1, eps) + deno = is_op("sqrt")(added_eps) + div_out = is_op("divide")(diff, deno) + weighted = is_op("multiply")(div_out, self.gamma) + added_bias = is_op("add")(weighted, self.beta) + self.pattern = added_bias + + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + gamma = node_map[self.gamma][0] + beta = node_map[self.beta][0] + return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta) + + +def rewrite_layer_norm(mod): + """Rewrite the input graph to replace multiple operators with a TVM native layer normalization + operator so that we can offload them to dnnl layer normalization byoc part. + """ + mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) + return mod diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index a2417f012ea4..db8f25e2a6ea 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { Binary(nid, dnnl::algorithm::binary_add); } else if ("multiply" == op_name) { Binary(nid, dnnl::algorithm::binary_mul); + } else if ("nn.layer_norm" == op_name) { + LayerNorm(nid); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {DNNL_ARG_VARIANCE, var_tr}}); } + void LayerNorm(const size_t& nid) { + auto node = nodes_[nid]; + + auto src_tr = GetInput(nid, 0); + auto gamma_tr = GetInput(nid, 1); + auto beta_tr = GetInput(nid, 2); + auto dst_tr = GetOutput(nid, 0); + + auto axis = GetNodeAttr(node, "axis"); + auto epsilon = GetNodeAttr(node, "epsilon"); + auto center = GetNodeAttr(node, "center"); + auto scale = GetNodeAttr(node, "scale"); + + ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case"; + + // LN description. + auto lnorm_desc = dnnl::layer_normalization_forward::desc( + dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, + dnnl::normalization_flags::use_scale_shift); + + auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_); + + // Concatenate scale and shift tensors + auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid()); + auto sc_sh_dims = scale_shift_tr.dims(); + + ICHECK(sc_sh_dims.size() == 2); + ICHECK(sc_sh_dims[0] == 2); + sc_sh_dims[0] /= 2; + auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); + auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); + + auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) { + dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc()); + Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); + }; + + register_copy(gamma_tr, scale_tr); + register_copy(beta_tr, shift_tr); + + Submit( + dnnl::layer_normalization_forward(lnorm_prim_desc), + {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}}); + } + void Pooling(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index babfad4a0c8c..3e4e831aa594 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): with tvm.transform.PassContext(opt_level=3): mod = alter_layout_seq(mod) + mod = dnnl.rewrite_layer_norm(mod) + byoc_seq = tvm.transform.Sequential( [ transform.MergeComposite(dnnl.pattern_table()), @@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype return relay.nn.relu(conv2d_bias_bn), dic, param_lst +def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"): + dic = {"input": x_shape} + param_lst = [] + input = relay.var("input", shape=x_shape) + beta = relay.const(np.zeros(x_shape[2]).astype(dtype)) + gamma = relay.const(np.ones(x_shape[2]).astype(dtype)) + out = relay.nn.layer_norm(input, gamma=gamma, beta=beta) + return out, dic, param_lst + + def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype) sum_data = relay.const(np.random.randint(x_shape).astype(dtype)) @@ -1032,5 +1044,14 @@ def get_graph(): run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False) +def test_layer_norm(run_module, dtype="float32"): + x_shape = (1, 49, 64) + + ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype) + ln = tvm.IRModule.from_expr(ln) + config = ln, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + if __name__ == "__main__": tvm.testing.main()