@@ -56,29 +56,61 @@ def make_batch_matmul_pattern():
5656
5757
5858def make_conv2d_pattern ():
59- # TODO(masahi): Check layout and alignment
6059 return is_op ("nn.conv2d" )(wildcard (), wildcard ())
6160
6261
62+ def check_dtype (lhs , rhs ):
63+ """Check if dtypes in the given workload are supported by CUTLASS."""
64+ return lhs .dtype == rhs .dtype and lhs .dtype == "float16" and rhs .dtype == "float16"
65+
66+
67+ def check_gemm (call ):
68+ """Check if the given dense workload can be offloaded to CUTLASS."""
69+ lhs = call .args [0 ].checked_type
70+ rhs = call .args [1 ].checked_type
71+ return check_dtype (lhs , rhs )
72+
73+
74+ def check_batch_matmul (call ):
75+ """Check if the given batch_matmul workload can be offloaded to CUTLASS."""
76+ transpose_a = call .attrs .transpose_a
77+ transpose_b = call .attrs .transpose_b
78+ return check_gemm (call ) and transpose_a == False and transpose_b == True
79+
80+
81+ def check_conv2d (call ):
82+ """Check if the given conv2d workload can be offloaded to CUTLASS."""
83+ data_layout = call .attrs .data_layout
84+ kernel_layout = call .attrs .kernel_layout
85+ data = call .args [0 ].checked_type
86+ weight = call .args [1 ].checked_type
87+ return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype (data , weight )
88+
89+
6390def partition_for_cutlass (mod ):
6491 """Partition the input module into CUTLASS-supported subgraphs."""
65- dense_pat = ("cutlass.dense" , make_gemm_pattern (False , None ))
66- dense_bias_pat = ("cutlass.dense_bias" , make_gemm_pattern (True , None ))
67- dense_bias_relu_pat = ("cutlass.dense_bias_relu" , make_gemm_pattern (True , "relu" ))
68- dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16" , make_gemm_pattern (True , "gelu" ))
92+ dense_pat = ("cutlass.dense" , make_gemm_pattern (False , None ), check_gemm )
93+ dense_bias_pat = ("cutlass.dense_bias" , make_gemm_pattern (True , None ), check_gemm )
94+ dense_bias_relu_pat = ("cutlass.dense_bias_relu" , make_gemm_pattern (True , "relu" ), check_gemm )
95+ dense_bias_gelu_fp16_pat = (
96+ "cutlass.dense_bias_gelu_fp16" ,
97+ make_gemm_pattern (True , "gelu" ),
98+ check_gemm ,
99+ )
69100 dense_bias_gelu_fp32_pat = (
70101 "cutlass.dense_bias_gelu_fp32" ,
71102 make_gemm_pattern (True , "gelu" , out_dtype = "float32" ),
103+ check_gemm ,
72104 )
73105 cutlass_patterns = [
74106 dense_bias_gelu_fp16_pat ,
75107 dense_bias_gelu_fp32_pat ,
76108 dense_bias_relu_pat ,
77109 dense_bias_pat ,
78110 dense_pat ,
79- ("cutlass.batch_matmul" , make_batch_matmul_pattern ()),
111+ ("cutlass.batch_matmul" , make_batch_matmul_pattern (), check_batch_matmul ),
80112 # TODO(masahi): Add more conv2d patterns
81- ("cutlass.conv2d" , make_conv2d_pattern ()),
113+ ("cutlass.conv2d" , make_conv2d_pattern (), check_conv2d ),
82114 ]
83115 mod = transform .MergeComposite (cutlass_patterns )(mod )
84116 mod = transform .AnnotateTarget (["cutlass" ])(mod )
0 commit comments