diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 3a9cbf1e8445..3d14a427b1a3 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -285,6 +285,11 @@ def alignments(dtype): raise ValueError("Unsupported data type: %s" % dtype) return alignments + alignments_c = [align for align in alignments(out_dtype) if OC % align == 0] + + if not profile_all_alignments: + alignments_c = [alignments_c[0]] + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, data_dtype, @@ -294,7 +299,7 @@ def alignments(dtype): conv_kind, stride_support, split_k_slices, - [align for align in alignments(out_dtype) if OC % align == 0], + alignments_c, ), lambda align: all([dim % align == 0 for dim in [IC]]), use_3xtf32,