Skip to content

Commit 6b780db

Browse files
committed
check align on N dim
1 parent 308c4da commit 6b780db

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, sm, cutlass_path, binary_path):
161161
self.sm = sm
162162
self.cache = {}
163163

164-
def check_align(self, op_name, M, K):
164+
def check_align(self, op_name, M, N, K):
165165
"""Filter out kernels that cannot be supported."""
166166
aligns = re.findall(r"align[1|2|4|8]", op_name)
167167
assert len(aligns) == 1
@@ -170,7 +170,7 @@ def check_align(self, op_name, M, K):
170170
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
171171
# See https://github.com/NVIDIA/cutlass/issues/362.
172172
# When the above issue is resolved, we can remove the alignment check on M below.
173-
return M % align == 0 and K % align == 0
173+
return all([dim % align == 0 for dim in [M, N, K]])
174174

175175
def get_default(self, out_dtype, batched=False):
176176
"""Return the default kernel for the requested architecture.
@@ -197,7 +197,7 @@ def profile(
197197
ops = GENERATOR_FUNC_TABLE[self.sm](
198198
out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
199199
)
200-
ops = list(filter(lambda op: self.check_align(op["name"], M, K), ops))
200+
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))
201201

202202
for op in ops:
203203
op["runtime"] = -1

tests/python/contrib/test_cutlass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def verify_batch_matmul(
242242
def test_dense():
243243
verify_dense(get_dense(M, N, K), M, N, K)
244244
verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K)
245+
# Test align1 case
246+
verify_dense(get_dense_bias(M, N + 1, K), M, N + 1, K)
245247

246248

247249
def test_dense_bias():

0 commit comments

Comments
 (0)