Skip to content

Commit 9dd62b4

Browse files
author
Ivy Zhang
authored
[BYOC-DNNL] add support for more ops and fusion patterns
[BYOC-DNNL] add support for more ops and fusion patterns
1 parent d8e39fd commit 9dd62b4

File tree

5 files changed

+932
-131
lines changed

5 files changed

+932
-131
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,39 @@ def _func_wrapper(expr):
6767

6868

6969
_register_external_op_helper("nn.batch_norm")
70+
_register_external_op_helper("nn.conv1d")
7071
_register_external_op_helper("nn.conv2d")
72+
_register_external_op_helper("nn.conv3d")
73+
_register_external_op_helper("nn.conv2d_transpose")
74+
_register_external_op_helper("nn.conv3d_transpose")
7175
_register_external_op_helper("nn.dense")
76+
_register_external_op_helper("nn.max_pool2d")
77+
_register_external_op_helper("nn.avg_pool2d")
78+
_register_external_op_helper("nn.max_pool3d")
79+
_register_external_op_helper("nn.avg_pool3d")
80+
_register_external_op_helper("abs")
81+
_register_external_op_helper("clip")
82+
_register_external_op_helper("exp")
83+
_register_external_op_helper("log")
84+
_register_external_op_helper("sqrt")
85+
_register_external_op_helper("round")
86+
_register_external_op_helper("logsumexp")
7287
_register_external_op_helper("nn.relu")
88+
_register_external_op_helper("nn.leaky_relu")
7389
_register_external_op_helper("tanh")
7490
_register_external_op_helper("sigmoid")
91+
_register_external_op_helper("nn.softmax")
7592
_register_external_op_helper("add")
7693
_register_external_op_helper("multiply")
7794

7895

79-
def make_conv_pattern(with_bias=True, with_eltwise=None):
80-
"""Create patterns related to nn.conv2d.
96+
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
97+
"""Create patterns related to conv and deconv.
8198
8299
Parameters
83100
----------
84101
with_bias : bool
85-
Whether attach `bias_add` to `nn.conv2d`.
102+
Whether attach `bias_add` to `conv / deconv`.
86103
with_eltwise : str
87104
The attached elementwise post-op name.
88105
Returns
@@ -93,7 +110,7 @@ def make_conv_pattern(with_bias=True, with_eltwise=None):
93110
data = wildcard()
94111
weight = wildcard()
95112
bias = wildcard()
96-
conv = is_op("nn.conv2d")(data, weight)
113+
conv = is_op(conv_name)(data, weight)
97114
if with_bias:
98115
conv_out = is_op("add")(conv, bias)
99116
else:
@@ -146,15 +163,19 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
146163
pattern : Tuple(pattern_name, CallPattern)
147164
Created pattern name, along with its CallPattern.
148165
"""
149-
pat_name = "dnnl." + op
166+
pat_name = op.replace("nn", "dnnl")
150167
pat_name += "_bias" if with_bias else ""
151168
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
152-
if op == "conv2d":
153-
dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
154-
elif op == "dense":
169+
if "conv" in op:
170+
dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise))
171+
elif op == "nn.dense":
155172
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
156173
else:
157-
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
174+
logger.warning(
175+
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
176+
"dense op are supported, but got %s.",
177+
op,
178+
)
158179
dnnl_pattern = ()
159180
return dnnl_pattern
160181

@@ -174,8 +195,15 @@ def pattern_table():
174195
for elt in elt_list:
175196
if not with_bias and not elt:
176197
return dnnl_patterns
177-
dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
178-
dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
198+
for conv_name in [
199+
"nn.conv1d",
200+
"nn.conv2d",
201+
"nn.conv3d",
202+
"nn.conv2d_transpose",
203+
"nn.conv3d_transpose",
204+
]:
205+
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
206+
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
179207
return dnnl_patterns
180208

181209

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34+
#include <regex>
3435
#include <sstream>
3536

3637
#include "../../utils.h"
@@ -439,6 +440,23 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
439440
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
440441
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
441442

443+
std::map<std::string, std::string> op_map{
444+
{"bias", "add"},
445+
{"relu", "nn.relu"},
446+
{"tanh", "tanh"},
447+
{"sigmoid", "sigmoid"},
448+
};
449+
450+
std::vector<std::string> ParsingOpList(std::string op, std::string pattern_name) {
451+
std::vector<std::string> op_list = {"nn." + op};
452+
for (auto& t : op_map) {
453+
if (pattern_name.find(t.first) != std::string::npos) {
454+
op_list.push_back(t.second);
455+
}
456+
}
457+
return op_list;
458+
}
459+
442460
public:
443461
DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
444462

@@ -453,28 +471,29 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
453471
ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
454472
name = comp.value();
455473

456-
if (name == "dnnl.conv2d_bias_relu") {
457-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
458-
} else if (name == "dnnl.conv2d_bias_tanh") {
459-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
460-
ICHECK(call->op.as<OpNode>()) << "Not op node";
461-
} else if (name == "dnnl.conv2d_bias_sigmoid") {
462-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
474+
if (name.find("dnnl.conv2d_transpose") != std::string::npos) {
475+
std::vector<std::string> op_list = ParsingOpList("conv2d_transpose", name);
476+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
463477
ICHECK(call->op.as<OpNode>()) << "Not op node";
464-
} else if (name == "dnnl.conv2d_bias") {
465-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
478+
} else if (name.find("dnnl.conv3d_transpose") != std::string::npos) {
479+
std::vector<std::string> op_list = ParsingOpList("conv3d_transpose", name);
480+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
466481
ICHECK(call->op.as<OpNode>()) << "Not op node";
467-
} else if (name == "dnnl.conv2d_relu") {
468-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
482+
} else if (name.find("dnnl.conv1d") != std::string::npos) {
483+
std::vector<std::string> op_list = ParsingOpList("conv1d", name);
484+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
469485
ICHECK(call->op.as<OpNode>()) << "Not op node";
470-
} else if (name == "dnnl.conv2d_tanh") {
471-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
486+
} else if (name.find("dnnl.conv2d") != std::string::npos) {
487+
std::vector<std::string> op_list = ParsingOpList("conv2d", name);
488+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
472489
ICHECK(call->op.as<OpNode>()) << "Not op node";
473-
} else if (name == "dnnl.conv2d_sigmoid") {
474-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
490+
} else if (name.find("dnnl.conv3d") != std::string::npos) {
491+
std::vector<std::string> op_list = ParsingOpList("conv3d", name);
492+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
475493
ICHECK(call->op.as<OpNode>()) << "Not op node";
476-
} else if (name == "dnnl.dense_bias") {
477-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
494+
} else if (name.find("dnnl.dense") != std::string::npos) {
495+
std::vector<std::string> op_list = ParsingOpList("dense", name);
496+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
478497
ICHECK(call->op.as<OpNode>()) << "Not op node";
479498
} else {
480499
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;

0 commit comments

Comments
 (0)