Skip to content

Commit 7a6b334

Browse files
committed
Added unittest for layer norm and make code compatible after introducing TensorRequisite(PR-11345)
1 parent fae8c7d commit 7a6b334

File tree

3 files changed

+88
-39
lines changed

3 files changed

+88
-39
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
4242

4343
from ... import _ffi_api
44-
from ...dataflow_pattern import wildcard, is_op, is_constant, rewrite, DFPatternCallback
44+
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
4545
from .register import register_pattern_table
4646

4747
logger = logging.getLogger("DNNL")
@@ -456,6 +456,7 @@ def visit_call(self, call):
456456
"nn.conv3d",
457457
"nn.conv3d_transpose",
458458
"nn.dense",
459+
"nn.layer_norm",
459460
]
460461
)
461462
if isinstance(call.op, tvm.tir.op.Op):
@@ -530,9 +531,10 @@ def visit_call(self, call):
530531

531532

532533
class LayerNormRewrite(DFPatternCallback):
533-
'''
534+
"""
534535
A callback to rewrite the following operators into a single layer normalization operator.
535536
537+
Pattern #1:
536538
1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
537539
2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
538540
3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
@@ -541,22 +543,38 @@ class LayerNormRewrite(DFPatternCallback):
541543
6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
542544
7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
543545
8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
544-
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
545-
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */;
546-
'''
546+
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */)
547+
/* ty=Tensor[(1, 3136, 64), float32] */;
548+
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
549+
/* ty=Tensor[(1, 3136, 64), float32] */;
550+
551+
Pattern #2:
552+
1 %0 = mean(%input, axis=[-1], keepdims=True);
553+
2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
554+
3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */;
555+
4 %3 = subtract(%input, %0);
556+
5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
557+
6 %5 = divide(%3, %4);
558+
7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */)
559+
/* ty=Tensor[(1, 49, 64), float32] */;
560+
8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
561+
/* ty=Tensor[(1, 49, 64), float32] */
562+
563+
"""
547564

548565
def __init__(self):
549566
super(LayerNormRewrite, self).__init__()
550567
self.data = wildcard()
551-
self.eps = wildcard()
552568
self.gamma = wildcard()
553569
self.beta = wildcard()
554570
mu = is_op("mean")(self.data)
555571
diff = is_op("subtract")(self.data, mu)
556572
cdiff = diff | is_op("cast")(diff)
557-
p1 = is_op("power")(cdiff, is_constant())
558-
mp1 = is_op("mean")(p1)
559-
added_eps = is_op("add")(mp1, self.eps)
573+
const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
574+
p1 = is_op("power")(cdiff, const_two)
575+
mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
576+
eps = is_expr(relay.const(1e-5))
577+
added_eps = is_op("add")(mp1, eps)
560578
deno = is_op("sqrt")(added_eps)
561579
div_out = is_op("divide")(diff, deno)
562580
weighted = is_op("multiply")(div_out, self.gamma)
@@ -567,12 +585,12 @@ def callback(self, pre, post, node_map):
567585
data = node_map[self.data][0]
568586
gamma = node_map[self.gamma][0]
569587
beta = node_map[self.beta][0]
570-
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta, epsilon=1e-5)
588+
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
571589

572590

573591
def rewrite_layer_norm(mod):
574592
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
575-
operator so that we can offload them to dnnl layer normalization byoc part.
593+
operator so that we can offload them to dnnl layer normalization byoc part.
576594
"""
577595
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
578-
return mod
596+
return mod

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -452,44 +452,50 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
452452
}
453453

454454
void LayerNorm(const size_t& nid) {
455+
455456
auto node = nodes_[nid];
456457

457-
auto data_entry = node.GetInputs()[0];
458-
auto gamma_entry = node.GetInputs()[1];
459-
auto beta_entry = node.GetInputs()[2];
460-
461-
dnnl::memory::dims data_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
458+
auto src_tr = GetInput(nid, 0);
459+
auto gamma_tr = GetInput(nid, 1);
460+
auto beta_tr = GetInput(nid, 2);
461+
auto dst_tr = GetOutput(nid, 0);
462462

463-
float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
463+
auto axis = GetNodeAttr<int>(node, "axis");
464+
auto epsilon = GetNodeAttr<float>(node, "epsilon");
465+
auto center = GetNodeAttr<bool>(node, "center");
466+
auto scale = GetNodeAttr<bool>(node, "scale");
464467

465-
// Memory description.
466-
dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
468+
ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case";
467469

468470
// LN description.
469471
auto lnorm_desc = dnnl::layer_normalization_forward::desc(
470-
dnnl::prop_kind::forward_inference, data_md, epsilon,
471-
dnnl::normalization_flags::use_scale | dnnl::normalization_flags::use_shift);
472+
dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
473+
dnnl::normalization_flags::use_scale_shift);
472474

