Skip to content

Commit 6cdf205

Browse files
committed
add dtype and layout check in parttern match
1 parent 7743cc6 commit 6cdf205

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

python/tvm/relay/op/contrib/cutlass.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,29 +56,61 @@ def make_batch_matmul_pattern():
5656

5757

5858
def make_conv2d_pattern():
59-
# TODO(masahi): Check layout and alignment
6059
return is_op("nn.conv2d")(wildcard(), wildcard())
6160

6261

62+
def check_dtype(lhs, rhs):
63+
"""Check if dtypes in the given workload are supported by CUTLASS."""
64+
return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16"
65+
66+
67+
def check_gemm(call):
68+
"""Check if the given dense workload can be offloaded to CUTLASS."""
69+
lhs = call.args[0].checked_type
70+
rhs = call.args[1].checked_type
71+
return check_dtype(lhs, rhs)
72+
73+
74+
def check_batch_matmul(call):
75+
"""Check if the given batch_matmul workload can be offloaded to CUTLASS."""
76+
transpose_a = call.attrs.transpose_a
77+
transpose_b = call.attrs.transpose_b
78+
return check_gemm(call) and transpose_a == False and transpose_b == True
79+
80+
81+
def check_conv2d(call):
82+
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
83+
data_layout = call.attrs.data_layout
84+
kernel_layout = call.attrs.kernel_layout
85+
data = call.args[0].checked_type
86+
weight = call.args[1].checked_type
87+
return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight)
88+
89+
6390
def partition_for_cutlass(mod):
6491
"""Partition the input module into CUTLASS-supported subgraphs."""
65-
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
66-
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None))
67-
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"))
68-
dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"))
92+
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
93+
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
94+
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm)
95+
dense_bias_gelu_fp16_pat = (
96+
"cutlass.dense_bias_gelu_fp16",
97+
make_gemm_pattern(True, "gelu"),
98+
check_gemm,
99+
)
69100
dense_bias_gelu_fp32_pat = (
70101
"cutlass.dense_bias_gelu_fp32",
71102
make_gemm_pattern(True, "gelu", out_dtype="float32"),
103+
check_gemm,
72104
)
73105
cutlass_patterns = [
74106
dense_bias_gelu_fp16_pat,
75107
dense_bias_gelu_fp32_pat,
76108
dense_bias_relu_pat,
77109
dense_bias_pat,
78110
dense_pat,
79-
("cutlass.batch_matmul", make_batch_matmul_pattern()),
111+
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
80112
# TODO(masahi): Add more conv2d patterns
81-
("cutlass.conv2d", make_conv2d_pattern()),
113+
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
82114
]
83115
mod = transform.MergeComposite(cutlass_patterns)(mod)
84116
mod = transform.AnnotateTarget(["cutlass"])(mod)

0 commit comments

Comments
 (0)