diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c919ff283343..fb59d02f9450 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -104,7 +104,7 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched, - profile_all, + find_first_valid, use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic @@ -126,10 +126,10 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched=batched, - profile_all=profile_all, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -146,7 +146,7 @@ def handle_batch_matmul( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" @@ -165,7 +165,7 @@ def handle_batch_matmul( arg1_dtype, use_3xtf32, True, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -191,7 +191,7 @@ def handle_dense( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" @@ -210,7 +210,7 @@ def handle_dense( arg1_dtype, use_3xtf32, False, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -237,7 +237,8 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, - profile_all, + profile_all_alignments, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" @@ -257,10 +258,11 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, - profile_all=profile_all, + profile_all_alignments, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -272,7 +274,13 @@ def handle_conv2d( def tune_cutlass_kernels( - mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" + mod, + sm, + use_3xtf32=True, + profile_all_alignments=False, + find_first_valid=False, + use_multiprocessing=False, + tmp_dir="./tmp", ): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -286,7 +294,14 @@ def tune_cutlass_kernels( An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. - profile_all : bool + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. + + profile_all_alignments : bool + When True, profile all kernal variants with smaller alignments than the largest possible. + + find_first_valid : bool Whether or not profile all candidate kernels, or stop profiling after the first applicable kernel is found. @@ -342,7 +357,8 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + profile_all_alignments, + find_first_valid, use_multiprocessing, ) ) @@ -357,7 +373,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) @@ -372,7 +388,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index c09017adfd95..b6dba009f2b2 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Conv2d kernel generator and profiler for CUTLASS.""" -import re from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter @@ -168,14 +167,6 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): ) return {"name": name, "opdef": opdef} - def check_align(self, op_name, C, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - return all([dim % align == 0 for dim in [C, K]]) - def select_op( self, d_shape, @@ -187,7 +178,8 @@ def select_op( data_dtype, weight_dtype, use_3xtf32, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -216,12 +208,16 @@ def select_op( return self.cache[workload] ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32 + out_dtype, + data_dtype, + weight_dtype, + enumerate_conv2d_operators, + lambda align: all([dim % align == 0 for dim in [IC, OC]]), + use_3xtf32, + profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) - - if profile_all: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) args = ( @@ -232,7 +228,7 @@ def select_op( for op in ops: out = self.engine.evaluate(op, args.split(" ")) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[workload] = op return op @@ -252,11 +248,12 @@ def profile( data_dtype, weight_dtype, use_3xtf32=True, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -269,8 +266,9 @@ def profile( data_dtype, weight_dtype, use_3xtf32, - profile_all=profile_all, - use_multiprocessing=use_multiprocessing, + profile_all_alignments, + find_first_valid, + use_multiprocessing, ) name, opdef = create_conv2d_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 445acb9305c8..bb591985cab5 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" -import re from .gemm_operation import GemmOperation, EmitGemmInstance from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP @@ -63,8 +62,9 @@ def create_gemm_operator_with_epilogue( swizzling_functor, ) - return op.procedural_name(), EmitGemmInstance().emit( - op, no_beta_scaling=no_beta_scaling, batched=batched + return ( + op.procedural_name(), + EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched), ) @@ -150,17 +150,6 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm self.cache = {} - def check_align(self, op_name, M, N, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. - # See https://github.com/NVIDIA/cutlass/issues/362. - # When the above issue is resolved, we can remove the alignment check on M below. - return all([dim % align == 0 for dim in [M, N, K]]) - def get_default( self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False ): @@ -168,8 +157,15 @@ def get_default( For now, the default kernel was picked arbitrary. """ ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 + out_dtype, + arg0_dtype, + arg1_dtype, + enumerate_gemm_operators, + lambda align: align == 1, # Only request align1 kernels + use_3xtf32, + profile_all_alignments=True, # To include all align1 kernels ) + default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] if arg0_dtype == "float32": @@ -200,7 +196,8 @@ def select_op( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -211,22 +208,27 @@ def select_op( op = self.cache[(M, N, K)] return op + # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. + # See https://github.com/NVIDIA/cutlass/issues/362. + # When the above issue is resolved, we can remove the alignment check on M below. + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, - use_3xtf32=use_3xtf32, + lambda align: all([dim % align == 0 for dim in [M, N, K]]), + use_3xtf32, + profile_all_alignments=profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) - if profile_all: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) for op in ops: out = self.engine.evaluate(op, [M, N, K]) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[(M, N, K)] = op return op @@ -244,12 +246,13 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32=True, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, batched=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -260,7 +263,8 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all=profile_all, + profile_all_alignments=profile_all_alignments, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6bb4f290233e..97af84e76990 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -62,7 +62,9 @@ def generate_tensor_op_common( return ops -def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): +def generate_sm75_tensor_op_1688( + out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False +): """Generate GEMM or Conv2D kernels for Turing.""" assert out_dtype in ["float32", "float16", "int32"] min_cc = 75 @@ -114,6 +116,12 @@ def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): ([64, 64, 64], 2, [2, 2, 1], min_cc, max_cc), ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + assert len(alignment_constraints) > 0 + + if not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + def get_tile_descriptions(math_inst): return [ TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) @@ -125,7 +133,15 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, use_3xtf32=True): +def generate_sm80_tensor_op_16816( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + use_3xtf32=True, + profile_all_alignments=False, +): """Generate GEMM or Conv2D kernels for Ampere.""" min_cc = 80 max_cc = 1024 @@ -218,15 +234,31 @@ def get_tile_descriptions(math_inst): for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + + if len(alignment_constraints) > 0 and not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + if arg0_dtype != "float32" and arg1_dtype != "float32": - sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator) + sm75_kernels = generate_sm75_tensor_op_1688( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + False, + profile_all_alignments, + ) else: # TF32 (float32 + float32 case) is only supported on sm80 sm75_kernels = [] - sm80_kernels = generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, op_creator - ) + if len(alignment_constraints) > 0: + sm80_kernels = generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator + ) + else: + sm80_kernels = [] # TODO(masahi): For int8 kernels, The CUTLASS generator modifies the output tensor alignment # after ops are created. Revisit how important this modification is. diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index da7c52ad119e..cf2787dda750 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -411,7 +411,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp or cu." object_format = module.format has_c_module = True else: @@ -426,7 +427,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp, or cu." object_format = module.format else: object_format = "c" diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 57f2f39c641b..00506ecf0527 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -188,7 +188,8 @@ def profile_and_build( mod, sm, use_3xtf32=use_3xtf32, - profile_all=False, + profile_all_alignments=False, + find_first_valid=True, use_multiprocessing=False, tmp_dir=tmp_dir, ) @@ -239,6 +240,9 @@ def verify_dense( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return + mod = tvm.IRModule.from_expr(func) typ = relay.transform.InferType()(mod)["main"].body.checked_type out_dtype = typ.dtype @@ -450,6 +454,8 @@ def verify_conv2d( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return mod_nchw = tvm.IRModule.from_expr(expr_nchw) mod_ref = tvm.IRModule.from_expr(expr_ref)