473475
auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);
474476
auto lnorm_prim = dnnl::layer_normalization_forward(lnorm_prim_desc);
475477

476-
net_.push_back(lnorm_prim);
477-
478-
// Memories.
479-
auto data_memory = BindDNNLMemory(data_entry, data_md);
480-
JSONGraphNodeEntry out_entry(nid, 0);
481-
auto dst_memory = BindDNNLMemory(out_entry, data_md);
482-
auto scale_memory = BindDNNLMemory(gamma_entry, data_md);
483-
auto shift_memory = BindDNNLMemory(beta_entry, data_md);
484-
auto mean_memory = dnnl::memory(lnorm_prim_desc.mean_desc(), engine_);
485-
auto variance_memory = dnnl::memory(lnorm_prim_desc.variance_desc(), engine_);
486-
487-
net_args_.push_back({{DNNL_ARG_SRC, data_memory},
488-
{DNNL_ARG_MEAN, mean_memory},
489-
{DNNL_ARG_VARIANCE, variance_memory},
490-
{DNNL_ARG_SCALE, scale_memory},
491-
{DNNL_ARG_SHIFT, shift_memory},
492-
{DNNL_ARG_DST, dst_memory}});
478+
// Concatenate scale and shift tensors
479+
auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid());
480+
auto sc_sh_dims = scale_shift_tr.dims();
481+
482+
ICHECK(sc_sh_dims.size() == 2);
483+
ICHECK(sc_sh_dims[0] == 2);
484+
sc_sh_dims[0] /= 2;
485+
auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
486+
auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();
487+
488+
auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) {
489+
dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc());
490+
Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
491+
};
492+
493+
register_copy(gamma_tr, scale_tr);
494+
register_copy(beta_tr, shift_tr);
495+
496+
Submit(dnnl::layer_normalization_forward(lnorm_prim_desc), {{DNNL_ARG_SRC, src_tr},
497+
{DNNL_ARG_DST, dst_tr},
498+
{DNNL_ARG_SCALE_SHIFT, scale_shift_tr}});
493499
}
494500

495501
void Pooling(const size_t& nid, dnnl::algorithm algo) {

tests/python/contrib/test_dnnl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
111111
with tvm.transform.PassContext(opt_level=3):
112112
mod = alter_layout_seq(mod)
113113

114+
mod = dnnl.rewrite_layer_norm(mod)
115+
114116
byoc_seq = tvm.transform.Sequential(
115117
[
116118
transform.MergeComposite(dnnl.pattern_table()),
@@ -183,6 +185,8 @@ def check_dnnl_used(mod, subgraph_num=None):
183185
continue
184186
if use_dnnl:
185187
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
188+
print("hebi-dbg: processed_mod")
189+
print(processed_mod)
186190
check_dnnl_used(processed_mod)
187191

188192
with tvm.transform.PassContext(opt_level=3):
@@ -192,6 +196,8 @@ def check_dnnl_used(mod, subgraph_num=None):
192196
if run_module:
193197
if isinstance(input, dict):
194198
result_dict[result_key] = func(**input, **params)
199+
print("result_dict: result_key = ", result_key)
200+
print(result_dict[result_key])
195201
else:
196202
result_dict[result_key] = func(input, **params)
197203

@@ -454,6 +460,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype
454460
return relay.nn.relu(conv2d_bias_bn), dic, param_lst
455461

456462

463+
def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
464+
dic = {"input": x_shape}
465+
param_lst = []
466+
input = relay.var("input", shape=x_shape)
467+
beta = relay.const(np.zeros(x_shape[2]).astype(dtype))
468+
gamma = relay.const(np.ones(x_shape[2]).astype(dtype))
469+
out = relay.nn.layer_norm(input, gamma=gamma, beta=beta)
470+
return out, dic, param_lst
471+
472+
457473
def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
458474
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
459475
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
@@ -1032,5 +1048,14 @@ def get_graph():
10321048
run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False)
10331049

10341050

1051+
def test_layer_norm(run_module, dtype="float32"):
1052+
x_shape = (1, 49, 64)
1053+
1054+
ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype)
1055+
ln = tvm.IRModule.from_expr(ln)
1056+
config = ln, dic, param_lst
1057+
run_and_verify_func(config, run_module=run_module, dtype=dtype)
1058+
1059+
10351060
if __name__ == "__main__":
10361061
tvm.testing.main()

0 commit comments

Comments
 (0)