@@ -320,7 +320,10 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
320320 }
321321
322322 // old_in, new_in = state[inputs]
323- Array<Layout> old_in, old_out, new_in, new_out, new_in2;
323+ // naming rule:
324+ // old_in, new_in: the input layouts given by downstream node.
325+ // old_in2, new_in2: the input layouts inferred by the current node.
326+ Array<Layout> old_in, old_in2, old_out, new_in, new_out, new_in2;
324327 for (auto inp : inputs) {
325328 old_in.push_back (inp->old_layout );
326329 new_in.push_back (inp->new_layout );
@@ -336,17 +339,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
336339 InferCorrectLayoutOutput infer_out;
337340 std::tie (infer_out, success) =
338341 InferCorrectLayouts (ref_call, Array<Layout>(nullptr ), old_in, types);
339- old_in = infer_out->input_layouts ;
342+ old_in2 = infer_out->input_layouts ;
340343 old_out = infer_out->output_layouts ;
341344 if (!success) {
342345 return Expr (nullptr );
343346 }
344- ICHECK_EQ (old_in .size (), new_in.size ());
347+ ICHECK_EQ (old_in2 .size (), new_in.size ());
345348
346- // if new_in == 'undef': new_in = old_in
347- for (size_t i = 0 ; i < new_in.size (); ++i) {
348- if (!new_in[i].defined ()) {
349- new_in.Set (i, old_in[i]);
349+ Array<Layout> new_in_tmp = new_in; // for backward compatibility of InferCorrectLayouts
350+ // if new_in_tmp == 'undef': new_in_tmp = old_in2
351+ for (size_t i = 0 ; i < new_in_tmp.size (); ++i) {
352+ if (!new_in_tmp[i].defined ()) {
353+ new_in_tmp.Set (i, old_in2[i]);
350354 }
351355 }
352356
@@ -356,7 +360,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
356360 // new_in2, new_out = op.infer(new_in)
357361 if (new_call->op ->IsInstance <OpNode>()) {
358362 success = false ;
359- std::tie (infer_out, success) = InferCorrectLayouts (new_call, new_in, old_in , types);
363+ std::tie (infer_out, success) = InferCorrectLayouts (new_call, new_in_tmp, old_in2 , types);
360364 new_in2 = infer_out->input_layouts ;
361365 new_out = infer_out->output_layouts ;
362366 if (!success) {
@@ -371,6 +375,17 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
371375 ICHECK_EQ (new_in.size (), new_in2.size ())
372376 << " The number of input nodes should keep the same during alter_op_layout" ;
373377
378+ auto transform_layout = [&memorizer](Expr arg_item, const Layout& old_in, const Layout& old_in2,
379+ const Layout& new_in, const Layout& new_in2) {
380+ if (old_in2.Equals (old_in)) { // the two transforms can be fused to one
381+ arg_item = memorizer.Transform (arg_item, new_in, new_in2);
382+ } else {
383+ if (old_in.defined ()) arg_item = memorizer.Transform (arg_item, new_in, old_in);
384+ arg_item = memorizer.Transform (arg_item, old_in2, new_in2);
385+ }
386+ return arg_item;
387+ };
388+
374389 // if (new_in != new_in2): insert transform (new_in -> new_in2)
375390 Array<Expr> transformed_args;
376391 size_t pt = 0 ;
@@ -380,12 +395,14 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
380395 Array<Expr> transformed_tuple_arg;
381396 transformed_tuple_arg.reserve (tuple_arg->fields .size ());
382397 for (auto arg_item : tuple_arg->fields ) {
383- transformed_tuple_arg.push_back (memorizer.Transform (arg_item, new_in[pt], new_in2[pt]));
398+ transformed_tuple_arg.push_back (
399+ transform_layout (arg_item, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
384400 pt++;
385401 }
386402 transformed_args.push_back (WithFields (tuple_arg, transformed_tuple_arg));
387403 } else {
388- transformed_args.push_back (memorizer.Transform (arg, new_in[pt], new_in2[pt]));
404+ transformed_args.push_back (
405+ transform_layout (arg, old_in[pt], old_in2[pt], new_in[pt], new_in2[pt]));
389406 pt++;
390407 }
391408 }
0 commit comments