@@ -161,7 +161,7 @@ def __init__(self, sm, cutlass_path, binary_path):
161161 self .sm = sm
162162 self .cache = {}
163163
164- def check_align (self , op_name , M , K ):
164+ def check_align (self , op_name , M , N , K ):
165165 """Filter out kernels that cannot be supported."""
166166 aligns = re .findall (r"align[1|2|4|8]" , op_name )
167167 assert len (aligns ) == 1
@@ -170,7 +170,7 @@ def check_align(self, op_name, M, K):
170170 # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
171171 # See https://github.com/NVIDIA/cutlass/issues/362.
172172 # When the above issue is resolved, we can remove the alignment check on M below.
173- return M % align == 0 and K % align == 0
173+ return all ([ dim % align == 0 for dim in [ M , N , K ]])
174174
175175 def get_default (self , out_dtype , batched = False ):
176176 """Return the default kernel for the requested architecture.
@@ -197,7 +197,7 @@ def profile(
197197 ops = GENERATOR_FUNC_TABLE [self .sm ](
198198 out_dtype , op_creator = partial (create_gemm_operator , batched = batched )
199199 )
200- ops = list (filter (lambda op : self .check_align (op ["name" ], M , K ), ops ))
200+ ops = list (filter (lambda op : self .check_align (op ["name" ], M , N , K ), ops ))
201201
202202 for op in ops :
203203 op ["runtime" ] = - 1
0 commit comments