Skip to content

Commit 57e24e5

Browse files
author
Ivy
committed
fix lint
1 parent e0c78cc commit 57e24e5

File tree

3 files changed

+56
-49
lines changed

3 files changed

+56
-49
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,16 @@ def pattern_table():
211211
return dnnl_patterns
212212

213213

214-
def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
215-
out_shape, paddings, strides, dilates, groups):
214+
def get_optimal_layout_for_conv(
215+
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
216+
):
216217
"""Get the optimal layout of dnnl, given shape of conv2d.
217218
218219
Parameters
219220
----------
220-
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
221-
Input argument.
221+
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
222+
: String
223+
Input argument.
222224
223225
Returns
224226
-------
@@ -238,13 +240,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
238240

239241

240242
def get_optimal_layout_for_conv_transpose(
241-
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
243+
data_layout,
244+
kernel_layout,
245+
weight_shape,
246+
out_shape,
247+
paddings,
248+
output_paddings,
249+
strides,
250+
dilates,
251+
groups,
242252
):
243253
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
244254
245255
Parameters
246256
----------
247-
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
257+
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
258+
dilates, groups
248259
: Int, String
249260
Input argument.
250261
@@ -255,7 +266,7 @@ def get_optimal_layout_for_conv_transpose(
255266
"""
256267
return _ffi_api.get_optimal_layout_for_conv_transpose(
257268
data_layout,
258-
kernel_layout,
269+
kernel_layout,
259270
weight_shape,
260271
out_shape,
261272
paddings,
@@ -270,16 +281,15 @@ def get_shape(tensor):
270281
"""Get tensor's shape."""
271282
if isinstance(tensor, relay.expr.Var):
272283
return tensor.type_annotation.concrete_shape
273-
elif isinstance(tensor, relay.expr.Constant):
284+
if isinstance(tensor, relay.expr.Constant):
274285
return tensor.data.shape
275-
elif isinstance(tensor, tvm.ir.tensor_type.TensorType):
286+
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
276287
return tensor.concrete_shape
277-
elif isinstance(tensor, tvm.ir.container.Array):
288+
if isinstance(tensor, tvm.ir.container.Array):
278289
return tensor[-1].shape
279-
elif isinstance(tensor, relay.expr.Call):
290+
if isinstance(tensor, relay.expr.Call):
280291
return tensor.checked_type.shape
281-
else:
282-
raise TypeError("Unsupport data type: %s" % type(tensor))
292+
raise TypeError("Unsupport data type: %s" % type(tensor))
283293

284294

285295
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
@@ -318,16 +328,14 @@ def legalize_group_conv(attrs, inputs, types):
318328
"""Legalize group conv / conv_transpose calculation.
319329
Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
320330
groups = attrs.groups
321-
if groups == 1:
322-
return
323-
data, weight = inputs
324-
OC, IC, H, W = get_shape(weight)
325-
new_attrs = dict(attrs)
326-
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
327-
if "Transpose" not in type(attrs).__name__:
328-
new_attrs["kernel_layout"] = "GOIHW"
329-
return relay.nn.conv2d(data, weight, **new_attrs)
330-
else:
331+
if groups > 1:
332+
data, weight = inputs
333+
OC, IC, H, W = get_shape(weight)
334+
new_attrs = dict(attrs)
335+
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
336+
if "Transpose" not in type(attrs).__name__:
337+
new_attrs["kernel_layout"] = "GOIHW"
338+
return relay.nn.conv2d(data, weight, **new_attrs)
331339
new_attrs["kernel_layout"] = "GIOHW"
332340
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
333341

@@ -346,21 +354,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
346354
conv_type = type(attrs).__name__.split("Attrs")[0]
347355

348356
res = get_optimal_layout_for_conv(
349-
attrs["data_layout"], attrs["kernel_layout"], weight_shape, out_shape, paddings,
350-
strides, dilates, groups,
357+
attrs["data_layout"],
358+
attrs["kernel_layout"],
359+
weight_shape,
360+
out_shape,
361+
paddings,
362+
strides,
363+
dilates,
364+
groups,
351365
)
352366
src_df, weight_df, dst_df = res.split(",")
353367
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
354-
new_attrs["kernel_layout"] = tag2layout(
355-
weight_df, is_weight=True, conv_type=conv_type
356-
)
368+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
357369
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
358370

359371
if conv_type == "Conv1D":
360372
return relay.nn.conv1d(data, weight, **new_attrs)
361-
elif conv_type == "Conv2D":
373+
if conv_type == "Conv2D":
362374
return relay.nn.conv2d(data, weight, **new_attrs)
363-
elif conv_type == "Conv3D":
375+
if conv_type == "Conv3D":
364376
return relay.nn.conv3d(data, weight, **new_attrs)
365377

366378

@@ -380,7 +392,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
380392

381393
res = get_optimal_layout_for_conv_transpose(
382394
attrs["data_layout"],
383-
attrs["kernel_layout"],
395+
attrs["kernel_layout"],
384396
weight_shape,
385397
out_shape,
386398
paddings,
@@ -391,16 +403,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
391403
)
392404
src_df, weight_df, dst_df = res.split(",")
393405
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
394-
new_attrs["kernel_layout"] = tag2layout(
395-
weight_df, is_weight=True, conv_type=conv_type
396-
)
406+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
397407
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
398408

399409
if conv_type == "Conv1DTranspose":
400410
return relay.nn.conv1d_transpose(data, weight, **new_attrs)
401-
elif conv_type == "Conv2DTranspose":
411+
if conv_type == "Conv2DTranspose":
402412
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
403-
elif conv_type == "Conv3DTranspose":
413+
if conv_type == "Conv3DTranspose":
404414
return relay.nn.conv3d_transpose(data, weight, **new_attrs)
405415

406416

@@ -418,10 +428,10 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
418428
mod : Module
419429
Annotated and partitioned module.
420430
"""
431+
from tvm.relay.testing.temp_op_attr import TempOpAttr
421432

422433
if params:
423434
mod["main"] = bind_params_by_name(mod["main"], params)
424-
from tvm.relay.testing.temp_op_attr import TempOpAttr
425435

426436
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_group_conv):
427437
with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", legalize_group_conv):
@@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
443453
with tvm.transform.PassContext(opt_level=3):
444454
mod = seq(mod)
445455
if alter_layout:
446-
from tvm.relay.testing.temp_op_attr import TempOpAttr
447-
448456
with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", alter_conv):
449457
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv):
450458
with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", alter_conv):

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) {
146146
return s;
147147
}
148148

