Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self):
${element_accumulator},
${element_epilogue}
>"""

self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
Expand All @@ -159,10 +160,22 @@ def __init__(self):
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.epilogue_residual_block = """
${epilogue_functor}<
${element_c},
${element_accumulator},
${element_epilogue},
${element_c},
${epilogue_vector_length},
${activation},
${binary_op},
${unary_op}
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
${element_a},
${layout_a},
${element_b},
Expand All @@ -186,7 +199,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation, no_beta_scaling=False):
def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -246,14 +259,26 @@ def emit(self, operation, no_beta_scaling=False):
],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
"conv_kernel_postfix": "",
}

template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
if residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
"unary_op": residual_block_info["unary_op"],
"binary_op": residual_block_info["binary_op"],
"activation": residual_block_info["activation"],
"conv_kernel_postfix": "WithBroadcast",
}
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
31 changes: 29 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
Instantiate a cutlass kernel from the given configuration,
along with the epilouge functor
"""
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
if "residual" in op_type:
activation_map = {
"cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish",
"cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
"cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid",
"cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
"cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
}
prefix = op_type[: op_type.find("_residual")]
activation = activation_map[prefix]
binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus"
unary_op = (
"cutlass::epilogue::thread::ReLu"
if op_type.endswith("relu")
else "cutlass::epilogue::thread::Identity"
)
residual_block_info = {
"activation": activation,
"binary_op": binary_op,
"unary_op": unary_op,
}
epilogue = EpilogueFunctor.LinearCombinationResidualBlock
no_beta_scaling = False
else:
residual_block_info = None
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]

element_a, element_b, element_c, element_epilogue = data_type

Expand All @@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
)

name = op.procedural_name()
opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
opdef = EmitConv2dInstance().emit(
op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info
)

return name, opdef

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()
LinearCombinationResidualBlock = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -161,6 +162,7 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
EpilogueFunctor.LinearCombinationResidualBlock: "cutlass::epilogue::thread::LinearCombinationResidualBlock",
}


Expand Down
48 changes: 47 additions & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from functools import partial
from tvm import relay
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
Expand Down Expand Up @@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return conv2d_out


def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)(
residual_input, tensor_op_out
)

if with_act is not None and with_act == "relu":
return is_op("nn.relu")(binary_out)

return binary_out


def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
# Only fp16 inputs are supported for now.
Expand Down Expand Up @@ -139,6 +153,25 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
if not check_conv2d(call):
return False

residual_binop = get_root_call(call, binary_op)
lhs = residual_binop.args[0]
rhs = residual_binop.args[1]

# residual_input is pattern-matched as a wildcard. Make sure it does not sit between
# residual binary op and the root conv2d of this pattern.
# If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d:
return False

return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape))


def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
Expand Down Expand Up @@ -189,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]

cutlass_patterns = dense_patterns + conv2d_patterns
residual_block_patterns = []

for with_act, postfix in [("relu", "_relu"), (None, "")]:
for name, pat, _ in conv2d_patterns[:-1]:
for bin_op in ["add", "multiply"]:
residual_block_patterns.append(
(
name + "_residual_" + bin_op + postfix,
make_residual_block_pattern(pat, bin_op, with_act=with_act),
partial(check_conv2d_residual, binary_op=bin_op),
)
)

cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns

if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
Expand Down
79 changes: 70 additions & 9 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
}

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
const std::vector<std::string>& func_args, bool has_residual_block = false) {
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" &&
attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
Expand All @@ -268,8 +268,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");

CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");

CutlassPrint(conv2d_decl, attrs.at("op_def"));
CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
" = cutlass::conv::device::ImplicitGemmConvolution<" +
Expand Down Expand Up @@ -308,14 +308,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
if (has_bias) {
if (has_residual_block) {
ICHECK(func_args.size() >= 4);
CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n");
} else if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
if (has_bias && no_bias_scaling) {
if (has_bias && no_bias_scaling && !has_residual_block) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n");
Expand All @@ -326,24 +330,38 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl,
"TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n");
"TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n");

CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
if (has_bias) {

if (has_residual_block) {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual), layout_C},\n");
} else if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out), layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
if (has_bias && no_bias_scaling) {

CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_D},\n");

if (has_residual_block) {
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices
CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
} else if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}

CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand Down Expand Up @@ -432,6 +450,21 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
return arg_names;
}

bool IsConv2dResidualBlock(const std::string& func_name) {
return func_name.find("conv2d") != std::string::npos &&
func_name.find("residual") != std::string::npos;
}

// Is node `x` an ancestor of `y`?
bool IsAncestor(const CallNode* x, const CallNode* y) {
if (x == y) return true;
for (auto arg : y->args) {
const CallNode* arg_ptr = arg.as<CallNode>();
if (arg_ptr && IsAncestor(x, arg_ptr)) return true;
}
return false;
}

GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
Expand Down Expand Up @@ -515,6 +548,30 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (IsConv2dResidualBlock(pattern_name.value())) {
const CallNode* current_call = callee->body.as<CallNode>();
bool has_relu = current_call->args.size() == 1;
const CallNode* binop = has_relu ? current_call->args[0].as<CallNode>() : current_call;
ICHECK(binop->args.size() == 2);
// Figure out which of the first or second argument corresponds to the residual input
// The root conv2d call can be reached via the other input of the binary op
int residual_index;
if (binop->args[1].as<VarNode>()) {
residual_index = 1;
} else if (binop->args[0].as<VarNode>()) {
residual_index = 0;
} else {
const CallNode* lhs = binop->args[0].as<CallNode>();
const CallNode* rhs = binop->args[1].as<CallNode>();
ICHECK(lhs && rhs);
// The residual input should be an ancestor of the non-residual input
residual_index = IsAncestor(rhs, lhs) ? 1 : 0;
}
const auto* non_residual_input = binop->args[!residual_index].as<CallNode>();
const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d");
ICHECK(conv2d_call);
return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -560,6 +617,8 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (IsConv2dResidualBlock(func_name)) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true);
} else if (func_name.find("conv2d") != std::string::npos) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}
Expand Down Expand Up @@ -623,6 +682,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_silu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
code_stream_ << "#include <cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
Loading