Skip to content

Commit d1dafbd

Browse files
authored
[CUTLASS] More robust support for pattern matching and alignment (#9698)
* bug fix in im2col encoding * skip legalize when batch size is dynamic * add sm75 kernels to sm80 profilings * add dtype and layout check in parttern match * use align1 kernel for unusual channel cases (IC = 3 etc) * test IC=3 convolution * fixed check functions for fused cases, run infer type before mergecomposite * check align on N dim * add comment on IC == 3 case * lint fix * do not offload depthwise conv2d * lint * trigger CI
1 parent 4e70931 commit d1dafbd

File tree

6 files changed

+139
-54
lines changed

6 files changed

+139
-54
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,16 @@ def profile(
128128
If profile_all is False, return immediately after the first applicable kernel is found.
129129
If use_multiprocessing is True, compile all profiler executables in parallel.
130130
"""
131-
B, H, W, C = d_shape
132-
K, R, S, _ = w_shape
131+
B, _, _, IC = d_shape
132+
OC, R, S, _ = w_shape
133133
_, P, Q, _ = out_shape
134134

135-
M = B * H * W
136-
K = R * S * C
137-
N = B * P * Q
135+
M = B * P * Q
136+
N = OC
137+
K = R * S * IC
138138

139139
gemm_profile_result = self.gemm_profiler.profile(
140-
M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
140+
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
141141
)
142142

143143
tile_description = gemm_profile_result["tile_description"]

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,13 @@ def create_gemm_operator(
141141
# TODO(masahi): A sensible way to pick reasonable default kernels
142142
DEFAULT_KERNELS = {
143143
75: {
144-
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
145-
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
144+
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
145+
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
146146
},
147+
# align1 variants do not seem to be available for sm80
147148
80: {
148-
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
149-
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
149+
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
150+
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
150151
},
151152
}
152153

@@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path):
160161
self.sm = sm
161162
self.cache = {}
162163

163-
def check_align(self, op_name, M):
164+
def check_align(self, op_name, M, N, K):
164165
"""Filter out kernels that cannot be supported."""
165166
aligns = re.findall(r"align[1|2|4|8]", op_name)
166167
assert len(aligns) == 1
168+
# The same alignment is used for all axes
167169
align = int(aligns[0][-1])
168-
if M % align != 0:
169-
return False
170-
return True
170+
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
171+
# See https://github.com/NVIDIA/cutlass/issues/362.
172+
# When the above issue is resolved, we can remove the alignment check on M below.
173+
return all([dim % align == 0 for dim in [M, N, K]])
171174

172175
def get_default(self, out_dtype, batched=False):
173176
"""Return the default kernel for the requested architecture.
@@ -194,7 +197,7 @@ def profile(
194197
ops = GENERATOR_FUNC_TABLE[self.sm](
195198
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
196199
)
197-
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
200+
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))
198201

199202
for op in ops:
200203
op["runtime"] = -1

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,11 @@ def get_tile_descriptions(math_inst):
152152
TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
153153
]
154154

155-
return generate_tensor_op_common(
155+
sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, op_creator)
156+
sm80_kernels = generate_tensor_op_common(
156157
math_instructions, alignment_constraints, get_tile_descriptions, op_creator
157158
)
159+
return sm75_kernels + sm80_kernels
158160

159161

160162
class ProfilerEngine:

python/tvm/relay/op/contrib/cutlass.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=invalid-name
1718
"""Patterns supported CUTLASS."""
19+
from tvm.ir.transform import Sequential
1820
from tvm.relay import transform
1921
from ...dataflow_pattern import wildcard, is_op, is_constant
2022

@@ -56,31 +58,88 @@ def make_batch_matmul_pattern():
5658

5759

5860
def make_conv2d_pattern():
59-
# TODO(masahi): Check layout and alignment
6061
return is_op("nn.conv2d")(wildcard(), wildcard())
6162

6263

64+
def check_dtype(lhs, rhs):
65+
"""Check if dtypes in the given workload are supported by CUTLASS."""
66+
# Only fp16 inputs are supported for now.
67+
return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16"
68+
69+
70+
def get_root_call(call, root_op_name):
71+
if str(call.op) == root_op_name:
72+
return call
73+
return get_root_call(call.args[0], root_op_name)
74+
75+
76+
def check_gemm(call):
77+
"""Check if the given dense workload can be offloaded to CUTLASS."""
78+
dense = get_root_call(call, "nn.dense")
79+
lhs = dense.args[0].checked_type
80+
rhs = dense.args[1].checked_type
81+
return check_dtype(lhs, rhs)
82+
83+
84+
def check_batch_matmul(call):
85+
"""Check if the given batch_matmul workload can be offloaded to CUTLASS."""
86+
batch_matmul = get_root_call(call, "nn.batch_matmul")
87+
lhs = batch_matmul.args[0].checked_type
88+
rhs = batch_matmul.args[1].checked_type
89+
transpose_a = batch_matmul.attrs.transpose_a
90+
transpose_b = batch_matmul.attrs.transpose_b
91+
return check_dtype(lhs, rhs) and not transpose_a and transpose_b
92+
93+
94+
def is_depthwise_conv2d(ic, oc, groups):
95+
return ic == oc == groups
96+
97+
98+
def check_conv2d(call):
99+
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
100+
conv2d = get_root_call(call, "nn.conv2d")
101+
data_layout = conv2d.attrs.data_layout
102+
kernel_layout = conv2d.attrs.kernel_layout
103+
data = conv2d.args[0].checked_type
104+
weight = conv2d.args[1].checked_type
105+
if data_layout != "NHWC" or kernel_layout != "OHWI" or not check_dtype(data, weight):
106+
return False
107+
IC = data.shape[3]
108+
OC = weight.shape[0]
109+
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)
110+
111+
63112
def partition_for_cutlass(mod):
64113
"""Partition the input module into CUTLASS-supported subgraphs."""
65-
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
66-
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None))
67-
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"))
68-
dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu"))
114+
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
115+
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
116+
dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm)
117+
dense_bias_gelu_fp16_pat = (
118+
"cutlass.dense_bias_gelu_fp16",
119+
make_gemm_pattern(True, "gelu"),
120+
check_gemm,
121+
)
69122
dense_bias_gelu_fp32_pat = (
70123
"cutlass.dense_bias_gelu_fp32",
71124
make_gemm_pattern(True, "gelu", out_dtype="float32"),
125+
check_gemm,
72126
)
73127
cutlass_patterns = [
74128
dense_bias_gelu_fp16_pat,
75129
dense_bias_gelu_fp32_pat,
76130
dense_bias_relu_pat,
77131
dense_bias_pat,
78132
dense_pat,
79-
("cutlass.batch_matmul", make_batch_matmul_pattern()),
133+
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
80134
# TODO(masahi): Add more conv2d patterns
81-
("cutlass.conv2d", make_conv2d_pattern()),
135+
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
82136
]
83-
mod = transform.MergeComposite(cutlass_patterns)(mod)
84-
mod = transform.AnnotateTarget(["cutlass"])(mod)
85-
mod = transform.PartitionGraph()(mod)
86-
return mod
137+
seq = Sequential(
138+
[
139+
transform.InferType(),
140+
transform.MergeComposite(cutlass_patterns),
141+
transform.AnnotateTarget(["cutlass"]),
142+
transform.PartitionGraph(),
143+
]
144+
)
145+
return seq(mod)

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,10 @@ def _conv2d_legalize(attrs, inputs, arg_types):
450450

