Skip to content

Commit ffce47d

Browse files
committed
use align1 kernel for unusual channel cases (IC = 3 etc)
1 parent 6cdf205 commit ffce47d

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,13 @@ def create_gemm_operator(
141141
# TODO(masahi): A sensible way to pick reasonable default kernels
142142
DEFAULT_KERNELS = {
143143
75: {
144-
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
145-
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
144+
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
145+
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
146146
},
147+
# align1 variants do not seem to be available for sm80
147148
80: {
148-
"float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
149-
"float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
149+
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
150+
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
150151
},
151152
}
152153

@@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path):
160161
self.sm = sm
161162
self.cache = {}
162163

163-
def check_align(self, op_name, M):
164+
def check_align(self, op_name, M, K):
164165
"""Filter out kernels that cannot be supported."""
165166
aligns = re.findall(r"align[1|2|4|8]", op_name)
166167
assert len(aligns) == 1
168+
# The same alignment is used for all axes
167169
align = int(aligns[0][-1])
168-
if M % align != 0:
169-
return False
170-
return True
170+
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
171+
# See https://github.com/NVIDIA/cutlass/issues/362.
172+
# When the above issue is resolved, we can remove the alignment check on M below.
173+
return M % align == 0 and K % align == 0
171174

172175
def get_default(self, out_dtype, batched=False):
173176
"""Return the default kernel for the requested architecture.
@@ -194,7 +197,7 @@ def profile(
194197
ops = GENERATOR_FUNC_TABLE[self.sm](
195198
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
196199
)
197-
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
200+
ops = list(filter(lambda op: self.check_align(op["name"], M, K), ops))
198201

199202
for op in ops:
200203
op["runtime"] = -1

0 commit comments

Comments
 (0)