Skip to content

Commit fda151b

Browse files
committed
Support residual block fusion
1 parent ce9d52f commit fda151b

File tree

8 files changed

+283
-38
lines changed

8 files changed

+283
-38
lines changed

python/tvm/contrib/cutlass/conv2d_operation.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(self):
150150
${element_accumulator},
151151
${element_epilogue}
152152
>"""
153+
153154
self.epilogue_no_beta_scaling = """
154155
${epilogue_functor}<
155156
${element_c},
@@ -159,10 +160,22 @@ def __init__(self):
159160
cutlass::epilogue::thread::ScaleType::NoBetaScaling
160161
>"""
161162

163+
self.epilogue_residual_block = """
164+
${epilogue_functor}<
165+
${element_c},
166+
${element_accumulator},
167+
${element_epilogue},
168+
${element_c},
169+
${epilogue_vector_length},
170+
${activation},
171+
${binary_op},
172+
${unary_op}
173+
>"""
174+
162175
self.template = """
163176
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
164177
using ${operation_name} =
165-
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
178+
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
166179
${element_a},
167180
${layout_a},
168181
${element_b},
@@ -186,7 +199,7 @@ def __init__(self):
186199
>::Kernel;
187200
"""
188201

189-
def emit(self, operation, no_beta_scaling=False):
202+
def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
190203
"""Instantiate a Conv2d kernel from given `operation`."""
191204
warp_shape = [
192205
int(
@@ -246,14 +259,26 @@ def emit(self, operation, no_beta_scaling=False):
246259
],
247260
"align_a": str(operation.A.alignment),
248261
"align_b": str(operation.B.alignment),
262+
"conv_kernel_postfix": "",
249263
}
250264

251-
template = substitute_template(
252-
self.template,
253-
{
254-
"epilogue": self.epilogue_no_beta_scaling
255-
if no_beta_scaling
256-
else self.epilogue_default
257-
},
258-
)
265+
if residual_block_info:
266+
template = substitute_template(
267+
self.template, {"epilogue": self.epilogue_residual_block}
268+
)
269+
values.update(
270+
{
271+
"unary_op": residual_block_info["unary_op"],
272+
"binary_op": residual_block_info["binary_op"],
273+
"activation": residual_block_info["activation"],
274+
"conv_kernel_postfix": "WithBroadcast",
275+
}
276+
)
277+
elif no_beta_scaling:
278+
template = substitute_template(
279+
self.template, {"epilogue": self.epilogue_no_beta_scaling}
280+
)
281+
else:
282+
template = substitute_template(self.template, {"epilogue": self.epilogue_default})
283+
259284
return substitute_template(template, values)

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
3939
Instantiate a cutlass kernel from the given configuration,
4040
along with the epilouge functor
4141
"""
42-
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
42+
if "residual" in op_type:
43+
activation_map = {
44+
"cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish",
45+
"cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
46+
"cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid",
47+
"cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
48+
"cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
49+
}
50+
prefix = op_type[: op_type.find("_residual")]
51+
activation = activation_map[prefix]
52+
binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus"
53+
unary_op = (
54+
"cutlass::epilogue::thread::ReLu"
55+
if op_type.endswith("relu")
56+
else "cutlass::epilogue::thread::Identity"
57+
)
58+
residual_block_info = {
59+
"activation": activation,
60+
"binary_op": binary_op,
61+
"unary_op": unary_op,
62+
}
63+
epilogue = EpilogueFunctor.LinearCombinationResidualBlock
64+
no_beta_scaling = False
65+
else:
66+
residual_block_info = None
67+
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
4368

4469
element_a, element_b, element_c, element_epilogue = data_type
4570

@@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
6287
)
6388

6489
name = op.procedural_name()
65-
opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
90+
opdef = EmitConv2dInstance().emit(
91+
op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info
92+
)
6693

6794
return name, opdef
6895

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,21 @@ def get_tile_descriptions(math_inst):
165165
80: generate_sm80_tensor_op_16816,
166166
}
167167

168+
EPILOGUE_MAP = {
169+
"cutlass.dense": (EpilogueFunctor.LinearCombination, True),
170+
"cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True),
171+
"cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
172+
"cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False),
173+
"cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False),
174+
"cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, True),
175+
"cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False),
176+
"cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
177+
"cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid, False),
178+
"cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
179+
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
180+
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, True),
181+
}
182+
168183

