File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
python/tvm/relay/op/contrib Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -90,14 +90,22 @@ def check_batch_matmul(call):
9090 return check_dtype (lhs , rhs ) and not transpose_a and transpose_b
9191
9292
93+ def is_depthwise_conv2d (ic , oc , groups ):
94+ return ic == oc == groups
95+
96+
9397def check_conv2d (call ):
9498 """Check if the given conv2d workload can be offloaded to CUTLASS."""
9599 conv2d = get_root_call (call , "nn.conv2d" )
96100 data_layout = conv2d .attrs .data_layout
97101 kernel_layout = conv2d .attrs .kernel_layout
98102 data = conv2d .args [0 ].checked_type
99103 weight = conv2d .args [1 ].checked_type
100- return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype (data , weight )
104+ if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype (data , weight ):
105+ return False
106+ IC = data .shape [3 ]
107+ OC = weight .shape [0 ]
108+ return not is_depthwise_conv2d (IC , OC , call .attrs .groups )
101109
102110
103111def partition_for_cutlass (mod ):
You can’t perform that action at this time.
0 commit comments