1616# under the License.
1717# pylint: disable=invalid-name
1818"""Patterns supported CUTLASS."""
19+ from functools import partial
1920from tvm import relay
2021from tvm .ir .transform import Sequential , PassContext
2122from tvm .relay import transform
@@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
8990 return conv2d_out
9091
9192
93+ def make_residual_block_pattern (tensor_op_out , binary_op = "add" , with_act = "relu" ):
94+ """Add pattern for residual blocks."""
95+ residual_input = wildcard ()
96+ binary_out = is_op (binary_op )(tensor_op_out , residual_input ) | is_op (binary_op )(
97+ residual_input , tensor_op_out
98+ )
99+
100+ if with_act is not None and with_act == "relu" :
101+ return is_op ("nn.relu" )(binary_out )
102+
103+ return binary_out
104+
105+
92106def check_dtype (lhs , rhs ):
93107 """Check if dtypes in the given workload are supported by CUTLASS."""
94108 # Only fp16 inputs are supported for now.
@@ -139,6 +153,25 @@ def check_conv2d(call):
139153 return not is_depthwise_conv2d (IC , OC , conv2d .attrs .groups )
140154
141155
156+ def check_conv2d_residual (call , binary_op ):
157+ """Check if the given conv2d workload can be offloaded to CUTLASS."""
158+ conv2d = get_root_call (call , "nn.conv2d" )
159+ if not check_conv2d (call ):
160+ return False
161+
162+ residual_binop = get_root_call (call , binary_op )
163+ lhs = residual_binop .args [0 ]
164+ rhs = residual_binop .args [1 ]
165+
166+ # residual_input is pattern-matched as a wildcard. Make sure it does not sit between
167+ # residual binary op and the root conv2d of this pattern.
168+ # If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
169+ if get_root_call (lhs , "nn.conv2d" ) == conv2d and get_root_call (rhs , "nn.conv2d" ) == conv2d :
170+ return True
171+
172+ return all ([x == y for (x , y ) in zip (lhs .checked_type .shape , rhs .checked_type .shape )])
173+
174+
142175def partition_for_cutlass (mod , params = None ):
143176 """Partition the input module into CUTLASS-supported subgraphs."""
144177 dense_pat = ("cutlass.dense" , make_gemm_pattern (False , None ), check_gemm )
@@ -165,16 +198,6 @@ def partition_for_cutlass(mod, params=None):
165198 ]
166199
167200 conv2d_patterns = [
168- (
169- "cutlass.conv2d_bias_hardswish" ,
170- make_conv2d_pattern (with_bias = True , with_act = "hardswish" ),
171- check_conv2d ,
172- ),
173- (
174- "cutlass.conv2d_bias_silu" ,
175- make_conv2d_pattern (with_bias = True , with_act = "silu" ),
176- check_conv2d ,
177- ),
178201 (
179202 "cutlass.conv2d_bias_hardswish" ,
180203 make_conv2d_pattern (with_bias = True , with_act = "hardswish" ),
@@ -199,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
199222 ("cutlass.conv2d" , make_conv2d_pattern (), check_conv2d ),
200223 ]
201224
202- cutlass_patterns = dense_patterns + conv2d_patterns
225+ residual_block_patterns = []
226+
227+ for with_act , postfix in [("relu" , "_relu" ), (None , "" )]:
228+ for name , pat , _ in conv2d_patterns [:- 1 ]:
229+ for bin_op in ["add" , "multiply" ]:
230+ residual_block_patterns .append (
231+ (
232+ name + "_residual_" + bin_op + postfix ,
233+ make_residual_block_pattern (pat , bin_op , with_act = with_act ),
234+ partial (check_conv2d_residual , binary_op = bin_op ),
235+ )
236+ )
237+
238+ cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns
203239
204240 if params is not None :
205241 mod ["main" ] = bind_params_by_name (mod ["main" ], params )
@@ -217,6 +253,7 @@ def partition_for_cutlass(mod, params=None):
217253 seq = Sequential (
218254 [
219255 transform .InferType (),
256+ transform .SimplifyExpr (),
220257 transform .MergeComposite (cutlass_patterns ),
221258 transform .AnnotateTarget (["cutlass" ], include_non_call_ops = False ),
222259 transform .PartitionGraph (bind_constants = False ),
0 commit comments