Skip to content

Commit e0c78cc

Browse files
author
Ivy
committed
add checkes for query_layout
1 parent 3a06dcf commit e0c78cc

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,23 @@ def pattern_table():
211211
return dnnl_patterns
212212

213213

214-
def get_optimal_layout_for_conv(weight_shape, out_shape, paddings, strides, dilates, groups):
214+
def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
215+
out_shape, paddings, strides, dilates, groups):
215216
"""Get the optimal layout of dnnl, given shape of conv2d.
216217
217218
Parameters
218219
----------
219-
weight_shape, out_shape, paddings, strides, dilates, groups : Int, String
220-
Input argument.
220+
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
221+
Input argument.
221222
222223
Returns
223224
-------
224225
layouts : string
225226
The result.
226227
"""
227228
return _ffi_api.get_optimal_layout_for_conv(
229+
data_layout,
230+
kernel_layout,
228231
weight_shape,
229232
out_shape,
230233
paddings,
@@ -235,13 +238,13 @@ def get_optimal_layout_for_conv(weight_shape, out_shape, paddings, strides, dila
235238

236239

237240
def get_optimal_layout_for_conv_transpose(
238-
weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
241+
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
239242
):
240243
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
241244
242245
Parameters
243246
----------
244-
weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
247+
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
245248
: Int, String
246249
Input argument.
247250
@@ -251,6 +254,8 @@ def get_optimal_layout_for_conv_transpose(
251254
The result.
252255
"""
253256
return _ffi_api.get_optimal_layout_for_conv_transpose(
257+
data_layout,
258+
kernel_layout,
254259
weight_shape,
255260
out_shape,
256261
paddings,
@@ -341,7 +346,8 @@ def alter_conv(attrs, inputs, tinfos, out_type):
341346
conv_type = type(attrs).__name__.split("Attrs")[0]
342347

343348
res = get_optimal_layout_for_conv(
344-
weight_shape, out_shape, paddings, strides, dilates, groups
349+
attrs["data_layout"], attrs["kernel_layout"], weight_shape, out_shape, paddings,
350+
strides, dilates, groups,
345351
)
346352
src_df, weight_df, dst_df = res.split(",")
347353
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
@@ -373,6 +379,8 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
373379
conv_type = type(attrs).__name__.split("Attrs")[0]
374380

375381
res = get_optimal_layout_for_conv_transpose(
382+
attrs["data_layout"],
383+
attrs["kernel_layout"],
376384
weight_shape,
377385
out_shape,
378386
paddings,

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,32 @@ dnnl::memory::dims str2dims(const std::string& str_shape,
172172
return out_dims;
173173
}
174174

175-
std::string get_optimal_layout_for_conv(std::string weight_shape,
176-
std::string out_shape, std::string paddings,
177-
std::string strides, std::string dilates, std::string G) {
175+
void check_shapes(const std::vector<std::string> shapes) {
176+
std::regex valid_pat("(\\d*)(,(\\d*))*");
177+
bool checked = std::regex_match(shapes[0], valid_pat);
178+
for (size_t i = 1; i < shapes.size()-1; i++) {
179+
checked &= std::regex_match(shapes[i], valid_pat);
180+
}
181+
checked &= std::regex_match(shapes[shapes.size()-1], std::regex("\\d*"));
182+
if (!checked) {
183+
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
184+
}
185+
}
186+
187+
void check_layout(bool var, bool ref) {
188+
if (var != ref) {
189+
LOG(FATAL) << "Invalid input layout for query dnnl optimal layout.";
190+
}
191+
}
192+
193+
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
194+
std::string weight_shape, std::string out_shape,
195+
std::string paddings, std::string strides,
196+
std::string dilates, std::string G) {
197+
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
198+
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
199+
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
200+
178201
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
179202
dnnl::stream s(eng);
180203
using tag = dnnl::memory::format_tag;
@@ -249,10 +272,15 @@ std::string get_optimal_layout_for_conv(std::string weight_shape,
249272
return res;
250273
}
251274

252-
std::string get_optimal_layout_for_conv_transpose(std::string weight_shape,
253-
std::string out_shape, std::string paddings,
254-
std::string output_paddings, std::string strides,
255-
std::string dilates, std::string G) {
275+
std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string kernel_layout,
276+
std::string weight_shape, std::string out_shape,
277+
std::string paddings, std::string output_paddings,
278+
std::string strides, std::string dilates,
279+
std::string G) {
280+
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
281+
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
282+
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
283+
256284
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
257285
dnnl::stream s(eng);
258286
using tag = dnnl::memory::format_tag;
@@ -335,13 +363,14 @@ std::string get_optimal_layout_for_conv_transpose(std::string weight_shape,
335363

336364
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv")
337365
.set_body([](TVMArgs args, TVMRetValue* rv) {
338-
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5]);
366+
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5],
367+
args[6], args[7]);
339368
});
340369

341370
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose")
342371
.set_body([](TVMArgs args, TVMRetValue* rv) {
343372
*rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4],
344-
args[5], args[6]);
373+
args[5], args[6], args[7], args[8]);
345374
});
346375

347376
} // namespace contrib

0 commit comments

Comments
 (0)