From 194c01b3b3f823d5ebc7db62546aea901746fb63 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 11 Mar 2023 19:28:56 -0800 Subject: [PATCH 1/6] conv2d alignment --- python/tvm/contrib/cutlass/gen_conv2d.py | 118 ++++++++++++-------- python/tvm/contrib/cutlass/gen_gemm.py | 10 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 14 ++- 3 files changed, 91 insertions(+), 51 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index bb26a47a5548..15d19cf5ae0f 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, dangerous-default-value """Conv2d kernel generator and profiler for CUTLASS.""" +import os +import pickle from functools import partial from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler @@ -40,6 +42,7 @@ def create_conv2d_operator_with_epilogue( tile_description, data_type, alignment, + alignment_epilogue, swizzling_functor, split_k_slices, ): @@ -78,7 +81,7 @@ def create_conv2d_operator_with_epilogue( A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment_epilogue) op = Conv2dOperation( conv_kind, @@ -110,6 +113,7 @@ def enumerate_conv2d_operators( conv_kind, stride_support, split_k_slices, + alignment_c, tile_descriptions, data_type, alignment_constraints, @@ -128,47 +132,49 @@ def enumerate_conv2d_operators( for split_k_slice in split_k_slices: for tile in tile_descriptions: - for alignment in alignment_constraints: - - A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) - B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) - - if element_c == DataType.s32 and A.alignment == 1: - tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) - tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) - - op = Conv2dOperation( - conv_kind, - IteratorAlgorithm.Optimized, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - stride_support, - EpilogueFunctor.LinearCombination, - swizzling_functor, - split_k_slice, - ) - - ret.append( - { - "src": profiler_emitter.emit( - kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), - op.procedural_name(), - element_output=element_c, - split_k_slices=split_k_slice, - ), - "name": op.procedural_name(), - "tile_description": tile, - "alignment": alignment, - "data_type": data_type, - "swizzle_functor": swizzling_functor, - "split_k_slices": split_k_slice, - } - ) + for alignmentAB in alignment_constraints: + for alignmentC in alignment_c: + + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignmentAB) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignmentAB) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignmentC) + + if element_c == DataType.s32 and A.alignment == 1: + tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128) + tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128) + + op = Conv2dOperation( + conv_kind, + IteratorAlgorithm.Optimized, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + stride_support, + EpilogueFunctor.LinearCombination, + swizzling_functor, + split_k_slice, + ) + + ret.append( + { + "src": profiler_emitter.emit( + kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), + op.procedural_name(), + element_output=element_c, + split_k_slices=split_k_slice, + ), + "name": op.procedural_name(), + "tile_description": tile, + "alignment": alignmentAB, + "alignment_epilogue": alignmentC, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + "split_k_slices": split_k_slice, + } + ) return ret @@ -181,7 +187,11 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm self.engine = ProfilerEngine(sm, cutlass_path, binary_path) - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_conv2d_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( self, @@ -265,12 +275,27 @@ def select_op( if workload in self.cache: return self.cache[workload] + def alignments(dtype): + if dtype in ["float16"]: + alignments = [8, 4, 2, 1] + elif dtype in ["float", "float32"]: + alignments = [4, 2, 1] + else: + raise ValueError("Unsupported data type: %s" % dtype) + return alignments + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, data_dtype, weight_dtype, - partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices), - lambda align: all([dim % align == 0 for dim in [IC, OC]]), + partial( + enumerate_conv2d_operators, + conv_kind, + stride_support, + split_k_slices, + [align for align in alignments(out_dtype) if OC % align == 0], + ), + lambda align: all([dim % align == 0 for dim in [IC]]), use_3xtf32, profile_all_alignments, # Use fp32 accumulation for wgrad to align with cuDNN @@ -294,6 +319,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[workload] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( @@ -350,6 +377,7 @@ def profile( op["tile_description"], op["data_type"], op["alignment"], + op["alignment_epilogue"], op["swizzle_functor"], op["split_k_slices"], ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index f5f160a4000a..0ea6231b8196 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" +import os +import pickle from functools import partial from .gemm_operation import EmitGemmInstance, GemmOperation @@ -195,7 +197,11 @@ def __init__(self, sm, cutlass_path, binary_path): assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm self.engine = ProfilerEngine(sm, cutlass_path, binary_path) self.sm = sm - self.cache = {} + self.cache_path = os.path.join(binary_path, "cutlass_gemm_cache.pickle") + if os.path.exists(self.cache_path): + self.cache = pickle.load(open(self.cache_path, "rb")) + else: + self.cache = {} def get_default( self, @@ -294,6 +300,8 @@ def select_op( op = min(ops, key=lambda i: i["runtime"]) self.cache[(M, N, K)] = op + with open(self.cache_path, "wb") as f: + pickle.dump(self.cache, f) return op def profile( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6b2587a0b0f1..58ed2e5519dc 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -230,8 +230,9 @@ def generate_sm80_tensor_op_16816( def get_default_tile_descriptions(block_k_factor): return [ - ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), ([128, 256, int(32 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 128, int(32 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([256, 64, int(32 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), ([256, 64, int(32 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), ([64, 256, int(32 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), ([128, 128, int(32 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), @@ -240,11 +241,14 @@ def get_default_tile_descriptions(block_k_factor): ([128, 64, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), ([64, 128, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), ([64, 64, int(32 * block_k_factor)], 10, [2, 2, 1], min_cc, max_cc), - ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc_smem_limited), - ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc_smem_limited), - ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), - ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), + ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), + ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), + ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), + ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), + ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc), + ([128, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([128, 64, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 128, int(64 * block_k_factor)], 3, [2, 2, 1], min_cc, max_cc), ([64, 64, int(64 * block_k_factor)], 5, [2, 2, 1], min_cc, max_cc), From 4bf1b147245358c80f791f445c1d9e42b4742011 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 11 Mar 2023 19:30:21 -0800 Subject: [PATCH 2/6] conv2d alignment --- python/tvm/contrib/cutlass/gen_tensor_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 58ed2e5519dc..177304e10229 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -241,10 +241,10 @@ def get_default_tile_descriptions(block_k_factor): ([128, 64, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), ([64, 128, int(32 * block_k_factor)], 6, [2, 2, 1], min_cc, max_cc), ([64, 64, int(32 * block_k_factor)], 10, [2, 2, 1], min_cc, max_cc), - ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc), - ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc), - ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc), - ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc), + ([256, 128, int(64 * block_k_factor)], 3, [4, 2, 1], min_cc, max_cc_smem_limited), + ([128, 256, int(64 * block_k_factor)], 3, [2, 4, 1], min_cc, max_cc_smem_limited), + ([256, 64, int(64 * block_k_factor)], 4, [4, 1, 1], min_cc, max_cc_smem_limited), + ([64, 256, int(64 * block_k_factor)], 4, [1, 4, 1], min_cc, max_cc_smem_limited), ([128, 128, int(64 * block_k_factor)], 4, [2, 2, 1], min_cc, max_cc), ([256, 64, int(64 * block_k_factor)], 3, [4, 1, 1], min_cc, max_cc), ([64, 256, int(64 * block_k_factor)], 3, [1, 4, 1], min_cc, max_cc), From ad62a846e1f7aba86104e37a2aa1971ea719c6ad Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sun, 12 Mar 2023 11:53:52 -0700 Subject: [PATCH 3/6] gemm refactor --- python/tvm/contrib/cutlass/gemm_operation.py | 8 ++--- python/tvm/contrib/cutlass/gen_conv2d.py | 29 ++------------- python/tvm/contrib/cutlass/gen_gemm.py | 37 +++----------------- python/tvm/contrib/cutlass/gen_tensor_op.py | 31 ++++++++++++++++ 4 files changed, 40 insertions(+), 65 deletions(-) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index eb9f92dad39a..8c8fc16c08d7 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -173,8 +173,7 @@ def __init__(self): ${element_epilogue}, cutlass::epilogue::thread::ScaleType::NoBetaScaling >""" - - self.epilogue_residual_block = """ + self.epilogue_residual = """ ${epilogue_functor}< ${element_c}, ${element_accumulator}, @@ -185,7 +184,6 @@ def __init__(self): ${binary_op}, ${unary_op} >""" - self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}< @@ -206,7 +204,7 @@ def __init__(self): >; """ - def emit(self, operation, no_beta_scaling=False, batched=False, residual_block_info=False): + def emit(self, operation, no_beta_scaling=False, residual_block_info=False, batched=False): """Instantiate a GEMM kernel from given `operation`.""" warp_shape = [ operation.tile_description.threadblock_shape[idx] @@ -277,7 +275,7 @@ def emit(self, operation, no_beta_scaling=False, batched=False, residual_block_i ) else: template = substitute_template(self.gemm_template, {"epilogue": self.epilogue_default}) - +g return substitute_template(template, values) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 15d19cf5ae0f..9bd47a7b34de 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -22,7 +22,7 @@ from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter -from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP +from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, epilogue_creator from .library import ( DataType, EpilogueFunctor, @@ -50,32 +50,7 @@ def create_conv2d_operator_with_epilogue( Instantiate a cutlass kernel from the given configuration, along with the epilouge functor """ - if "residual" in op_type: - activation_map = { - "cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish", - "cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu", - "cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid", - "cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu", - "cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity", - } - prefix = op_type[: op_type.find("_residual")] - activation = activation_map[prefix] - binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" - unary_op = ( - "cutlass::epilogue::thread::ReLu" - if op_type.endswith("relu") - else "cutlass::epilogue::thread::Identity" - ) - residual_block_info = { - "activation": activation, - "binary_op": binary_op, - "unary_op": unary_op, - } - epilogue = EpilogueFunctor.LinearCombinationResidualBlock - no_beta_scaling = False - else: - residual_block_info = None - epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + residual_block_info, epilogue, no_beta_scaling = epilogue_creator(op_type) element_a, element_b, element_c, element_epilogue = data_type diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 0ea6231b8196..750c94124d06 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -22,7 +22,7 @@ from .gemm_operation import EmitGemmInstance, GemmOperation from .gemm_profiler import GemmProfilerEmitter -from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine +from .gen_tensor_op import epilogue_creator, GENERATOR_FUNC_TABLE, ProfilerEngine from .library import ( DataType, DataTypeTag, @@ -46,6 +46,8 @@ def create_gemm_operator_with_epilogue( Instantiate a cutlass kernel from the given configuration, along with the epilouge functor """ + residual_block_info, epilogue, no_beta_scaling = epilogue_creator(op_type) + element_a, element_b, element_c, element_epilogue = data_type A = TensorDescription(element_a, LayoutType.RowMajor, alignment) @@ -55,37 +57,6 @@ def create_gemm_operator_with_epilogue( if batched: swizzling_functor = SwizzlingFunctor.Batched - if "residual" in op_type: - if "hardswish" in op_type: - activation = "cutlass::epilogue::thread::HardSwish" - elif "silu" in op_type: - activation = "cutlass::epilogue::thread::SiLu" - elif "sigmoid" in op_type: - activation = "cutlass::epilogue::thread::Sigmoid" - elif "gelu" in op_type: - activation = "cutlass::epilogue::thread::GELU" - elif "relu" in op_type: - activation = "cutlass::epilogue::thread::ReLu" - else: - activation = "cutlass::epilogue::thread::Identity" - - binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" - unary_op = ( - "cutlass::epilogue::thread::ReLu" - if op_type.endswith("relu") - else "cutlass::epilogue::thread::Identity" - ) - residual_block_info = { - "activation": activation, - "binary_op": binary_op, - "unary_op": unary_op, - } - epilogue = EpilogueFunctor.LinearCombinationResidualBlock - no_beta_scaling = False - else: - residual_block_info = None - epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] - op = GemmOperation( tile_description.minimum_compute_capability, tile_description, @@ -102,8 +73,8 @@ def create_gemm_operator_with_epilogue( EmitGemmInstance().emit( op, no_beta_scaling=no_beta_scaling, - batched=batched, residual_block_info=residual_block_info, + batched=batched, ), ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 177304e10229..c05ac47b5cec 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -367,6 +367,36 @@ def get_tile_descriptions(math_inst): } +def epilogue_creator(op_type): + if "residual" in op_type: + activation_map = { + "_bias_hardswish": "cutlass::epilogue::thread::HardSwish", + "_bias_silu": "cutlass::epilogue::thread::SiLu", + "_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid", + "_bias_relu": "cutlass::epilogue::thread::ReLu", + "_bias": "cutlass::epilogue::thread::Identity", + } + prefix = op_type[op_type.find("_bias") : op_type.find("_residual")] + activation = activation_map[prefix] + binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" + unary_op = ( + "cutlass::epilogue::thread::ReLu" + if op_type.endswith("relu") + else "cutlass::epilogue::thread::Identity" + ) + residual_block_info = { + "activation": activation, + "binary_op": binary_op, + "unary_op": unary_op, + } + epilogue = EpilogueFunctor.LinearCombinationResidualBlock + no_beta_scaling = False + else: + residual_block_info = None + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + return residual_block_info, epilogue, no_beta_scaling + + # (Epilogue functor name, no_beta_scaling) EPILOGUE_MAP = { "cutlass.dense": (EpilogueFunctor.LinearCombination, False), @@ -537,6 +567,7 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ if "dense" in func_name or "matmul" in func_name: batched = "batch" in annotations transposed = "transposed" in func_name + lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) From c315646d56d445155073e73c4e3c3383d080bb60 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 17 Mar 2023 19:08:21 -0700 Subject: [PATCH 4/6] . --- python/tvm/contrib/cutlass/gemm_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 8c8fc16c08d7..e13af5006dd9 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -275,7 +275,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False, batc ) else: template = substitute_template(self.gemm_template, {"epilogue": self.epilogue_default}) -g + return substitute_template(template, values) From cc6b2d135ce3ac1ae0b499b0c092a8bee6ffc210 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 17 Mar 2023 19:14:34 -0700 Subject: [PATCH 5/6] . --- python/tvm/contrib/cutlass/gemm_operation.py | 6 ++-- python/tvm/contrib/cutlass/gen_conv2d.py | 29 +++++++++++++-- python/tvm/contrib/cutlass/gen_gemm.py | 37 +++++++++++++++++--- python/tvm/contrib/cutlass/gen_tensor_op.py | 31 ---------------- 4 files changed, 64 insertions(+), 39 deletions(-) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index e13af5006dd9..eb9f92dad39a 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -173,7 +173,8 @@ def __init__(self): ${element_epilogue}, cutlass::epilogue::thread::ScaleType::NoBetaScaling >""" - self.epilogue_residual = """ + + self.epilogue_residual_block = """ ${epilogue_functor}< ${element_c}, ${element_accumulator}, @@ -184,6 +185,7 @@ def __init__(self): ${binary_op}, ${unary_op} >""" + self.gemm_template = """ // Gemm operator ${operation_name} using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}< @@ -204,7 +206,7 @@ def __init__(self): >; """ - def emit(self, operation, no_beta_scaling=False, residual_block_info=False, batched=False): + def emit(self, operation, no_beta_scaling=False, batched=False, residual_block_info=False): """Instantiate a GEMM kernel from given `operation`.""" warp_shape = [ operation.tile_description.threadblock_shape[idx] diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 9bd47a7b34de..15d19cf5ae0f 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -22,7 +22,7 @@ from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter -from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, epilogue_creator +from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP from .library import ( DataType, EpilogueFunctor, @@ -50,7 +50,32 @@ def create_conv2d_operator_with_epilogue( Instantiate a cutlass kernel from the given configuration, along with the epilouge functor """ - residual_block_info, epilogue, no_beta_scaling = epilogue_creator(op_type) + if "residual" in op_type: + activation_map = { + "cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish", + "cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu", + "cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid", + "cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu", + "cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity", + } + prefix = op_type[: op_type.find("_residual")] + activation = activation_map[prefix] + binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" + unary_op = ( + "cutlass::epilogue::thread::ReLu" + if op_type.endswith("relu") + else "cutlass::epilogue::thread::Identity" + ) + residual_block_info = { + "activation": activation, + "binary_op": binary_op, + "unary_op": unary_op, + } + epilogue = EpilogueFunctor.LinearCombinationResidualBlock + no_beta_scaling = False + else: + residual_block_info = None + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] element_a, element_b, element_c, element_epilogue = data_type diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 750c94124d06..0ea6231b8196 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -22,7 +22,7 @@ from .gemm_operation import EmitGemmInstance, GemmOperation from .gemm_profiler import GemmProfilerEmitter -from .gen_tensor_op import epilogue_creator, GENERATOR_FUNC_TABLE, ProfilerEngine +from .gen_tensor_op import EPILOGUE_MAP, GENERATOR_FUNC_TABLE, ProfilerEngine from .library import ( DataType, DataTypeTag, @@ -46,8 +46,6 @@ def create_gemm_operator_with_epilogue( Instantiate a cutlass kernel from the given configuration, along with the epilouge functor """ - residual_block_info, epilogue, no_beta_scaling = epilogue_creator(op_type) - element_a, element_b, element_c, element_epilogue = data_type A = TensorDescription(element_a, LayoutType.RowMajor, alignment) @@ -57,6 +55,37 @@ def create_gemm_operator_with_epilogue( if batched: swizzling_functor = SwizzlingFunctor.Batched + if "residual" in op_type: + if "hardswish" in op_type: + activation = "cutlass::epilogue::thread::HardSwish" + elif "silu" in op_type: + activation = "cutlass::epilogue::thread::SiLu" + elif "sigmoid" in op_type: + activation = "cutlass::epilogue::thread::Sigmoid" + elif "gelu" in op_type: + activation = "cutlass::epilogue::thread::GELU" + elif "relu" in op_type: + activation = "cutlass::epilogue::thread::ReLu" + else: + activation = "cutlass::epilogue::thread::Identity" + + binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" + unary_op = ( + "cutlass::epilogue::thread::ReLu" + if op_type.endswith("relu") + else "cutlass::epilogue::thread::Identity" + ) + residual_block_info = { + "activation": activation, + "binary_op": binary_op, + "unary_op": unary_op, + } + epilogue = EpilogueFunctor.LinearCombinationResidualBlock + no_beta_scaling = False + else: + residual_block_info = None + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + op = GemmOperation( tile_description.minimum_compute_capability, tile_description, @@ -73,8 +102,8 @@ def create_gemm_operator_with_epilogue( EmitGemmInstance().emit( op, no_beta_scaling=no_beta_scaling, - residual_block_info=residual_block_info, batched=batched, + residual_block_info=residual_block_info, ), ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index c05ac47b5cec..177304e10229 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -367,36 +367,6 @@ def get_tile_descriptions(math_inst): } -def epilogue_creator(op_type): - if "residual" in op_type: - activation_map = { - "_bias_hardswish": "cutlass::epilogue::thread::HardSwish", - "_bias_silu": "cutlass::epilogue::thread::SiLu", - "_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid", - "_bias_relu": "cutlass::epilogue::thread::ReLu", - "_bias": "cutlass::epilogue::thread::Identity", - } - prefix = op_type[op_type.find("_bias") : op_type.find("_residual")] - activation = activation_map[prefix] - binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" - unary_op = ( - "cutlass::epilogue::thread::ReLu" - if op_type.endswith("relu") - else "cutlass::epilogue::thread::Identity" - ) - residual_block_info = { - "activation": activation, - "binary_op": binary_op, - "unary_op": unary_op, - } - epilogue = EpilogueFunctor.LinearCombinationResidualBlock - no_beta_scaling = False - else: - residual_block_info = None - epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] - return residual_block_info, epilogue, no_beta_scaling - - # (Epilogue functor name, no_beta_scaling) EPILOGUE_MAP = { "cutlass.dense": (EpilogueFunctor.LinearCombination, False), @@ -567,7 +537,6 @@ def get_batch_stride(stride_annot, arg0_idx, arg1_idx, arg0_axis_idx, arg1_axis_ if "dense" in func_name or "matmul" in func_name: batched = "batch" in annotations transposed = "transposed" in func_name - lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0) rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1) bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None) From 745ef3c9afababd5d22dccf6972f044f6c83575c Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 18 Mar 2023 08:40:54 -0700 Subject: [PATCH 6/6] lint --- python/tvm/contrib/cutlass/gen_conv2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 15d19cf5ae0f..9e9e16426ba6 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -226,6 +226,7 @@ def get_default( tile_description, data_type, alignment, + alignment, swizzling_functor, split_k_slices=1, )