Skip to content

Commit b8c748c

Browse files
Ivy Zhangpfk-beta
authored andcommitted
[BYOC-DNNL] Support DNNL optimal layout (apache#10421)
* enable dnnl optimal layout for supported ops * verfied cv models with onednnv1.7 * rebase to the latest main branch * fix format related comments * remove unnecessary layout transformation * change deconv into conv_transpose * rename some variables and functions * simplify query_layout * add checkes for query_layout * fix lint * move partition_for_dnnl from dnnl.py to test_dnnl.py * remove unnecessary model test * add more dnnl layout * rename flag in convolution.cc * enhance dnnl layout
1 parent 9ac0428 commit b8c748c

File tree

7 files changed

+990
-174
lines changed

7 files changed

+990
-174
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 207 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
import logging
3636

3737
import tvm.ir
38-
from tvm.relay import transform
39-
from tvm.relay.build_module import bind_params_by_name
38+
from tvm import relay
4039

40+
from ... import _ffi_api
4141
from ...dataflow_pattern import wildcard, is_op
4242
from .register import register_pattern_table
4343

@@ -94,12 +94,12 @@ def _func_wrapper(expr):
9494

9595

9696
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
97-
"""Create patterns related to conv and deconv.
97+
"""Create patterns related to conv and conv_transpose.
9898
9999
Parameters
100100
----------
101101
with_bias : bool
102-
Whether attach `bias_add` to `conv / deconv`.
102+
Whether attach `bias_add` to `conv / conv_transpose`.
103103
with_eltwise : str
104104
The attached elementwise post-op name.
105105
Returns
@@ -147,12 +147,12 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
147147
return dense_out
148148

149149

150-
def make_dnnl_pattern(op, with_bias, with_eltwise):
150+
def make_dnnl_pattern(op_name, with_bias, with_eltwise):
151151
"""Create dnnl patterns.
152152
153153
Parameters
154154
----------
155-
op : str
155+
op_name : str
156156
The first call node's op name.
157157
with_bias : bool
158158
Whether attach `bias_add` to `nn.dense`.
@@ -163,18 +163,20 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
163163
pattern : Tuple(pattern_name, CallPattern)
164164
Created pattern name, along with its CallPattern.
165165
"""
166-
pat_name = op.replace("nn", "dnnl")
166+
pat_name = op_name.replace("nn", "dnnl")
167+
if "_transpose" in op_name:
168+
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
167169
pat_name += "_bias" if with_bias else ""
168170
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
169-
if "conv" in op:
170-
dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise))
171-
elif op == "nn.dense":
171+
if "conv" in op_name:
172+
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
173+
elif op_name == "nn.dense":
172174
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
173175
else:
174176
logger.warning(
175177
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
176178
"dense op are supported, but got %s.",
177-
op,
179+
op_name,
178180
)
179181
dnnl_pattern = ()
180182
return dnnl_pattern
@@ -207,39 +209,205 @@ def pattern_table():
207209
return dnnl_patterns
208210

209211

210-
def partition_for_dnnl(mod, params=None):
211-
"""Partition the graph greedily offloading supported operators to DNNL.
212+
def get_optimal_layout_for_conv(
213+
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
214+
):
215+
"""Get the optimal layout of dnnl, given shape of conv2d.
212216
213217
Parameters
214218
----------
215-
mod : Module
216-
The module to run passes on.
217-
params : Optional[Dict[str, NDArray]]
218-
Constant input parameters.
219+
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
220+
: String
221+
Input argument.
222+
219223
Returns
220224
-------
221-
mod : Module
222-
Annotated and partitioned module.
225+
layouts : string
226+
The result.
223227
"""
228+
return _ffi_api.get_optimal_layout_for_conv(
229+
data_layout,
230+
kernel_layout,
231+
weight_shape,
232+
out_shape,
233+
paddings,
234+
strides,
235+
dilates,
236+
groups,
237+
)
238+
239+
240+
def get_optimal_layout_for_conv_transpose(
241+
data_layout,
242+
kernel_layout,
243+
weight_shape,
244+
out_shape,
245+
paddings,
246+
output_paddings,
247+
strides,
248+
dilates,
249+
groups,
250+
):
251+
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
252+
253+
Parameters
254+
----------
255+
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
256+
dilates, groups
257+
: Int, String
258+
Input argument.
259+
260+
Returns
261+
-------
262+
layouts : string
263+
The result.
264+
"""
265+
return _ffi_api.get_optimal_layout_for_conv_transpose(
266+
data_layout,
267+
kernel_layout,
268+
weight_shape,
269+
out_shape,
270+
paddings,
271+
output_paddings,
272+
strides,
273+
dilates,
274+
groups,
275+
)
276+
277+
278+
def get_shape(tensor):
279+
"""Get tensor's shape."""
280+
if isinstance(tensor, relay.expr.Var):
281+
return tensor.type_annotation.concrete_shape
282+
if isinstance(tensor, relay.expr.Constant):
283+
return tensor.data.shape
284+
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
285+
return tensor.concrete_shape
286+
if isinstance(tensor, tvm.ir.container.Array):
287+
return tensor[-1].shape
288+
if isinstance(tensor, relay.expr.Call):
289+
return tensor.checked_type.shape
290+
raise TypeError("Unsupport data type: %s" % type(tensor))
291+
224292

