Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table

logger = logging.getLogger("DNNL")
Expand Down Expand Up @@ -92,6 +92,7 @@ def _func_wrapper(expr):
_register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
_register_external_op_helper("nn.layer_norm")


def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
Expand Down Expand Up @@ -455,6 +456,7 @@ def visit_call(self, call):
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
"nn.layer_norm",
]
)
if isinstance(call.op, tvm.tir.op.Op):
Expand Down Expand Up @@ -526,3 +528,69 @@ def visit_call(self, call):
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod


class LayerNormRewrite(DFPatternCallback):
"""
A callback to rewrite the following operators into a single layer normalization operator.

Pattern #1:
1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), float32] */;
5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), float32] */;
6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), float32] */;
7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;
10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 3136, 64), float32] */;

Pattern #2:
1 %0 = mean(%input, axis=[-1], keepdims=True);
2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), float32] */;
4 %3 = subtract(%input, %0);
5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
6 %5 = divide(%3, %4);
7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 49, 64), float32] */;
8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
/* ty=Tensor[(1, 49, 64), float32] */

"""

def __init__(self):
super(LayerNormRewrite, self).__init__()
self.data = wildcard()
self.gamma = wildcard()
self.beta = wildcard()
mu = is_op("mean")(self.data)
diff = is_op("subtract")(self.data, mu)
cdiff = diff | is_op("cast")(diff)
const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
p1 = is_op("power")(cdiff, const_two)
mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
eps = is_expr(relay.const(1e-5))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
weighted = is_op("multiply")(div_out, self.gamma)
added_bias = is_op("add")(weighted, self.beta)
self.pattern = added_bias

def callback(self, pre, post, node_map):
data = node_map[self.data][0]
gamma = node_map[self.gamma][0]
beta = node_map[self.beta][0]
return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)


def rewrite_layer_norm(mod):
"""Rewrite the input graph to replace multiple operators with a TVM native layer normalization
operator so that we can offload them to dnnl layer normalization byoc part.
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod
47 changes: 47 additions & 0 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Binary(nid, dnnl::algorithm::binary_mul);
} else if ("nn.layer_norm" == op_name) {
LayerNorm(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down Expand Up @@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, var_tr}});
}

void LayerNorm(const size_t& nid) {
auto node = nodes_[nid];

auto src_tr = GetInput(nid, 0);
auto gamma_tr = GetInput(nid, 1);
auto beta_tr = GetInput(nid, 2);
auto dst_tr = GetOutput(nid, 0);

auto axis = GetNodeAttr<int>(node, "axis");
auto epsilon = GetNodeAttr<float>(node, "epsilon");
auto center = GetNodeAttr<bool>(node, "center");
auto scale = GetNodeAttr<bool>(node, "scale");

ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case";

// LN description.
auto lnorm_desc = dnnl::layer_normalization_forward::desc(
dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
dnnl::normalization_flags::use_scale_shift);

auto lnorm_prim_desc = dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);

// Concatenate scale and shift tensors
auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid());
auto sc_sh_dims = scale_shift_tr.dims();

ICHECK(sc_sh_dims.size() == 2);
ICHECK(sc_sh_dims[0] == 2);
sc_sh_dims[0] /= 2;
auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();

auto register_copy = [this](const TensorRequisite& src, const TensorRequisite& dst) {
dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, dst.desc());
Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
};

register_copy(gamma_tr, scale_tr);
register_copy(beta_tr, shift_tr);

Submit(
dnnl::layer_normalization_forward(lnorm_prim_desc),
{{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, {DNNL_ARG_SCALE_SHIFT, scale_shift_tr}});
}

void Pooling(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];

Expand Down
21 changes: 21 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
with tvm.transform.PassContext(opt_level=3):
mod = alter_layout_seq(mod)

mod = dnnl.rewrite_layer_norm(mod)

byoc_seq = tvm.transform.Sequential(
[
transform.MergeComposite(dnnl.pattern_table()),
Expand Down Expand Up @@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype
return relay.nn.relu(conv2d_bias_bn), dic, param_lst


def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
dic = {"input": x_shape}
param_lst = []
input = relay.var("input", shape=x_shape)
beta = relay.const(np.zeros(x_shape[2]).astype(dtype))
gamma = relay.const(np.ones(x_shape[2]).astype(dtype))
out = relay.nn.layer_norm(input, gamma=gamma, beta=beta)
return out, dic, param_lst


def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, dtype=dtype)
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
Expand Down Expand Up @@ -1032,5 +1044,14 @@ def get_graph():
run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False)


def test_layer_norm(run_module, dtype="float32"):
x_shape = (1, 49, 64)

ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype)
ln = tvm.IRModule.from_expr(ln)
config = ln, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


if __name__ == "__main__":
tvm.testing.main()