diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index dbff3f601655..bd84d837c58e 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -323,6 +323,7 @@ def conv2d_nhwc( "NHWC", out_dtype, auto_scheduler_rewritten_layout, + auto_scheduler_should_rewrite_layout=True, ) @@ -714,6 +715,7 @@ def conv( order: str, out_dtype: Union[str, None] = None, auto_scheduler_rewritten_layout: Optional[str] = None, + auto_scheduler_should_rewrite_layout: bool = False, ): """Convolution operator in NCHW or NHWC layout. @@ -752,6 +754,11 @@ def conv( Elements are converted to this type before elementwise multiplication and summation. + auto_scheduler_should_rewrite_layout : bool + Should auto scheduler be allowed to rewrite the layout of the filter + tensor. Defaults to false. This can cause errors if used with grouped + convs. + auto_scheduler_rewritten_layout: str Layout from autoscheduler's layout rewritting. @@ -862,7 +869,7 @@ def compute(*args): # tag is expected to be lowercase tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", - attrs={"layout_free_placeholders": [filt]}, + attrs={"layout_free_placeholders": [filt]} if auto_scheduler_should_rewrite_layout else {}, varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[permutation_from]), ) # if we used autoscheduler's changed layout we need to rewrite the ordering