@@ -308,30 +308,59 @@ InferLayoutOutput InferLayoutConv2d(const Call& call,
308308 Layout desired_data_layout = (*it).second [0 ];
309309 Layout desired_weight_layout = (*it).second [1 ];
310310 Layout desired_output_layout = (*it).second .size () == 3 ? (*it).second [2 ] : (*it).second [0 ];
311- ICHECK_EQ (desired_data_layout.ndim (), desired_data_layout.ndim_primal ()) << " Axis swap only" ;
312- ICHECK_EQ (desired_weight_layout.ndim (), desired_weight_layout.ndim_primal ())
313- << " Axis swap only" ;
314- ICHECK_EQ (desired_output_layout.ndim (), desired_output_layout.ndim_primal ())
315- << " Axis swap only" ;
316- data_layout = TransposeLike (InitialLayout (4 ), attrs->data_layout , desired_data_layout);
317- weight_layout = TransposeLike (InitialLayout (4 ), attrs->kernel_layout , desired_weight_layout);
318- output_layout = TransposeLike (InitialLayout (4 ), attrs->out_layout , desired_output_layout);
319- new_attrs->data_layout = (*it).second [0 ];
320- new_attrs->kernel_layout = (*it).second [1 ];
321- new_attrs->out_layout = (*it).second .size () == 3 ? (*it).second [2 ] : (*it).second [0 ];
322- } else {
323- // We don't have a desired layout for conv2d.
324- // We can just propagate the layout from the input.
325- data_layout = GetLayoutDecision (var_layout_map, call->args [0 ]);
326- weight_layout = GetLayoutDecision (var_layout_map, call->args [1 ]);
327- output_layout = data_layout;
328- new_attrs->data_layout =
329- TransposeLike (attrs->data_layout , InitialLayout (4 ), data_layout->layout ).name ();
330- new_attrs->kernel_layout =
331- TransposeLike (attrs->kernel_layout , InitialLayout (4 ), weight_layout->layout ).name ();
332- new_attrs->out_layout =
333- TransposeLike (attrs->out_layout , InitialLayout (4 ), output_layout->layout ).name ();
311+ tir::Layout input_layout (attrs->data_layout , DataType::Int (64 ));
312+ tir::Layout kernel_layout (attrs->kernel_layout , DataType::Int (64 ));
313+ tir::Layout out_layout (attrs->out_layout , DataType::Int (64 ));
314+
315+ if ((desired_data_layout.ndim () == input_layout.ndim ()) &&
316+ (desired_weight_layout.ndim () == kernel_layout.ndim ()) &&
317+ (desired_output_layout.ndim () == out_layout.ndim ())) {
318+ // Just a transpose
319+ data_layout = TransposeLike (InitialLayout (4 ), attrs->data_layout , desired_data_layout);
320+ weight_layout = TransposeLike (InitialLayout (4 ), attrs->kernel_layout , desired_weight_layout);
321+ output_layout = TransposeLike (InitialLayout (4 ), attrs->out_layout , desired_output_layout);
322+ new_attrs->data_layout = (*it).second [0 ];
323+ new_attrs->kernel_layout = (*it).second [1 ];
324+ new_attrs->out_layout = (*it).second .size () == 3 ? (*it).second [2 ] : (*it).second [0 ];
325+ return InferLayoutOutput ({data_layout, weight_layout}, {output_layout}, Attrs (new_attrs));
326+ } else {
327+ // Layout Transform
328+ auto data_si = GetStructInfo (call->args [0 ]);
329+ auto kernel_si = GetStructInfo (call->args [1 ]);
330+ TensorStructInfo data_sinfo = data_si.as <TensorStructInfo>().value ();
331+ TensorStructInfo kernel_sinfo = kernel_si.as <TensorStructInfo>().value ();
332+ Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape .as <ShapeExprNode>());
333+ Optional<ShapeExpr> kernel_shape = GetRef<ShapeExpr>(kernel_sinfo->shape .as <ShapeExprNode>());
334+
335+ bool can_data_proved =
336+ CanProveLayoutTransform (input_layout, desired_data_layout, data_shape.value ()->values );
337+ bool can_kernel_proved = CanProveLayoutTransform (kernel_layout, desired_weight_layout,
338+ kernel_shape.value ()->values );
339+
340+ if (can_data_proved && can_kernel_proved) {
341+ data_layout = TransposeSubLayoutLike (InitialLayout (4 ), input_layout, desired_data_layout);
342+ weight_layout =
343+ TransposeSubLayoutLike (InitialLayout (4 ), kernel_layout, desired_weight_layout);
344+ output_layout = TransposeSubLayoutLike (InitialLayout (4 ), out_layout, desired_output_layout);
345+ new_attrs->data_layout = (*it).second [0 ];
346+ new_attrs->kernel_layout = (*it).second [1 ];
347+ new_attrs->out_layout = (*it).second .size () == 3 ? (*it).second [2 ] : (*it).second [0 ];
348+ return InferLayoutOutput ({data_layout, weight_layout}, {output_layout}, Attrs (new_attrs));
349+ }
350+ }
334351 }
352+
353+ // We don't have a desired layout for conv2d or desired layouts not compatible.
354+ // We can just propagate the layout from the input.
355+ data_layout = GetLayoutDecision (var_layout_map, call->args [0 ]);
356+ weight_layout = GetLayoutDecision (var_layout_map, call->args [1 ]);
357+ output_layout = data_layout;
358+ new_attrs->data_layout =
359+ TransposeLike (attrs->data_layout , InitialLayout (4 ), data_layout->layout ).name ();
360+ new_attrs->kernel_layout =
361+ TransposeLike (attrs->kernel_layout , InitialLayout (4 ), weight_layout->layout ).name ();
362+ new_attrs->out_layout =
363+ TransposeLike (attrs->out_layout , InitialLayout (4 ), output_layout->layout ).name ();
335364 return InferLayoutOutput ({data_layout, weight_layout}, {output_layout}, Attrs (new_attrs));
336365}
337366
0 commit comments