@@ -141,12 +141,13 @@ def create_gemm_operator(
141141# TODO(masahi): A sensible way to pick reasonable default kernels
142142DEFAULT_KERNELS = {
143143 75 : {
144- "float16" : "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4 " ,
145- "float32" : "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4 " ,
144+ "float16" : "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1 " ,
145+ "float32" : "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1 " ,
146146 },
147+ # align1 variants do not seem to be available for sm80
147148 80 : {
148- "float16" : "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4 " ,
149- "float32" : "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4 " ,
149+ "float16" : "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1 " ,
150+ "float32" : "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1 " ,
150151 },
151152}
152153
@@ -160,14 +161,16 @@ def __init__(self, sm, cutlass_path, binary_path):
160161 self .sm = sm
161162 self .cache = {}
162163
163- def check_align (self , op_name , M ):
164+ def check_align (self , op_name , M , K ):
164165 """Filter out kernels that cannot be supported."""
165166 aligns = re .findall (r"align[1|2|4|8]" , op_name )
166167 assert len (aligns ) == 1
168+ # The same alignment is used for all axes
167169 align = int (aligns [0 ][- 1 ])
168- if M % align != 0 :
169- return False
170- return True
170+ # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
171+ # See https://github.com/NVIDIA/cutlass/issues/362.
172+ # When the above issue is resolved, we can remove the alignment check on M below.
173+ return M % align == 0 and K % align == 0
171174
172175 def get_default (self , out_dtype , batched = False ):
173176 """Return the default kernel for the requested architecture.
@@ -194,7 +197,7 @@ def profile(
194197 ops = GENERATOR_FUNC_TABLE [self .sm ](
195198 out_dtype , op_creator = partial (create_gemm_operator , batched = batched )
196199 )
197- ops = list (filter (lambda op : self .check_align (op ["name" ], M ), ops ))
200+ ops = list (filter (lambda op : self .check_align (op ["name" ], M , K ), ops ))
198201
199202 for op in ops :
200203 op ["runtime" ] = - 1
0 commit comments