Skip to content

Commit e53cbe4

Browse files
authored
Fix LayoutRewriter (#10118)
* Fix layout pass * add unit test * fix lint * fix lint * fix lint
1 parent 8ce1b6c commit e53cbe4

File tree

3 files changed

+40
-22
lines changed

3 files changed

+40
-22
lines changed

src/relay/op/nn/nn.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,6 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
210210
const Array<Layout>& new_in_layouts,
211211
const Array<Layout>& old_in_layouts,
212212
const Array<tvm::relay::Type>& old_in_types) {
213-
// Respect input layout, if explicitly specified (for example, "NW").
214-
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) {
215-
return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs);
216-
}
217213
return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
218214
}
219215

@@ -283,14 +279,6 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
283279
const Array<tvm::relay::Type>& old_in_types) {
284280
auto params = attrs.as<DensePackAttrs>();
285281
ICHECK(params);
286-
// Respect input layout, if explicitly specified (for example, "NW").
287-
// However, a packed layout such as "NC8c" is not supported by dense_pack op. For such cases,
288-
// we insert a layout transform "NC8c" -> "NC".
289-
// We do not expect to get a packed layout like "NW8w", which is not compatitble with "NC",
290-
// since packing is always done on the "C" axis.
291-
if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && new_in_layouts[0].ndim() == 2) {
292-
return InferCorrectLayoutOutput({new_in_layouts[0], params->weight_layout}, {"NC"}, attrs);
293-
}
294282
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
295283
}
296284

src/relay/transforms/transform_layout.h

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
320320
}
321321

322322
// old_in, new_in = state[inputs]
323-
Array<Layout> old_in, old_out, new_in, new_out, new_in2;
323+
// naming rule:
324+
// old_in, new_in: the input layouts given by downstream node.
325+
// old_in2, new_in2: the input layouts inferred by the current node.
326+
Array<Layout> old_in, old_in2, old_out, new_in, new_out, new_in2;
324327
for (auto inp : inputs) {
325328
old_in.push_back(inp->old_layout);
326329
new_in.push_back(inp->new_layout);
@@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
336339
InferCorrectLayoutOutput infer_out;
337340
std::tie(infer_out, success) =
338341
InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, types);
339-
old_in = infer_out->input_layouts;
342+
old_in2 = infer_out->input_layouts;
340343
old_out = infer_out->output_layouts;
341344
if (!success) {
342345
return Expr(nullptr);
343346
}
344-
ICHECK_EQ(old_in.size(), new_in.size());
347+
ICHECK_EQ(old_in2.size(), new_in.size());
345348

346-
// if new_in == 'undef': new_in = old_in
347-
for (size_t i = 0; i < new_in.size(); ++i) {
348-
if (!new_in[i].defined()) {
349-
new_in.Set(i, old_in[i]);
349+
Array<Layout> new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts
350+
// if new_in_tmp == 'undef': new_in_tmp = old_in2
351+
for (size_t i = 0; i < new_in_tmp.size(); ++i) {
352+
if (!new_in_tmp[i].defined()) {
353+
new_in_tmp.Set(i, old_in2[i]);
350354
}
351355
}
352356

@@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
356360
// new_in2, new_out = op.infer(new_in)
357361
if (new_call->op->IsInstance<OpNode>()) {
358362
success = false;
359-
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types);
363+
std::tie(infer_out, success) = InferCorrectLayouts(new_call, new_in_tmp, old_in2, types);
360364
new_in2 = infer_out->input_layouts;
361365
new_out = infer_out->output_layouts;
362366
if (!success) {
@@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
371375
ICHECK_EQ(new_in.size(), new_in2.size())
372376
<< "The number of input nodes should keep the same during alter_op_layout";
373377

378+
auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2,
379+
const Layout& new_in, const Layout& new_in2) {
380+
if (old_in2.Equals(old_in)) { // the two transforms can be fused to one
381+
arg_item = memorizer.Transform(arg_item, new_in, new_in2);
382+
} else {
383+
if (old_in.defined()) arg_item = memorizer.Transform(arg_item, new_in, old_in);
384+
arg_item = memorizer.Transform(arg_item, old_in2, new_in2);
385+
}
386+
return arg_item;
387+
};
388+
374389
// if (new_in != new_in2): insert transform (new_in -> new_in2)
375390
Array<Expr> transformed_args;
376391
size_t pt = 0;
@@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
380395
Array<Expr> transformed_tuple_arg;
381396
transformed_tuple_arg.reserve(tuple_arg->fields.size());
382397
for (auto arg_item : tuple_arg->fields) {
383-
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
398+
transformed_tuple_arg.push_back(
399+
transform_layout(arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
384400
pt++;
385401
}
386402
transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg));
387403
} else {
388-
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
404+
transformed_args.push_back(
405+
transform_layout(arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
389406
pt++;
390407
}
391408
}

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,5 +1471,18 @@ def test_conv2d_reduce_channels():
14711471
relay.build(mod, params=params, target="llvm")
14721472

14731473

1474+
def test_axis_semantic_change():
1475+
x = relay.var("x", shape=(1, 1, 24, 48))
1476+
w1 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
1477+
w2 = relay.const(np.random.uniform(size=(1, 1, 1, 1)))
1478+
y = relay.nn.conv2d(x, w1, kernel_size=(1, 1), padding=(0, 0), channels=1)
1479+
y = relay.transpose(y, (0, 1, 3, 2))
1480+
z = relay.nn.conv2d(y, w2, kernel_size=(1, 1), padding=(0, 0), channels=1)
1481+
func = relay.Function([x], z)
1482+
mod = tvm.IRModule.from_expr(func)
1483+
with tvm.transform.PassContext(opt_level=3):
1484+
relay.build(mod, target="llvm")
1485+
1486+
14741487
if __name__ == "__main__":
14751488
pytest.main([__file__])

0 commit comments

Comments
 (0)