451451
elif data_dtype in ["float16"]:
452452
if data_layout == "NHWC" and kernel_layout == "HWIO":
453+
if isinstance(data_tensor.shape[0], tvm.tir.expr.Any):
454+
# Skip legalize when the batch size is dynamic
455+
return None
456+
453457
batch = data_tensor.shape[0].value
454458
in_channel = data_tensor.shape[3].value
455459
out_channel = kernel_tensor.shape[3].value

tests/python/contrib/test_cutlass.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def verify_batch_matmul(
242242
def test_dense():
243243
verify_dense(get_dense(M, N, K), M, N, K)
244244
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
245+
# Test align1 case
246+
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)
245247

246248

247249
def test_dense_bias():
@@ -312,13 +314,14 @@ def convert_conv2d_layout(mod, desired_layouts):
312314

313315

314316
def verify_conv2d(
315-
mod_nchw,
316-
mod_ref,
317+
mod_nchw, # can be dynamic batch
318+
mod_ref, # always static batch
317319
d_shape,
318320
w_shape,
319321
sm=80,
320322
atol=1e-5,
321323
rtol=1e-5,
324+
use_cudnn_ref=False,
322325
run_benchmark=False,
323326
):
324327
if not has_cutlass():
@@ -332,52 +335,66 @@ def verify_conv2d(
332335
typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
333336
use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
334337

338+
mod_weight_ohwi = convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]})
339+
335340
if use_vm:
336-
rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
337-
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
338-
)
341+
rt_mod, _, num_cutlass_partition = profile_and_build_vm(mod_weight_ohwi, params, sm)
339342
out = get_output_vm(rt_mod, ["data"], [np_data])
340343
else:
341-
rt_mod, dev, num_cutlass_partition = profile_and_build(
342-
convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
344+
rt_mod, _, num_cutlass_partition = profile_and_build(
345+
mod_weight_ohwi,
343346
params,
344347
sm,
345348
)
346349
out = get_output(rt_mod, ["data"], [np_data])
347350

348351
assert num_cutlass_partition > 0
349352

350-
rt_mod_ref, _ = get_ref_rt_mod(
351-
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
352-
params,
353-
target="cuda",
354-
)
355-
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
353+
if use_cudnn_ref:
354+
rt_mod_ref, dev = get_ref_rt_mod(
355+
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "OHWI"]}),
356+
params,
357+
target="cuda -libs=cudnn",
358+
)
359+
else:
360+
rt_mod_ref, dev = get_ref_rt_mod(
361+
convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
362+
params,
363+
target="cuda",
364+
)
356365

357-
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
366+
ref_out = get_output(rt_mod_ref, ["data"], [np_data])
358367

359368
if run_benchmark:
360369
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
361370
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
362371

372+
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
373+
363374

364375
def test_conv2d():
376+
for IC in [3, 16]:
377+
d_shape = (16, IC, 32, 32)
378+
w_shape = (32, IC, 3, 3)
379+
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
380+
381+
verify_conv2d(
382+
mod_nchw,
383+
mod_nchw,
384+
d_shape,
385+
w_shape,
386+
sm=80,
387+
atol=1e-5,
388+
rtol=1e-5,
389+
use_cudnn_ref=(IC == 3), # The autotvm kernel has an accuracy issue with IC == 3 case
390+
run_benchmark=False,
391+
)
392+
365393
d_shape = (16, 16, 32, 32)
366394
w_shape = (32, 16, 3, 3)
367-
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
368-
369-
verify_conv2d(
370-
mod_nchw,
371-
mod_nchw,
372-
d_shape,
373-
w_shape,
374-
sm=80,
375-
atol=1e-5,
376-
rtol=1e-5,
377-
run_benchmark=False,
378-
)
379-
380395
dyn_batch_shape = (relay.Any(),) + d_shape[1:]
396+
397+
mod_nchw = get_conv2d_nchw(d_shape, w_shape)
381398
mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
382399

383400
verify_conv2d(

0 commit comments

Comments
 (0)