Skip to content

Commit 3a06dcf

Browse files
author
Ivy
committed
simplify query_layout
1 parent f63e6bd commit 3a06dcf

File tree

3 files changed

+68
-62
lines changed

3 files changed

+68
-62
lines changed

python/tvm/relay/op/contrib/dnnl.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,21 +211,20 @@ def pattern_table():
211211
return dnnl_patterns
212212

213213

214-
def get_optimal_layout_for_conv(input_size, weight_shape, out_shape, paddings, strides, dilates, groups):
214+
def get_optimal_layout_for_conv(weight_shape, out_shape, paddings, strides, dilates, groups):
215215
"""Get the optimal layout of dnnl, given shape of conv2d.
216216
217217
Parameters
218218
----------
219-
input_size, weight_shape, out_shape, paddings, strides, dilates, groups : Int, String
220-
Input argument.
219+
weight_shape, out_shape, paddings, strides, dilates, groups : Int, String
220+
Input argument.
221221
222222
Returns
223223
-------
224224
layouts : string
225225
The result.
226226
"""
227227
return _ffi_api.get_optimal_layout_for_conv(
228-
input_size,
229228
weight_shape,
230229
out_shape,
231230
paddings,
@@ -236,13 +235,13 @@ def get_optimal_layout_for_conv(input_size, weight_shape, out_shape, paddings, s
236235

237236

238237
def get_optimal_layout_for_conv_transpose(
239-
input_size, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
238+
weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
240239
):
241240
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
242241
243242
Parameters
244243
----------
245-
input_size, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
244+
weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
246245
: Int, String
247246
Input argument.
248247
@@ -252,7 +251,6 @@ def get_optimal_layout_for_conv_transpose(
252251
The result.
253252
"""
254253
return _ffi_api.get_optimal_layout_for_conv_transpose(
255-
input_size,
256254
weight_shape,
257255
out_shape,
258256
paddings,
@@ -282,15 +280,15 @@ def get_shape(tensor):
282280
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
283281
"""Transfer layout, denoted with `a, b, c, d, e`,
284282
into valid layout (NCHW / OIHW) of TVM."""
285-
if conv_type == "Conv1D":
283+
if "Conv1D" in conv_type:
286284
data_dic = {"a": "N", "b": "C", "c": "W"}
287285
weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"}
288-
elif conv_type == "Conv2D":
286+
elif "Conv2D" in conv_type:
289287
data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"}
290288
weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"}
291289
if "e" in input_data:
292290
weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"}
293-
elif conv_type == "Conv3D":
291+
elif "Conv3D" in conv_type:
294292
data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"}
295293
weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"}
296294

@@ -343,7 +341,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
343341
conv_type = type(attrs).__name__.split("Attrs")[0]
344342

345343
res = get_optimal_layout_for_conv(
346-
len(get_shape(out_type)), weight_shape, out_shape, paddings, strides, dilates, groups
344+
weight_shape, out_shape, paddings, strides, dilates, groups
347345
)
348346
src_df, weight_df, dst_df = res.split(",")
349347
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
@@ -375,7 +373,6 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
375373
conv_type = type(attrs).__name__.split("Attrs")[0]
376374

377375
res = get_optimal_layout_for_conv_transpose(
378-
len(get_shape(out_type)),
379376
weight_shape,
380377
out_shape,
381378
paddings,

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -146,24 +146,33 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) {
146146
return s;
147147
}
148148

149-
dnnl::memory::dims str2dims(std::string str_shape, int input_size) {
150-
std::string str_reg = "(\\d*)";
151-
for (int i = 0; i < input_size - 1; i++) {
152-
str_reg.append(",(\\d*)");
149+
dnnl::memory::dims str2dims(const std::string& str_shape,
150+
bool dilates = false,
151+
std::string interval = ",") {
152+
// Split strings
153+
std::vector<std::string> str_dims;
154+
size_t pos = 0, start = 0;
155+
while ((pos = str_shape.find(interval, start)) != std::string::npos) {
156+
std::string str_dim = str_shape.substr(start, pos - start);
157+
if (pos > start) str_dims.push_back(str_dim);
158+
start = pos + interval.size();
153159
}
154-
std::regex rex(str_reg);
155-
std::smatch m;
160+
if (str_shape.size() > start) {
161+
str_dims.push_back(str_shape.substr(start));
162+
}
163+
// transfer string to dims
156164
dnnl::memory::dims out_dims;
157-
if (std::regex_search(str_shape, m, rex)) {
158-
std::transform(m.begin() + 1, m.end(), std::back_inserter(out_dims),
159-
[](const std::string& str) { return std::stoi(str); });
165+
if (dilates) {
166+
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
167+
[](const std::string& str) { return std::stoi(str) - 1; });
160168
} else {
161-
LOG(FATAL) << "Unsupported shape for querying optimal dnnl layout: " << str_shape;
169+
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
170+
[](const std::string& str) { return std::stoi(str); });
162171
}
163172
return out_dims;
164173
}
165174

166-
std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape,
175+
std::string get_optimal_layout_for_conv(std::string weight_shape,
167176
std::string out_shape, std::string paddings,
168177
std::string strides, std::string dilates, std::string G) {
169178
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
@@ -172,35 +181,36 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
172181
using dt = dnnl::memory::data_type;
173182

174183
dnnl::memory::dim groups = std::stoi(G);
175-
dnnl::memory::dims weight_dims_ = str2dims(weight_shape, input_size);
184+
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
176185
dnnl::memory::dims weight_dims = weight_dims_;
186+
177187
if (groups > 1) {
178188
if (weight_dims_.size() == 5) {
179-
weight_dims = {weight_dims_[0] * weight_dims_[1], weight_dims_[2], weight_dims_[3],
189+
weight_dims = {groups * weight_dims_[1], groups * weight_dims_[2], weight_dims_[3],
180190
weight_dims_[4]};
181191
} else {
182192
weight_dims[1] = weight_dims[1] * groups;
183193
}
184194
}
185-
dnnl::memory::dims out_dims = str2dims(out_shape, input_size);
186-
dnnl::memory::dims padding_dims = str2dims(paddings, 2 * (input_size - 2));
195+
196+
dnnl::memory::dims out_dims = str2dims(out_shape);
197+
dnnl::memory::dims padding_dims = str2dims(paddings);
187198
dnnl::memory::dims padding_dims_l(padding_dims.begin(),
188199
padding_dims.begin() + padding_dims.size() / 2);
189200
dnnl::memory::dims padding_dims_r(padding_dims.end() - padding_dims.size() / 2,
190201
padding_dims.end());
191-
dnnl::memory::dims strides_dims = str2dims(strides, input_size - 2);
192-
dnnl::memory::dims dilates_dims = str2dims(dilates, input_size - 2);
202+
dnnl::memory::dims strides_dims = str2dims(strides);
203+
dnnl::memory::dims dilates_dims = str2dims(dilates, true);
193204

194205
dnnl::memory::dims input_dims = out_dims;
195206
input_dims[1] = weight_dims[1];
196-
for (int i = 2; i < input_size; i++) {
207+
for (int i = 2; i < out_dims.size(); i++) {
197208
dnnl::memory::dim K = weight_dims[i];
198209
dnnl::memory::dim S = strides_dims[i - 2];
199-
dnnl::memory::dim D = dilates_dims[i - 2] - 1;
210+
dnnl::memory::dim D = dilates_dims[i - 2];
200211
dnnl::memory::dim PL = padding_dims_l[i - 2];
201212
dnnl::memory::dim PR = padding_dims_r[i - 2];
202213
dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
203-
dilates_dims[i - 2] = D;
204214
input_dims[i] = out_dims[i] * S - PL - PR + DK - 1;
205215
}
206216

@@ -210,6 +220,7 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
210220
conv_weights_dims = {groups, out_dims[1] / groups, input_dims[1] / groups};
211221
conv_weights_dims.insert(conv_weights_dims.end(), weight_dims.begin() + 2, weight_dims.end());
212222
}
223+
213224
dnnl::memory::dims conv_dst_dims = out_dims;
214225
dnnl::memory::dims conv_strides = strides_dims;
215226
dnnl::memory::dims conv_dilates = dilates_dims;
@@ -238,7 +249,7 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
238249
return res;
239250
}
240251

241-
std::string get_optimal_layout_for_conv_transpose(int input_size, std::string weight_shape,
252+
std::string get_optimal_layout_for_conv_transpose(std::string weight_shape,
242253
std::string out_shape, std::string paddings,
243254
std::string output_paddings, std::string strides,
244255
std::string dilates, std::string G) {
@@ -248,25 +259,25 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
248259
using dt = dnnl::memory::data_type;
249260

250261
dnnl::memory::dim groups = std::stoi(G);
251-
dnnl::memory::dims weight_dims_ = str2dims(weight_shape, input_size);
262+
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
252263
dnnl::memory::dims weight_dims = weight_dims_;
253264
if (groups > 1) {
254265
if (weight_dims_.size() == 5) {
255-
weight_dims = {weight_dims_[0] * weight_dims_[1], weight_dims_[2], weight_dims_[3],
266+
weight_dims = {groups * weight_dims_[1], groups * weight_dims_[2], weight_dims_[3],
256267
weight_dims_[4]};
257268
} else {
258269
weight_dims[1] = weight_dims[1] * groups;
259270
}
260271
}
261-
dnnl::memory::dims out_dims = str2dims(out_shape, input_size);
262-
dnnl::memory::dims padding_dims = str2dims(paddings, 2 * (input_size - 2));
272+
dnnl::memory::dims out_dims = str2dims(out_shape);
273+
dnnl::memory::dims padding_dims = str2dims(paddings);
263274
dnnl::memory::dims padding_dims_l(padding_dims.begin(),
264275
padding_dims.begin() + padding_dims.size() / 2);
265276
dnnl::memory::dims padding_dims_r(padding_dims.end() - padding_dims.size() / 2,
266277
padding_dims.end());
267-
dnnl::memory::dims output_padding_dims = str2dims(output_paddings, input_size - 2);
268-
dnnl::memory::dims strides_dims = str2dims(strides, input_size - 2);
269-
dnnl::memory::dims dilates_dims = str2dims(dilates, input_size - 2);
278+
dnnl::memory::dims output_padding_dims = str2dims(output_paddings);
279+
dnnl::memory::dims strides_dims = str2dims(strides);
280+
dnnl::memory::dims dilates_dims = str2dims(dilates, true);
270281

271282
dnnl::memory::dims input_dims = out_dims;
272283
if (out_dims[1] == weight_dims[0]) {
@@ -275,15 +286,14 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
275286
input_dims[1] = weight_dims[0];
276287
std::swap(weight_dims[0], weight_dims[1]);
277288
}
278-
for (int i = 2; i < input_size; i++) {
289+
for (int i = 2; i < out_dims.size(); i++) {
279290
dnnl::memory::dim K = weight_dims[i];
280291
dnnl::memory::dim S = strides_dims[i - 2];
281-
dnnl::memory::dim D = dilates_dims[i - 2] - 1;
292+
dnnl::memory::dim D = dilates_dims[i - 2];
282293
dnnl::memory::dim PL = padding_dims_l[i - 2];
283294
dnnl::memory::dim PR = padding_dims_r[i - 2];
284295
dnnl::memory::dim OP = output_padding_dims[i - 2];
285296
dnnl::memory::dim DK = 1 + (K - 1) * (D + 1);
286-
dilates_dims[i - 2] = D;
287297
input_dims[i] = (out_dims[i] - DK + PL + PR - OP) / S + 1;
288298
}
289299

@@ -325,14 +335,13 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
325335

326336
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv")
327337
.set_body([](TVMArgs args, TVMRetValue* rv) {
328-
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5],
329-
args[6]);
338+
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5]);
330339
});
331340

332341
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose")
333342
.set_body([](TVMArgs args, TVMRetValue* rv) {
334343
*rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4],
335-
args[5], args[6], args[7]);
344+
args[5], args[6]);
336345
});
337346

338347
} // namespace contrib

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
216216
return out_dims;
217217
}
218218

219-
dnnl::memory::dims TransformStr2Dims(std::vector<std::string> strs, std::string str_name) {
219+
dnnl::memory::dims TransformStr2Dims(std::vector<std::string> strs, bool dilates = false) {
220220
dnnl::memory::dims out_dims;
221-
if (str_name == "dilates") {
221+
if (dilates) {
222222
std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims),
223223
[](const std::string& str) { return std::stoi(str) - 1; });
224224
} else {
@@ -338,10 +338,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
338338
dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
339339
dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout);
340340
dnnl::memory::dims bias_dims = {channels};
341-
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides");
342-
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates");
343-
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding");
344-
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding");
341+
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
342+
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
343+
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
344+
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
345345
dnnl::memory::dims dst_dims = src_dims;
346346
dst_dims[1] = channels;
347347
weights_dims_[0] = channels;
@@ -463,11 +463,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
463463
}
464464
}
465465
dnnl::memory::dims bias_dims = {channels};
466-
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides");
467-
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates");
468-
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding");
469-
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding");
470-
dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding, "padding");
466+
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
467+
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
468+
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
469+
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
470+
dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding);
471471
dnnl::memory::dims dst_dims = src_dims;
472472
dst_dims[1] = channels;
473473
for (int i = 2; i < src_dims.size(); i++) {
@@ -675,11 +675,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
675675

676676
dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout);
677677
dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout);
678-
dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel, "kernel");
679-
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides");
680-
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates");
681-
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding");
682-
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding");
678+
dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel);
679+
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
680+
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
681+
dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l);
682+
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
683683

684684
// Memory descriptions.
685685
auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]);

0 commit comments

Comments
 (0)