Skip to content

Commit 56ef1a1

Browse files
author
Ivy
committed
fix lint
1 parent e0c78cc commit 56ef1a1

File tree

3 files changed

+54
-47
lines changed

3 files changed

+54
-47
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from tvm import relay
3939
from tvm.relay import transform
4040
from tvm.relay.build_module import bind_params_by_name
41+
from tvm.relay.testing.temp_op_attr import TempOpAttr
4142

4243
from ... import _ffi_api
4344
from ...dataflow_pattern import wildcard, is_op
@@ -211,14 +212,16 @@ def pattern_table():
211212
return dnnl_patterns
212213

213214

214-
def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
215-
out_shape, paddings, strides, dilates, groups):
215+
def get_optimal_layout_for_conv(
216+
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
217+
):
216218
"""Get the optimal layout of dnnl, given shape of conv2d.
217219
218220
Parameters
219221
----------
220-
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
221-
Input argument.
222+
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
223+
: String
224+
Input argument.
222225
223226
Returns
224227
-------
@@ -238,13 +241,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
238241

239242

240243
def get_optimal_layout_for_conv_transpose(
241-
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
244+
data_layout,
245+
kernel_layout,
246+
weight_shape,
247+
out_shape,
248+
paddings,
249+
output_paddings,
250+
strides,
251+
dilates,
252+
groups,
242253
):
243254
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
244255
245256
Parameters
246257
----------
247-
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
258+
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
259+
dilates, groups
248260
: Int, String
249261
Input argument.
250262
@@ -255,7 +267,7 @@ def get_optimal_layout_for_conv_transpose(
255267
"""
256268
return _ffi_api.get_optimal_layout_for_conv_transpose(
257269
data_layout,
258-
kernel_layout,
270+
kernel_layout,
259271
weight_shape,
260272
out_shape,
261273
paddings,
@@ -270,16 +282,15 @@ def get_shape(tensor):
270282
"""Get tensor's shape."""
271283
if isinstance(tensor, relay.expr.Var):
272284
return tensor.type_annotation.concrete_shape
273-
elif isinstance(tensor, relay.expr.Constant):
285+
if isinstance(tensor, relay.expr.Constant):
274286
return tensor.data.shape
275-
elif isinstance(tensor, tvm.ir.tensor_type.TensorType):
287+
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
276288
return tensor.concrete_shape
277-
elif isinstance(tensor, tvm.ir.container.Array):
289+
if isinstance(tensor, tvm.ir.container.Array):
278290
return tensor[-1].shape
279-
elif isinstance(tensor, relay.expr.Call):
291+
if isinstance(tensor, relay.expr.Call):
280292
return tensor.checked_type.shape
281-
else:
282-
raise TypeError("Unsupport data type: %s" % type(tensor))
293+
raise TypeError("Unsupport data type: %s" % type(tensor))
283294

284295

285296
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
@@ -318,18 +329,19 @@ def legalize_group_conv(attrs, inputs, types):
318329
"""Legalize group conv / conv_transpose calculation.
319330
Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
320331
groups = attrs.groups
321-
if groups == 1:
322-
return
323332
data, weight = inputs
333+
if groups == 1:
334+
if "Transpose" not in type(attrs).__name__:
335+
return relay.nn.conv2d(data, weight, **attrs)
336+
return relay.nn.conv2d_transpose(data, weight, **attrs)
324337
OC, IC, H, W = get_shape(weight)
325338
new_attrs = dict(attrs)
326339
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
327340
if "Transpose" not in type(attrs).__name__:
328341
new_attrs["kernel_layout"] = "GOIHW"
329342
return relay.nn.conv2d(data, weight, **new_attrs)
330-
else:
331-
new_attrs["kernel_layout"] = "GIOHW"
332-
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
343+
new_attrs["kernel_layout"] = "GIOHW"
344+
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
333345

334346

335347
def alter_conv(attrs, inputs, tinfos, out_type):
@@ -346,22 +358,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
346358
conv_type = type(attrs).__name__.split("Attrs")[0]
347359

348360
res = get_optimal_layout_for_conv(
349-
attrs["data_layout"], attrs["kernel_layout"], weight_shape, out_shape, paddings,
350-
strides, dilates, groups,
361+
attrs["data_layout"],
362+
attrs["kernel_layout"],
363+
weight_shape,
364+
out_shape,
365+
paddings,
366+
strides,
367+
dilates,
368+
groups,
351369
)
352370
src_df, weight_df, dst_df = res.split(",")
353371
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
354-
new_attrs["kernel_layout"] = tag2layout(
355-
weight_df, is_weight=True, conv_type=conv_type
356-
)
372+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
357373
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
358374

359375
if conv_type == "Conv1D":
360376
return relay.nn.conv1d(data, weight, **new_attrs)
361-
elif conv_type == "Conv2D":
377+
if conv_type == "Conv2D":
362378
return relay.nn.conv2d(data, weight, **new_attrs)
363-
elif conv_type == "Conv3D":
364-
return relay.nn.conv3d(data, weight, **new_attrs)
379+
return relay.nn.conv3d(data, weight, **new_attrs)
365380

366381

367382
def alter_conv_transpose(attrs, inputs, tinfos, out_type):
@@ -380,7 +395,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
380395

381396
res = get_optimal_layout_for_conv_transpose(
382397
attrs["data_layout"],
383-
attrs["kernel_layout"],
398+
attrs["kernel_layout"],
384399
weight_shape,
385400
out_shape,
386401
paddings,
@@ -391,17 +406,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
391406
)
392407
src_df, weight_df, dst_df = res.split(",")
393408
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
394-
new_attrs["kernel_layout"] = tag2layout(
395-
weight_df, is_weight=True, conv_type=conv_type
396-
)
409+
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
397410
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)
398411

399412
if conv_type == "Conv1DTranspose":
400413
return relay.nn.conv1d_transpose(data, weight, **new_attrs)
401-
elif conv_type == "Conv2DTranspose":
414+
if conv_type == "Conv2DTranspose":
402415
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
403-
elif conv_type == "Conv3DTranspose":
404-
return relay.nn.conv3d_transpose(data, weight, **new_attrs)
416+
return relay.nn.conv3d_transpose(data, weight, **new_attrs)
405417

406418

407419
def partition_for_dnnl(mod, params=None, alter_layout=True):
@@ -418,10 +430,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
418430
mod : Module
419431
Annotated and partitioned module.
420432
"""
421-
422433
if params:
423434
mod["main"] = bind_params_by_name(mod["main"], params)
424-
from tvm.relay.testing.temp_op_attr import TempOpAttr
425435

426436
with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_group_conv):
427437
with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", legalize_group_conv):
@@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
443453
with tvm.transform.PassContext(opt_level=3):
444454
mod = seq(mod)
445455
if alter_layout:
446-
from tvm.relay.testing.temp_op_attr import TempOpAttr
447-
448456
with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", alter_conv):
449457
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv):
450458
with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", alter_conv):

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) {
146146
return s;
147147
}
148148

