From fe839b36a7803a3b20365cf0cbc849c2a3762a38 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Dec 2021 16:05:50 +0900 Subject: [PATCH 1/4] [CUTLASS] Support residual block fusion for conv2d commit d4a78a3e13530974e852b4c0480b7c8d0f792e68 Author: Masahiro Masuda Date: Thu Dec 23 16:33:41 2021 +0900 fixed residual block check condition commit 6ee5a3913333e8ba2d5d0ed6842a58fe37baa547 Author: Masahiro Masuda Date: Thu Dec 23 16:25:04 2021 +0900 minor fix commit 8af8b3078f11ee293d2e22d9e37e715c617ffb75 Author: Masahiro Masuda Date: Thu Dec 23 16:18:50 2021 +0900 remove SimplifyExpr pass commit 20ae2d874917c69fabc6fcf03a3d47aff98eee91 Author: Masahiro Masuda Date: Thu Dec 23 16:16:46 2021 +0900 fix bad merge commit 17eed222c5e69e7863c95563b638e5390c634b1b Author: Masahiro Masuda Date: Thu Dec 23 16:13:53 2021 +0900 black commit fda151b524cb28581256befa74575bbfa23efa4c Author: Masahiro Masuda Date: Thu Dec 23 16:09:45 2021 +0900 Support residual block fusion commit ce9d52fd629d6119abdd471b00ff6a79223d6752 Author: Masahiro Masuda Date: Thu Dec 23 15:56:32 2021 +0900 Remove SimplifyExpr pass from the pipeline (makes DETR result nan) commit d3b681d95977b6fc0965a0a3ec8af3f866bd9e91 Author: Masahiro Masuda Date: Thu Dec 23 15:47:07 2021 +0900 fix no_beta_scaling values commit 87b36dbbb11adb582ffb628fc6ad62668dcdee7e Author: Masahiro Masuda Date: Thu Dec 23 14:59:40 2021 +0900 fill in TODO doc commit fd67595831c7b8741f30577bc91488bcce34a76a Author: Masahiro Masuda Date: Thu Dec 23 14:31:06 2021 +0900 Refactor cutlass kernel generation and selection --- 3rdparty/cutlass | 2 +- .../tvm/contrib/cutlass/conv2d_operation.py | 45 +++++++-- python/tvm/contrib/cutlass/gen_conv2d.py | 31 +++++- python/tvm/contrib/cutlass/library.py | 2 + python/tvm/relay/op/contrib/cutlass.py | 48 +++++++++- src/relay/backend/contrib/cutlass/codegen.cc | 96 +++++++++++++++++-- src/relay/backend/utils.h | 18 +++- tests/python/contrib/test_cutlass.py | 55 ++++++++++- 8 files changed, 268 insertions(+), 29 deletions(-) diff --git a/3rdparty/cutlass b/3rdparty/cutlass index dceabd4c5a2a..c2ee13a0fe99 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit dceabd4c5a2aa8cb29ce5a05311a57519baadddc +Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 1c7f9a31b955..5318cc7d74c4 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -150,6 +150,7 @@ def __init__(self): ${element_accumulator}, ${element_epilogue} >""" + self.epilogue_no_beta_scaling = """ ${epilogue_functor}< ${element_c}, @@ -159,10 +160,22 @@ def __init__(self): cutlass::epilogue::thread::ScaleType::NoBetaScaling >""" + self.epilogue_residual_block = """ + ${epilogue_functor}< + ${element_c}, + ${element_accumulator}, + ${element_epilogue}, + ${element_c}, + ${epilogue_vector_length}, + ${activation}, + ${binary_op}, + ${unary_op} + >""" + self.template = """ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" using ${operation_name} = - typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}< ${element_a}, ${layout_a}, ${element_b}, @@ -186,7 +199,7 @@ def __init__(self): >::Kernel; """ - def emit(self, operation, no_beta_scaling=False): + def emit(self, operation, no_beta_scaling=False, residual_block_info=False): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( @@ -246,14 +259,26 @@ def emit(self, operation, no_beta_scaling=False): ], "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), + "conv_kernel_postfix": "", } - template = substitute_template( - self.template, - { - "epilogue": self.epilogue_no_beta_scaling - if no_beta_scaling - else self.epilogue_default - }, - ) + if residual_block_info: + template = substitute_template( + self.template, {"epilogue": self.epilogue_residual_block} + ) + values.update( + { + "unary_op": residual_block_info["unary_op"], + "binary_op": residual_block_info["binary_op"], + "activation": residual_block_info["activation"], + "conv_kernel_postfix": "WithBroadcast", + } + ) + elif no_beta_scaling: + template = substitute_template( + self.template, {"epilogue": self.epilogue_no_beta_scaling} + ) + else: + template = substitute_template(self.template, {"epilogue": self.epilogue_default}) + return substitute_template(template, values) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 4e4a7b2458e2..39db9fd01319 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue( Instantiate a cutlass kernel from the given configuration, along with the epilouge functor """ - epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] + if "residual" in op_type: + activation_map = { + "cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish", + "cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu", + "cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid", + "cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu", + "cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity", + } + prefix = op_type[: op_type.find("_residual")] + activation = activation_map[prefix] + binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus" + unary_op = ( + "cutlass::epilogue::thread::ReLu" + if op_type.endswith("relu") + else "cutlass::epilogue::thread::Identity" + ) + residual_block_info = { + "activation": activation, + "binary_op": binary_op, + "unary_op": unary_op, + } + epilogue = EpilogueFunctor.LinearCombinationResidualBlock + no_beta_scaling = False + else: + residual_block_info = None + epilogue, no_beta_scaling = EPILOGUE_MAP[op_type] element_a, element_b, element_c, element_epilogue = data_type @@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue( ) name = op.procedural_name() - opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling) + opdef = EmitConv2dInstance().emit( + op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info + ) return name, opdef diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index efc5dd5ccd97..08cdb323c126 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum): LinearCombinationSigmoid = enum_auto() LinearCombinationSilu = enum_auto() LinearCombinationHardSwish = enum_auto() + LinearCombinationResidualBlock = enum_auto() EpilogueFunctorTag = { @@ -161,6 +162,7 @@ class EpilogueFunctor(enum.Enum): EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid", EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu", EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish", + EpilogueFunctor.LinearCombinationResidualBlock: "cutlass::epilogue::thread::LinearCombinationResidualBlock", } diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index cbbc45a5d1c0..31f0408c0f04 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name """Patterns supported CUTLASS.""" +from functools import partial from tvm import relay from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform @@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None): return conv2d_out +def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"): + """Add pattern for residual blocks.""" + residual_input = wildcard() + binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)( + residual_input, tensor_op_out + ) + + if with_act is not None and with_act == "relu": + return is_op("nn.relu")(binary_out) + + return binary_out + + def check_dtype(lhs, rhs): """Check if dtypes in the given workload are supported by CUTLASS.""" # Only fp16 inputs are supported for now. @@ -139,6 +153,25 @@ def check_conv2d(call): return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups) +def check_conv2d_residual(call, binary_op): + """Check if the given conv2d workload can be offloaded to CUTLASS.""" + conv2d = get_root_call(call, "nn.conv2d") + if not check_conv2d(call): + return False + + residual_binop = get_root_call(call, binary_op) + lhs = residual_binop.args[0] + rhs = residual_binop.args[1] + + # residual_input is pattern-matched as a wildcard. Make sure it does not sit between + # residual binary op and the root conv2d of this pattern. + # If the root conv2d is the parent of both lhs and rhs, we should reject this pattern. + if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d: + return False + + return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)) + + def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) @@ -189,7 +222,20 @@ def partition_for_cutlass(mod, params=None): ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] - cutlass_patterns = dense_patterns + conv2d_patterns + residual_block_patterns = [] + + for with_act, postfix in [("relu", "_relu"), (None, "")]: + for name, pat, _ in conv2d_patterns[:-1]: + for bin_op in ["add", "multiply"]: + residual_block_patterns.append( + ( + name + "_residual_" + bin_op + postfix, + make_residual_block_pattern(pat, bin_op, with_act=with_act), + partial(check_conv2d_residual, binary_op=bin_op), + ) + ) + + cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index a87ba2f2cf1d..8153f89f4e42 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -254,11 +254,18 @@ Str2StrMap Conv2dArgs(const Map& attrs) { args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); + + if (attrs.find("arg3_shape") != attrs.end()) { + auto arg3_shape = attrs["arg3_shape"].as(); + args["residual_N"] = GetDimAsStr(arg3_shape->at(0)); + args["residual_H"] = GetDimAsStr(arg3_shape->at(1)); + args["residual_W"] = GetDimAsStr(arg3_shape->at(2)); + } return args; } std::string Conv2dOp(std::string id, const Str2StrMap& attrs, - const std::vector& func_args) { + const std::vector& func_args, bool has_residual_block = false) { bool has_bias = attrs.at("op_type").find("bias") != std::string::npos; bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" && attrs.at("op_type") != "cutlass.conv2d_bias_silu" && @@ -268,8 +275,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); - CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(conv2d_decl, attrs.at("op_def")); CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") + " = cutlass::conv::device::ImplicitGemmConvolution<" + @@ -308,14 +315,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, ICHECK(func_args.size() >= 2); CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); - if (has_bias) { + if (has_residual_block) { + ICHECK(func_args.size() >= 4); + CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n"); + } else if (has_bias) { ICHECK(func_args.size() >= 3); CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); } CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); - if (has_bias && no_bias_scaling) { + if (has_bias && no_bias_scaling && !has_residual_block) { CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); } else { CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); @@ -325,25 +336,53 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n"); + + if (has_residual_block) { + if (attrs.at("P") == attrs.at("residual_H") && attrs.at("Q") == attrs.at("residual_W")) { + CutlassPrint(conv2d_decl, + "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); + } else { + ICHECK(attrs.at("residual_H") == "1" && attrs.at("residual_W") == "1"); + // Handle broadcast ops (MobilenetV3 and EfficientNetV2) in a residual block-like pattern + CutlassPrint(conv2d_decl, "// Broadcast in a residual block \n"); + CutlassPrint(conv2d_decl, "TensorNHWC layout_C(TensorNHWC(0, 0, K));\n\n"); + } + } else { + CutlassPrint(conv2d_decl, + "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); + } + CutlassPrint(conv2d_decl, - "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n"); + "TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); + CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); - if (has_bias) { + + if (has_residual_block) { + CutlassPrint(conv2d_decl, " {static_cast(ptr_residual), layout_C},\n"); + } else if (has_bias) { CutlassPrint( conv2d_decl, " {static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n"); } else { - CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_out), layout_C},\n"); } - CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); - if (has_bias && no_bias_scaling) { + + CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_D},\n"); + + if (has_residual_block) { + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices + CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); + CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); + } else if (has_bias && no_bias_scaling) { CutlassPrint(conv2d_decl, " {alpha}\n};\n"); } else { CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); } + CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); @@ -432,6 +471,20 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi return arg_names; } + bool IsConv2dResidualBlock(const std::string& func_name) { + return func_name.find("conv2d") != std::string::npos && + func_name.find("residual") != std::string::npos; + } + + bool IsAncestor(const CallNode* x, const CallNode* y) { + if (x == y) return true; + for (auto arg : y->args) { + const CallNode* arg_ptr = arg.as(); + if (arg_ptr && IsAncestor(x, arg_ptr)) return true; + } + return false; + } + GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, const CallNode* caller) { const auto pattern_name = callee->GetAttr(attr::kComposite); @@ -515,6 +568,27 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "multiply"}); return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); + } else if (IsConv2dResidualBlock(pattern_name.value())) { + const CallNode* current_call = callee->body.as(); + const CallNode* binop = + current_call->args.size() == 1 ? current_call->args[0].as() : current_call; + ICHECK(binop->args.size() == 2); + int residual_index; + if (binop->args[1].as()) { + residual_index = 1; + } else if (binop->args[0].as()) { + residual_index = 0; + } else { + const CallNode* lhs = binop->args[0].as(); + const CallNode* rhs = binop->args[1].as(); + ICHECK(lhs && rhs); + residual_index = IsAncestor(rhs, lhs) ? 1 : 0; + } + const auto* conv2d_call = + GetRootCall(binop->args[!residual_index].as(), "nn.conv2d"); + ICHECK(conv2d_call); + return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; @@ -560,6 +634,8 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); + } else if (IsConv2dResidualBlock(func_name)) { + ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true); } else if (func_name.find("conv2d") != std::string::npos) { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); } @@ -623,6 +699,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; ICHECK(ref->IsInstance()); auto res = GenCutlassFunc(Downcast(ref)); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index df25a8641792..cc1adff7ef55 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -389,7 +389,6 @@ inline bool IsOp(const CallNode* call, const std::string& op_name) { * "nn.relu"} * \return A CallNode corresponding to the root op, whose name is expected_op_names[0] */ - inline const CallNode* GetRootCall(const CallNode* current_call, int depth, const std::vector& expected_op_names) { ICHECK(current_call && depth >= 0 && static_cast(depth) < expected_op_names.size() && @@ -405,6 +404,23 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, return GetRootCall(next_call, depth - 1, expected_op_names); } +/*! + * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) + * \param call A Relay call node. + * \param op_name The name of an op to look for. + * "nn.relu"} + * \return A CallNode corresponding to the root op with the given op_name + */ +inline const CallNode* GetRootCall(const CallNode* current_call, const std::string& op_name) { + if (current_call == nullptr) return nullptr; + if (IsOp(current_call, op_name)) return current_call; + + ICHECK_GT(current_call->args.size(), 0); + + const auto* next_call = current_call->args[0].as(); + return GetRootCall(next_call, op_name); +} + /*! * \brief Get the external symbol of the Relay function name. * diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index b2bdb8ca91a0..54738ddd772b 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -130,6 +130,17 @@ def get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype="float16"): return relay.nn.bias_add(conv2d, bias) +def silu(x): + return x * relay.sigmoid(x) + + +def hardswish(x, out_dtype="float16"): + return x * ( + relay.clip(x + relay.const(3, dtype=out_dtype), a_min=0, a_max=6) + / relay.const(6, dtype=out_dtype) + ) + + def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"): return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)) @@ -140,15 +151,29 @@ def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16") def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"): conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype) - return conv_out * relay.sigmoid(conv_out) + return silu(conv_out) def get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, out_dtype="float16"): - conv2d_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype) - return conv2d_out * ( - relay.clip(conv2d_out + relay.const(3, dtype=out_dtype), a_min=0, a_max=6) - / relay.const(6, dtype=out_dtype) + conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype) + return hardswish(conv_out, out_dtype) + + +def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, out_dtype="float16"): + data = relay.var("data", shape=d_shape, dtype="float16") + weight = relay.var("weight", shape=w_shape, dtype="float16") + bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype) + out_channel = w_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + out_dtype=out_dtype, ) + bias_add = relay.nn.bias_add(conv2d, bias) + return bias_add, data def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False): @@ -492,5 +517,25 @@ def test_conv2d_fusion(): ) +def test_conv2d_residual_block(): + d_shape = (16, 16, 32, 32) + w_shape = (16, 16, 3, 3) + padding = (1, 1) + + bias_add, residual_input = get_conv2d_nchw_bias_residual(d_shape, w_shape, padding) + + for func, tol in [ + (relay.nn.relu(bias_add + residual_input), 1e-5), + (relay.nn.relu(bias_add) + residual_input, 1e-5), + (relay.sigmoid(bias_add) * residual_input, 1e-5), + (relay.nn.relu(silu(bias_add) * residual_input), 1e-5), + # HardSwish requires higher tolerance since vectoring the residual block epilogue + # in cutlass. + # TODO(masahi): Invesitigate this issue + (relay.nn.relu(hardswish(bias_add) + residual_input), 1e-3), + ]: + verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) + + if __name__ == "__main__": pytest.main([__file__]) From 467c6ad0dea89977cfb80cf6be9a92ab254e9461 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Dec 2021 20:34:36 +0900 Subject: [PATCH 2/4] do not try to support broadcast binary op --- src/relay/backend/contrib/cutlass/codegen.cc | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 8153f89f4e42..1ade0fe24d29 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -338,15 +338,9 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n"); if (has_residual_block) { - if (attrs.at("P") == attrs.at("residual_H") && attrs.at("Q") == attrs.at("residual_W")) { - CutlassPrint(conv2d_decl, - "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); - } else { - ICHECK(attrs.at("residual_H") == "1" && attrs.at("residual_W") == "1"); - // Handle broadcast ops (MobilenetV3 and EfficientNetV2) in a residual block-like pattern - CutlassPrint(conv2d_decl, "// Broadcast in a residual block \n"); - CutlassPrint(conv2d_decl, "TensorNHWC layout_C(TensorNHWC(0, 0, K));\n\n"); - } + ICHECK(attrs.at("P") == attrs.at("residual_H") && attrs.at("Q") == attrs.at("residual_W")); + CutlassPrint(conv2d_decl, + "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); } else { CutlassPrint(conv2d_decl, "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); From 35a7ca34bd4bbb7e3ead28c03cec77b2ac5a42ce Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Dec 2021 20:52:01 +0900 Subject: [PATCH 3/4] add comments --- src/relay/backend/contrib/cutlass/codegen.cc | 12 ++++++++---- src/relay/backend/utils.h | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 1ade0fe24d29..afc0d7524af7 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -470,6 +470,7 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi func_name.find("residual") != std::string::npos; } + // Is node `x` an ancestor of `y`? bool IsAncestor(const CallNode* x, const CallNode* y) { if (x == y) return true; for (auto arg : y->args) { @@ -564,9 +565,11 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi Conv2dArgs(std::ref(attrs_))); } else if (IsConv2dResidualBlock(pattern_name.value())) { const CallNode* current_call = callee->body.as(); - const CallNode* binop = - current_call->args.size() == 1 ? current_call->args[0].as() : current_call; + bool has_relu = current_call->args.size() == 1; + const CallNode* binop = has_relu ? current_call->args[0].as() : current_call; ICHECK(binop->args.size() == 2); + // Figure out which of the first or second argument corresponds to the residual input + // The root conv2d call can be reached via the other input of the binary op int residual_index; if (binop->args[1].as()) { residual_index = 1; @@ -576,10 +579,11 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi const CallNode* lhs = binop->args[0].as(); const CallNode* rhs = binop->args[1].as(); ICHECK(lhs && rhs); + // The residual input should be an ancestor of the non-residual input residual_index = IsAncestor(rhs, lhs) ? 1 : 0; } - const auto* conv2d_call = - GetRootCall(binop->args[!residual_index].as(), "nn.conv2d"); + const auto* non_residual_input = binop->args[!residual_index].as(); + const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d"); ICHECK(conv2d_call); return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cc1adff7ef55..658283b5dc36 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -406,9 +406,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, /*! * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) + * Unlike the previous definition, it does not verify operator names of intermediate nodes. Instead, + * it recursively visit child nodes until it finds a call node with the given op_name. * \param call A Relay call node. - * \param op_name The name of an op to look for. - * "nn.relu"} + * \param op_name The name of an op to look for, such as ""nn.conv2d". * \return A CallNode corresponding to the root op with the given op_name */ inline const CallNode* GetRootCall(const CallNode* current_call, const std::string& op_name) { From 0c1ffaed42a0e99c92cdc52ff30b2ca3e35bde3d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Dec 2021 21:28:05 +0900 Subject: [PATCH 4/4] remove residual input shape check --- src/relay/backend/contrib/cutlass/codegen.cc | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index afc0d7524af7..dc03eea014ab 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -254,13 +254,6 @@ Str2StrMap Conv2dArgs(const Map& attrs) { args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); - - if (attrs.find("arg3_shape") != attrs.end()) { - auto arg3_shape = attrs["arg3_shape"].as(); - args["residual_N"] = GetDimAsStr(arg3_shape->at(0)); - args["residual_H"] = GetDimAsStr(arg3_shape->at(1)); - args["residual_W"] = GetDimAsStr(arg3_shape->at(2)); - } return args; } @@ -336,16 +329,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n"); - - if (has_residual_block) { - ICHECK(attrs.at("P") == attrs.at("residual_H") && attrs.at("Q") == attrs.at("residual_W")); - CutlassPrint(conv2d_decl, - "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); - } else { - CutlassPrint(conv2d_decl, - "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); - } - + CutlassPrint(conv2d_decl, + "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n");