225-
if params:
226-
mod["main"] = bind_params_by_name(mod["main"], params)
227-
seq = tvm.transform.Sequential(
228-
[
229-
transform.CanonicalizeOps(),
230-
transform.InferType(),
231-
transform.SimplifyInference(),
232-
transform.FoldConstant(),
233-
transform.FoldScaleAxis(),
234-
# fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
235-
transform.SimplifyExpr(),
236-
transform.FoldConstant(),
237-
transform.MergeComposite(pattern_table()),
238-
transform.AnnotateTarget("dnnl"),
239-
transform.MergeCompilerRegions(),
240-
transform.PartitionGraph(),
241-
]
293+
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
294+
"""Transfer layout, denoted with `a, b, c, d, e`,
295+
into valid layout (NCHW / OIHW) of TVM."""
296+
if "Conv1D" in conv_type:
297+
data_dic = {"a": "N", "b": "C", "c": "W"}
298+
weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"}
299+
elif "Conv2D" in conv_type:
300+
data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"}
301+
weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"}
302+
if "e" in input_data:
303+
weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"}
304+
elif "Conv3D" in conv_type:
305+
data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"}
306+
weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"}
307+
308+
dic = weight_dic if is_weight else data_dic
309+
res = ""
310+
311+
for i in input_data:
312+
if i.isupper():
313+
i = i.lower()
314+
res += dic[i]
315+
dic[i] = dic[i].lower()
316+
elif i.islower():
317+
res += dic[i]
318+
elif i.isdigit():
319+
res += i
320+
else:
321+
raise ValueError("Unsupport layout format: %s" % input_data)
322+
return res
323+
324+
325+
def legalize_group_conv(attrs, inputs, types):
326+
"""Legalize group conv / conv_transpose calculation.
327+
Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
328+
groups = attrs.groups
329+
data, weight = inputs
330+
if groups == 1:
331+
if "Transpose" not in type(attrs).__name__:
332+
return relay.nn.conv2d(data, weight, **attrs)
333+
return relay.nn.conv2d_transpose(data, weight, **attrs)
334+
OC, IC, H, W = get_shape(weight)
335+
new_attrs = dict(attrs)
336+
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
337+
if "Transpose" not in type(attrs).__name__:
338+
new_attrs["kernel_layout"] = "GOIHW"
339+
return relay.nn.conv2d(data, weight, **new_attrs)
340+
new_attrs["kernel_layout"] = "GIOHW"
341+
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
342+
343+
344+
def alter_conv(attrs, inputs, tinfos, out_type):
345+
"""The convolution's layout auto-query func for dnnl."""
346+
347+
data, weight = inputs
348+
groups = str(attrs.groups)
349+
weight_shape = ",".join([str(x) for x in get_shape(weight)])
350+
out_shape = ",".join([str(x) for x in get_shape(out_type)])
351+
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
352+
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
353+
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
354+
new_attrs = dict(attrs)
355+
conv_type = type(attrs).__name__.split("Attrs")[0]
356+
357+
res = get_optimal_layout_for_conv(
358+
attrs["data_layout"],
359+
attrs["kernel_layout"],
360+
weight_shape,
361+
out_shape,
362+
paddings,
363+
strides,
364+
dilates,
365+
groups,
242366
)
243-
with tvm.transform.PassContext(opt_level=3):
244-
mod = seq(mod)
245-
return mod
367+
src_df, weight_df, dst_df = res.split(",")
368+
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
369+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
370+
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
371+
372+
if conv_type == "Conv1D":
373+
return relay.nn.conv1d(data, weight, **new_attrs)
374+
if conv_type == "Conv2D":
375+
return relay.nn.conv2d(data, weight, **new_attrs)
376+
return relay.nn.conv3d(data, weight, **new_attrs)
377+
378+
379+
def alter_conv_transpose(attrs, inputs, tinfos, out_type):
380+
"""The transposed convolution's layout auto-query func for dnnl."""
381+
382+
data, weight = inputs
383+
weight_shape = ",".join([str(x) for x in get_shape(weight)])
384+
out_shape = ",".join([str(x) for x in get_shape(out_type)])
385+
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
386+
output_paddings = ",".join([str(x) for x in attrs.get_int_tuple("output_padding")])
387+
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
388+
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
389+
groups = str(attrs.groups)
390+
new_attrs = dict(attrs)
391+
conv_type = type(attrs).__name__.split("Attrs")[0]
392+
393+
res = get_optimal_layout_for_conv_transpose(
394+
attrs["data_layout"],
395+
attrs["kernel_layout"],
396+
weight_shape,
397+
out_shape,
398+
paddings,
399+
output_paddings,
400+
strides,
401+
dilates,
402+
groups,
403+
)
404+
src_df, weight_df, dst_df = res.split(",")
405+
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
406+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
407+
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
408+
409+
if conv_type == "Conv1DTranspose":
410+
return relay.nn.conv1d_transpose(data, weight, **new_attrs)
411+
if conv_type == "Conv2DTranspose":
412+
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
413+
return relay.nn.conv3d_transpose(data, weight, **new_attrs)

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -445,14 +445,30 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
445445
{"relu", "nn.relu"},
446446
{"tanh", "tanh"},
447447
{"sigmoid", "sigmoid"},
448+
{"nn.deconv2d", "nn.conv2d_transpose"},
449+
{"nn.deconv3d", "nn.conv3d_transpose"},
448450
};
449451