149-
dnnl::memory::dims str2dims(const std::string& str_shape,
150-
bool dilates = false,
149+
dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
151150
std::string interval = ",") {
152151
// Split strings
153152
std::vector<std::string> str_dims;
@@ -164,21 +163,21 @@ dnnl::memory::dims str2dims(const std::string& str_shape,
164163
dnnl::memory::dims out_dims;
165164
if (dilates) {
166165
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
167-
[](const std::string& str) { return std::stoi(str) - 1; });
166+
[](const std::string& str) { return std::stoi(str) - 1; });
168167
} else {
169168
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
170-
[](const std::string& str) { return std::stoi(str); });
169+
[](const std::string& str) { return std::stoi(str); });
171170
}
172171
return out_dims;
173172
}
174173

175174
void check_shapes(const std::vector<std::string> shapes) {
176175
std::regex valid_pat("(\\d*)(,(\\d*))*");
177176
bool checked = std::regex_match(shapes[0], valid_pat);
178-
for (size_t i = 1; i < shapes.size()-1; i++) {
177+
for (size_t i = 1; i < shapes.size() - 1; i++) {
179178
checked &= std::regex_match(shapes[i], valid_pat);
180179
}
181-
checked &= std::regex_match(shapes[shapes.size()-1], std::regex("\\d*"));
180+
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
182181
if (!checked) {
183182
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
184183
}
@@ -193,7 +192,7 @@ void check_layout(bool var, bool ref) {
193192
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
194193
std::string weight_shape, std::string out_shape,
195194
std::string paddings, std::string strides,
196-
std::string dilates, std::string G) {
195+
std::string dilates, std::string G) {
197196
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
198197
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
199198
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
@@ -272,7 +271,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
272271
return res;
273272
}
274273

275-
std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string kernel_layout,
274+
std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
275+
std::string kernel_layout,
276276
std::string weight_shape, std::string out_shape,
277277
std::string paddings, std::string output_paddings,
278278
std::string strides, std::string dilates,

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
505505
padding_dims_l, padding_dims_r);
506506

507507
// Enable elementwise post-ops.
508-
auto deconv_prim_desc =
509-
dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
508+
auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
510509

511510
// Push to the network.
512511
auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);

0 commit comments

Comments
 (0)