Skip to content

Commit 34f79a8

Browse files
author
Ivy Zhang
committed
enable convolution(1d-3d), deconvolution(2d-3d), pooling(2d-3d), some CV activations.
1 parent 14d0187 commit 34f79a8

File tree

4 files changed

+919
-128
lines changed

4 files changed

+919
-128
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,42 @@ def _func_wrapper(expr):
6767

6868

6969
_register_external_op_helper("nn.batch_norm")
70+
_register_external_op_helper("nn.conv1d")
7071
_register_external_op_helper("nn.conv2d")
72+
_register_external_op_helper("nn.conv3d")
73+
_register_external_op_helper("nn.conv2d_transpose")
74+
_register_external_op_helper("nn.conv3d_transpose")
7175
_register_external_op_helper("nn.dense")
76+
"""Pooling"""
77+
_register_external_op_helper("nn.max_pool2d")
78+
_register_external_op_helper("nn.avg_pool2d")
79+
_register_external_op_helper("nn.max_pool3d")
80+
_register_external_op_helper("nn.avg_pool3d")
81+
"""Activation"""
82+
_register_external_op_helper("abs")
83+
_register_external_op_helper("clip")
84+
_register_external_op_helper("exp")
85+
_register_external_op_helper("log")
86+
_register_external_op_helper("sqrt")
87+
_register_external_op_helper("round")
88+
_register_external_op_helper("logsumexp")
7289
_register_external_op_helper("nn.relu")
90+
_register_external_op_helper("nn.leaky_relu")
7391
_register_external_op_helper("tanh")
7492
_register_external_op_helper("sigmoid")
93+
_register_external_op_helper("nn.softmax")
94+
"""Binary"""
7595
_register_external_op_helper("add")
7696
_register_external_op_helper("multiply")
7797

7898

79-
def make_conv_pattern(with_bias=True, with_eltwise=None):
80-
"""Create patterns related to nn.conv2d.
99+
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
100+
"""Create patterns related to conv and deconv.
81101
82102
Parameters
83103
----------
84104
with_bias : bool
85-
Whether attach `bias_add` to `nn.conv2d`.
105+
Whether attach `bias_add` to `conv / deconv`.
86106
with_eltwise : str
87107
The attached elementwise post-op name.
88108
Returns
@@ -93,7 +113,7 @@ def make_conv_pattern(with_bias=True, with_eltwise=None):
93113
data = wildcard()
94114
weight = wildcard()
95115
bias = wildcard()
96-
conv = is_op("nn.conv2d")(data, weight)
116+
conv = is_op(conv_name)(data, weight)
97117
if with_bias:
98118
conv_out = is_op("add")(conv, bias)
99119
else:
@@ -146,15 +166,18 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
146166
pattern : Tuple(pattern_name, CallPattern)
147167
Created pattern name, along with its CallPattern.
148168
"""
149-
pat_name = "dnnl." + op
169+
pat_name = op.replace("nn", "dnnl")
150170
pat_name += "_bias" if with_bias else ""
151171
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
152-
if op == "conv2d":
153-
dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
154-
elif op == "dense":
172+
if "conv" in op:
173+
dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise))
174+
elif op == "nn.dense":
155175
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
156176
else:
157-
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
177+
logger.warning(
178+
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and dense op are supported, but got %s.",
179+
op,
180+
)
158181
dnnl_pattern = ()
159182
return dnnl_pattern
160183

@@ -174,8 +197,15 @@ def pattern_table():
174197
for elt in elt_list:
175198
if not with_bias and not elt:
176199
return dnnl_patterns
177-
dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
178-
dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
200+
for conv_name in [
201+
"nn.conv1d",
202+
"nn.conv2d",
203+
"nn.conv3d",
204+
"nn.conv2d_transpose",
205+
"nn.conv3d_transpose",
206+
]:
207+
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
208+
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
179209
return dnnl_patterns
180210

181211

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34+
#include <regex>
3435
#include <sstream>
3536

3637
#include "../../utils.h"
@@ -439,6 +440,23 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
439440
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
440441
using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
441442

443+
std::map<std::string, std::string> op_map{
444+
{"bias", "add"},
445+
{"relu", "nn.relu"},
446+
{"tanh", "tanh"},
447+
{"sigmoid", "sigmoid"},
448+
};
449+
450+
std::vector<std::string> ParsingOpList(std::string op, std::string pattern_name) {
451+
std::vector<std::string> op_list = {"nn." + op};
452+
for (auto& t : op_map) {
453+
if (pattern_name.find(t.first) != std::string::npos) {
454+
op_list.push_back(t.second);
455+
}
456+
}
457+
return op_list;
458+
}
459+
442460
public:
443461
DNNLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
444462

@@ -453,28 +471,29 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
453471
ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
454472
name = comp.value();
455473

456-
if (name == "dnnl.conv2d_bias_relu") {
457-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
458-
} else if (name == "dnnl.conv2d_bias_tanh") {
459-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
460-
ICHECK(call->op.as<OpNode>()) << "Not op node";
461-
} else if (name == "dnnl.conv2d_bias_sigmoid") {
462-
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
474+
if (name.find("dnnl.conv2d_transpose") != std::string::npos) {
475+
std::vector<std::string> op_list = ParsingOpList("conv2d_transpose", name);
476+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
463477
ICHECK(call->op.as<OpNode>()) << "Not op node";
464-
} else if (name == "dnnl.conv2d_bias") {
465-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
478+
} else if (name.find("dnnl.conv3d_transpose") != std::string::npos) {
479+
std::vector<std::string> op_list = ParsingOpList("conv3d_transpose", name);
480+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
466481
ICHECK(call->op.as<OpNode>()) << "Not op node";
467-
} else if (name == "dnnl.conv2d_relu") {
468-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
482+
} else if (name.find("dnnl.conv1d") != std::string::npos) {
483+
std::vector<std::string> op_list = ParsingOpList("conv1d", name);
484+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
469485
ICHECK(call->op.as<OpNode>()) << "Not op node";
470-
} else if (name == "dnnl.conv2d_tanh") {
471-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
486+
} else if (name.find("dnnl.conv2d") != std::string::npos) {
487+
std::vector<std::string> op_list = ParsingOpList("conv2d", name);
488+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
472489
ICHECK(call->op.as<OpNode>()) << "Not op node";
473-
} else if (name == "dnnl.conv2d_sigmoid") {
474-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
490+
} else if (name.find("dnnl.conv3d") != std::string::npos) {
491+
std::vector<std::string> op_list = ParsingOpList("conv3d", name);
492+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
475493
ICHECK(call->op.as<OpNode>()) << "Not op node";
476-
} else if (name == "dnnl.dense_bias") {
477-
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
494+
} else if (name.find("dnnl.dense") != std::string::npos) {
495+
std::vector<std::string> op_list = ParsingOpList("dense", name);
496+
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
478497
ICHECK(call->op.as<OpNode>()) << "Not op node";
479498
} else {
480499
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;

0 commit comments

Comments
 (0)