450-
std::vector<std::string> ParsingOpList(std::string op, std::string pattern_name) {
451-
std::vector<std::string> op_list = {"nn." + op};
452-
for (auto& t : op_map) {
453-
if (pattern_name.find(t.first) != std::string::npos) {
454-
op_list.push_back(t.second);
452+
std::vector<std::string> ParsingOpList(const std::string& pattern_name,
453+
std::string interval = "_") {
454+
ICHECK_NE(pattern_name, "");
455+
std::vector<std::string> op_list;
456+
size_t pos = 0, start = 0;
457+
while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
458+
std::string op_name = pattern_name.substr(start, pos - start);
459+
if (op_name.find("dnnl") != std::string::npos) {
460+
op_name.replace(op_name.find("dnnl"), 4, "nn");
461+
if (op_name.find("deconv") != std::string::npos) {
462+
op_name = op_map[op_name];
463+
}
464+
} else {
465+
op_name = op_map[op_name];
455466
}
467+
if (pos > start) op_list.push_back(op_name);
468+
start = pos + interval.size();
469+
}
470+
if (pattern_name.size() > start) {
471+
op_list.push_back(op_map[pattern_name.substr(start)]);
456472
}
457473
return op_list;
458474
}
@@ -471,28 +487,28 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
471487
ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
472488
name = comp.value();
473489

474-
if (name.find("dnnl.conv2d_transpose") != std::string::npos) {
475-
std::vector<std::string> op_list = ParsingOpList("conv2d_transpose", name);
490+
if (name.find("dnnl.deconv2d") != std::string::npos) {
491+
std::vector<std::string> op_list = ParsingOpList(name);
476492
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
477493
ICHECK(call->op.as<OpNode>()) << "Not op node";
478-
} else if (name.find("dnnl.conv3d_transpose") != std::string::npos) {
479-
std::vector<std::string> op_list = ParsingOpList("conv3d_transpose", name);
494+
} else if (name.find("dnnl.deconv3d") != std::string::npos) {
495+
std::vector<std::string> op_list = ParsingOpList(name);
480496
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
481497
ICHECK(call->op.as<OpNode>()) << "Not op node";
482498
} else if (name.find("dnnl.conv1d") != std::string::npos) {
483-
std::vector<std::string> op_list = ParsingOpList("conv1d", name);
499+
std::vector<std::string> op_list = ParsingOpList(name);
484500
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
485501
ICHECK(call->op.as<OpNode>()) << "Not op node";
486502
} else if (name.find("dnnl.conv2d") != std::string::npos) {
487-
std::vector<std::string> op_list = ParsingOpList("conv2d", name);
503+
std::vector<std::string> op_list = ParsingOpList(name);
488504
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
489505
ICHECK(call->op.as<OpNode>()) << "Not op node";
490506
} else if (name.find("dnnl.conv3d") != std::string::npos) {
491-
std::vector<std::string> op_list = ParsingOpList("conv3d", name);
507+
std::vector<std::string> op_list = ParsingOpList(name);
492508
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
493509
ICHECK(call->op.as<OpNode>()) << "Not op node";
494510
} else if (name.find("dnnl.dense") != std::string::npos) {
495-
std::vector<std::string> op_list = ParsingOpList("dense", name);
511+
std::vector<std::string> op_list = ParsingOpList(name);
496512
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
497513
ICHECK(call->op.as<OpNode>()) << "Not op node";
498514
} else {

0 commit comments

Comments
 (0)