149-
dnnl::memory::dims str2dims(const std::string& str_shape,
150-
bool dilates = false,
149+
dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
151150
std::string interval = ",") {
152151
// Split strings
153152
std::vector<std::string> str_dims;
@@ -164,21 +163,21 @@ dnnl::memory::dims str2dims(const std::string& str_shape,
164163
dnnl::memory::dims out_dims;
165164
if (dilates) {
166165
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
167-
[](const std::string& str) { return std::stoi(str) - 1; });
166+
[](const std::string& str) { return std::stoi(str) - 1; });
168167
} else {
169168
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
170-
[](const std::string& str) { return std::stoi(str); });
169+
[](const std::string& str) { return std::stoi(str); });
171170
}
172171
return out_dims;
173172
}
174173

175174
void check_shapes(const std::vector<std::string> shapes) {
176175
std::regex valid_pat("(\\d*)(,(\\d*))*");
177176
bool checked = std::regex_match(shapes[0], valid_pat);
178-
for (size_t i = 1; i < shapes.size()-1; i++) {
177+
for (size_t i = 1; i < shapes.size() - 1; i++) {
179178
checked &= std::regex_match(shapes[i], valid_pat);
180179
}
181-
checked &= std::regex_match(shapes[shapes.size()-1], std::regex("\\d*"));
180+
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
182181
if (!checked) {
183182
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
184183
}
@@ -193,7 +192,7 @@ void check_layout(bool var, bool ref) {
193192
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
194193
std::string weight_shape, std::string out_shape,
195194
std::string paddings, std::string strides,
196-
std::string dilates, std::string G) {
195+
std::string dilates, std::string G) {
197196
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
198197
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
199198
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
@@ -272,7 +271,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
272271
return res;
273272
}
274273

275-
std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string kernel_layout,
274+
std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
275+
std::string kernel_layout,
276276
std::string weight_shape, std::string out_shape,
277277
std::string paddings, std::string output_paddings,
278278
std::string strides, std::string dilates,

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
505505
padding_dims_l, padding_dims_r);
506506

507507
// Enable elementwise post-ops.
508-
auto deconv_prim_desc =
509-
dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
508+
auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
510509

511510
// Push to the network.
512511
auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);

0 commit comments

Comments
 (0)