Skip to content

Commit 6551ece

Browse files
Tristan Konoligepfk-beta
authored andcommitted
[FIX] Only allow autoscheduler layout rewritting in conv2d_nhwc (apache#10522)
Autoscheduler cannot handle layout rewritting for grouped convolutions and non-nhwc layouts. Previously layout rewriting was enabled for all convolutions causing errors where autoscheduler generated too large layouts like `64N4n1n1n1H1W68C1n1h1w2c2n`. Autoscheduler is now only enabled on non-grouped conv2d_nhwc.
1 parent 72bd08c commit 6551ece

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

python/tvm/topi/nn/conv2d.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def conv2d_nhwc(
323323
"NHWC",
324324
out_dtype,
325325
auto_scheduler_rewritten_layout,
326+
auto_scheduler_should_rewrite_layout=True,
326327
)
327328

328329

@@ -714,6 +715,7 @@ def conv(
714715
order: str,
715716
out_dtype: Union[str, None] = None,
716717
auto_scheduler_rewritten_layout: Optional[str] = None,
718+
auto_scheduler_should_rewrite_layout: bool = False,
717719
):
718720
"""Convolution operator in NCHW or NHWC layout.
719721
@@ -752,6 +754,11 @@ def conv(
752754
Elements are converted to this type before elementwise multiplication
753755
and summation.
754756
757+
auto_scheduler_should_rewrite_layout : bool
758+
Should auto scheduler be allowed to rewrite the layout of the filter
759+
tensor. Defaults to false. This can cause errors if used with grouped
760+
convs.
761+
755762
auto_scheduler_rewritten_layout: str
756763
Layout from autoscheduler's layout rewritting.
757764
@@ -862,7 +869,7 @@ def compute(*args):
862869
# tag is expected to be lowercase
863870
tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
864871
name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
865-
attrs={"layout_free_placeholders": [filt]},
872+
attrs={"layout_free_placeholders": [filt]} if auto_scheduler_should_rewrite_layout else {},
866873
varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[permutation_from]),
867874
)
868875
# if we used autoscheduler's changed layout we need to rewrite the ordering

0 commit comments

Comments
 (0)