169184
# (Epilogue functor name, no_beta_scaling)
170185
EPILOGUE_MAP = {

python/tvm/contrib/cutlass/library.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
151151
LinearCombinationSigmoid = enum_auto()
152152
LinearCombinationSilu = enum_auto()
153153
LinearCombinationHardSwish = enum_auto()
154+
LinearCombinationResidualBlock = enum_auto()
154155

155156

156157
EpilogueFunctorTag = {
@@ -161,6 +162,7 @@ class EpilogueFunctor(enum.Enum):
161162
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
162163
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
163164
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
165+
EpilogueFunctor.LinearCombinationResidualBlock: "cutlass::epilogue::thread::LinearCombinationResidualBlock",
164166
}
165167

166168

python/tvm/relay/op/contrib/cutlass.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Patterns supported CUTLASS."""
19+
from functools import partial
1920
from tvm import relay
2021
from tvm.ir.transform import Sequential, PassContext
2122
from tvm.relay import transform
@@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
8990
return conv2d_out
9091

9192

93+
def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
94+
"""Add pattern for residual blocks."""
95+
residual_input = wildcard()
96+
binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)(
97+
residual_input, tensor_op_out
98+
)
99+
100+
if with_act is not None and with_act == "relu":
101+
return is_op("nn.relu")(binary_out)
102+
103+
return binary_out
104+
105+
92106
def check_dtype(lhs, rhs):
93107
"""Check if dtypes in the given workload are supported by CUTLASS."""
94108
# Only fp16 inputs are supported for now.
@@ -139,6 +153,25 @@ def check_conv2d(call):
139153
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)
140154

141155

156+
def check_conv2d_residual(call, binary_op):
157+
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
158+
conv2d = get_root_call(call, "nn.conv2d")
159+
if not check_conv2d(call):
160+
return False
161+
162+
residual_binop = get_root_call(call, binary_op)
163+
lhs = residual_binop.args[0]
164+
rhs = residual_binop.args[1]
165+
166+
# residual_input is pattern-matched as a wildcard. Make sure it does not sit between
167+
# residual binary op and the root conv2d of this pattern.
168+
# If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
169+
if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d:
170+
return True
171+
172+
return all([x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)])
173+
174+
142175
def partition_for_cutlass(mod, params=None):
143176
"""Partition the input module into CUTLASS-supported subgraphs."""
144177
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
@@ -165,16 +198,6 @@ def partition_for_cutlass(mod, params=None):
165198
]
166199

167200
conv2d_patterns = [
168-
(
169-
"cutlass.conv2d_bias_hardswish",
170-
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
171-
check_conv2d,
172-
),
173-
(
174-
"cutlass.conv2d_bias_silu",
175-
make_conv2d_pattern(with_bias=True, with_act="silu"),
176-
check_conv2d,
177-
),
178201
(
179202
"cutlass.conv2d_bias_hardswish",
180203
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
@@ -199,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
199222
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
200223
]
201224

202-
cutlass_patterns = dense_patterns + conv2d_patterns
225+
residual_block_patterns = []
226+
227+
for with_act, postfix in [("relu", "_relu"), (None, "")]:
228+
for name, pat, _ in conv2d_patterns[:-1]:
229+
for bin_op in ["add", "multiply"]:
230+
residual_block_patterns.append(
231+
(
232+
name + "_residual_" + bin_op + postfix,
233+
make_residual_block_pattern(pat, bin_op, with_act=with_act),
234+
partial(check_conv2d_residual, binary_op=bin_op),
235+
)
236+
)
237+
238+
cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns
203239

204240
if params is not None:
205241
mod["main"] = bind_params_by_name(mod["main"], params)
@@ -217,6 +253,7 @@ def partition_for_cutlass(mod, params=None):
217253
seq = Sequential(
218254
[
219255
transform.InferType(),
256+
transform.SimplifyExpr(),
220257
transform.MergeComposite(cutlass_patterns),
221258
transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
222259
transform.PartitionGraph(bind_constants=False),

0 commit comments

Comments
 (0)