Skip to content

Commit 49e2c89

Browse files
committed
do not offload depthwise conv2d
1 parent cd83677 commit 49e2c89

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

python/tvm/relay/op/contrib/cutlass.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,22 @@ def check_batch_matmul(call):
9090
return check_dtype(lhs, rhs) and not transpose_a and transpose_b
9191

9292

93+
def is_depthwise_conv2d(ic, oc, groups):
94+
return ic == oc == groups
95+
96+
9397
def check_conv2d(call):
9498
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
9599
conv2d = get_root_call(call, "nn.conv2d")
96100
data_layout = conv2d.attrs.data_layout
97101
kernel_layout = conv2d.attrs.kernel_layout
98102
data = conv2d.args[0].checked_type
99103
weight = conv2d.args[1].checked_type
100-
return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight)
104+
if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight):
105+
return False
106+
IC = data.shape[3]
107+
OC = weight.shape[0]
108+
return not is_depthwise_conv2d(IC, OC, call.attrs.groups)
101109

102110

103111
def partition_for_cutlass(mod):

0 commit comments

Comments
 (0)