diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 3887fc2e2e26..3d14a427b1a3 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, dangerous-default-value """Conv2d kernel generator and profiler for CUTLASS.""" +import os +import pickle from functools import partial from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler @@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue( tile_description, data_type, alignment, + alignment_epilogue, swizzling_functor, split_k_slices, ): @@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue( A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue) op = Conv2dOperation( conv_kind, @@ -110,6 +113,7 @@ def enumerate_conv2d_operators( conv_kind, stride_support, split_k_slices, + alignment_c, tile_descriptions, data_type, alignment_constraints, @@ -128,47 +132,49 @@ def enumerate_conv2d_operators( for split_k_slice in split_k_slices: for tile in tile_descriptions: - for alignment in alignment_constraints: - - A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) - B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) - - if element_c == DataType.s32 and A.alignment == 1: - tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) - tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) - - op = Conv2dOperation( - conv_kind, - IteratorAlgorithm.Optimized, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - stride_support, - EpilogueFunctor.LinearCombination, - swizzling_functor, - split_k_slice, - ) - - ret.append( - { - "src": profiler_emitter.emit( - kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), - op.procedural_name(), - element_output=element_c, - split_k_slices=split_k_slice, - ), - "name": op.procedural_name(), - "tile_description": tile, - "alignment": alignment, - "data_type": data_type, - "swizzle_functor": swizzling_functor, - "split_k_slices": split_k_slice, - } - ) + for alignmentAB in alignment_constraints: + for alignmentC in alignment_c: + + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignmentAB) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignmentAB) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignmentC) + + if element_c == DataType.s32 and A.alignment == 1: + tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) + tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) + + op = Conv2dOperation( + conv_kind, + IteratorAlgorithm.Optimized, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + stride_support, + EpilogueFunctor.LinearCombination, + swizzling_functor, + split_k_slice, + ) + + ret.append( + { + "src": profiler_emitter.emit( + kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), + op.procedural_name(), + element_output=element_c, + split_k_slices=split_k_slice, + ), + "name": op.procedural_name(), + "tile_description": tile, + "alignment": alignmentAB, + "alignment_epilogue": alignmentC, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + "split_k_slices": split_k_slice, + } + ) return ret @@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm assert sm in GENERATOR_FUNC_TABLE, f"sm{sm} not supported yet." self.engine = ProfilerEngine(sm, cutlass_path, binary_path) - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_conv2d_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( self, @@ -216,6 +226,7 @@ def get_default( tile_description, data_type, alignment, + alignment, swizzling_functor, split_k_slices=1, ) @@ -265,12 +276,32 @@ def select_op( if workload in self.cache: return self.cache[workload] + def alignments(dtype): + if dtype in ["float16"]: + alignments = [8, 4, 2, 1] + elif dtype in ["float", "float32"]: + alignments = [4, 2, 1] + else: + raise ValueError("Unsupported data type: %s" % dtype) + return alignments + + alignments_c = [align for align in alignments(out_dtype) if OC % align == 0] + + if not profile_all_alignments: + alignments_c = [alignments_c[0]] + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, data_dtype, weight_dtype, - partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices), - lambda align: all([dim % align == 0 for dim in [IC, OC]]), + partial( + enumerate_conv2d_operators, + conv_kind, + stride_support, + split_k_slices, + alignments_c, + ), + lambda align: all([dim % align == 0 for dim in [IC]]), use_3xtf32, profile_all_alignments, # Use fp32 accumulation for wgrad to align with cuDNN @@ -294,6 +325,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[workload] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( @@ -350,6 +383,7 @@ def profile( op["tile_description"], op["data_type"], op["alignment"], + op["alignment_epilogue"], op["swizzle_functor"], op["split_k_slices"], ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 78e19f510daa..ed963c2e6144 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,6 +16,9 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" +import os +import pickle + from .gemm_operation import EmitGemmInstance, GemmOperation from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine @@ -152,7 +155,11 @@ def __init__(self, sm, cutlass_path, binary_path): assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, f"sm{sm} not supported yet." self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.sm = sm - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_gemm_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False @@ -242,6 +249,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[(M, N, K)] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 855d8dc2d18d..b93de837d92f 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -213,8 +213,9 @@ def generate_sm80_tensor_op_16816( def get_default_tile_descriptions(block_k_factor): return [ - ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), @@ -228,6 +229,9 @@ def get_default_tile_descriptions(block_k_factor): ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc), + ([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc),