Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,20 @@ def conv_backward_filter(
x.shape[0], tvm.tir.expr.IntImm
), "Dynamic batch is not supported for cudnn conv2d backwad filter yet."

ic_ind = 1 if tensor_format == 0 else 3

if groups > 1:
assert (
x_shape[ic_ind] == dy.shape[ic_ind] and x_shape[ic_ind] == groups
), "Only depthwise wgrad supported for groups > 1."
ic = 1
else:
ic = x_shape[ic_ind]

if tensor_format == 0:
dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w]
dw_shape = [dy.shape[1], ic, filter_h, filter_w]
else:
dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]]
dw_shape = [dy.shape[3], filter_h, filter_w, ic]

algo = conv_backward_filter_find_algo(
tensor_format,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we force the accum dtype to be fp32 if wgrad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes lots of sense. I'm even wondering whether we should force accum dtype to be fp32 for all ConvKind instead of just wgrad.

Copy link
Member Author

@masahi masahi Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always using fp32 accum for other conv2d kind will probably bring perf regression (cuDNN also allows fp16 accumulation for fprop and dgrad). Ideally we should add accumulation_dtype to Conv2dAttr to guide that decision, I thought about doing that, but I realized that I have to change a lot of topi to take that into account.

Also we need to discuss what the interface for ToMixedPrecision should be if we want to allow changing accumulation dtype, right now we cannot flexibly change even output dtype @AndrewZhaoLuo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Make sense.
We discussed about the accum dtype before when @AndrewZhaoLuo was working on the ToMixedPrecision pass, but just like you pointed out, this will involve lots of TOPI changes.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.

ToMixedPrecision can configure accumulation and output dtypes for any call node but only using information from examining that node. I'm not sure implementation details like whether it's winograd can be transmitted here.

I will say on relay level all we care about is type checking imo so just get the output_dtype correct. For example, accumulate all you like in fp32 but internally just make sure the output fits the expected type written in interface. Perhaps extraneous cast here is bad but maybe we can repair it further down in topi-tir level.

Copy link
Member Author

@masahi masahi Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I agree that we can let out_dtype in conv2d essentially act like the accumulation data type, and add an explicit cast op if accum_dtype != out_dtype. That's fine as far as ToMixedPrecision pass goes, but for cutlass BYOC, we need to additionally pattern match against the added cast(fp32 -> fp16) op to know that this conv2d is fp32 accum -> fp16 out. And for cuDNN which is not implemented as BYOC, this doesn't work because all it sees is fp32 accum -> fp32 out conv2d. And cuDNN wgrad doesn't support such dtype combination.

Hmm yeah, so if I'm understanding correctly for conv2d_winograd we want to accumulate to fp32 but if it's not winograd we are ok with accumulating to fp16.

@AndrewZhaoLuo Here wgrad means "conv2d gradient with respect to weight", not winograd :)

)

if not find_first_valid:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def get_default(
lambda align: align == 1, # Only request align1 kernels
use_3xtf32,
profile_all_alignments=True, # To include all align1 kernels
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
accumlator_dtype=out_dtype,
)

default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]
Expand Down Expand Up @@ -220,6 +222,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [M, N, K]]),
use_3xtf32,
profile_all_alignments=profile_all_alignments,
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
accumlator_dtype=out_dtype,
)

if not find_first_valid:
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def generate_tensor_op_common(
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_c,
math_inst.element_accumulator,
]

Expand All @@ -63,7 +63,14 @@ def generate_tensor_op_common(


def generate_sm75_tensor_op_1688(
out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False
out_dtype,
arg0_dtype,
arg1_dtype,
op_creator,
check_align,
_,
profile_all_alignments=False,
accumlator_dtype="float32",
):
"""Generate GEMM or Conv2D kernels for Turing."""
assert out_dtype in ["float32", "float16", "int32"]
Expand All @@ -77,6 +84,7 @@ def generate_sm75_tensor_op_1688(
DataType.f16,
DataType.f16,
dtype_map[out_dtype],
dtype_map[accumlator_dtype],
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
Expand All @@ -100,6 +108,7 @@ def generate_sm75_tensor_op_1688(
dtype_map[arg0_dtype],
dtype_map[arg1_dtype],
DataType.s32,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
Expand Down Expand Up @@ -141,6 +150,7 @@ def generate_sm80_tensor_op_16816(
check_align,
use_3xtf32=True,
profile_all_alignments=False,
accumlator_dtype="float32",
):
"""Generate GEMM or Conv2D kernels for Ampere."""
min_cc = 80
Expand Down Expand Up @@ -176,6 +186,7 @@ def get_default_tile_descriptions(block_k_factor):
DataType.f16,
DataType.f16,
dtype_map[out_dtype],
dtype_map[accumlator_dtype],
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
Expand All @@ -189,6 +200,7 @@ def get_default_tile_descriptions(block_k_factor):
DataType.f32,
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add,
),
Expand Down Expand Up @@ -221,6 +233,7 @@ def get_default_tile_descriptions(block_k_factor):
dtype_map[arg0_dtype],
dtype_map[arg1_dtype],
DataType.s32,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
Expand Down Expand Up @@ -248,6 +261,7 @@ def get_tile_descriptions(math_inst):
check_align,
False,
profile_all_alignments,
accumlator_dtype=accumlator_dtype,
)
else:
# TF32 (float32 + float32 case) is only supported on sm80
Expand Down Expand Up @@ -292,6 +306,7 @@ def get_tile_descriptions(math_inst):
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_backward_weight": (EpilogueFunctor.LinearCombination, False),
}


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ def __init__(
instruction_shape,
element_a,
element_b,
element_c,
element_accumulator,
opcode_class,
math_operation=MathOperation.multiply_add,
):
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_c = element_c
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def make_conv2d_transpose_pattern():
return is_op("nn.conv2d_transpose")(wildcard(), wildcard())


def make_conv2d_backward_weight_pattern():
return is_op("nn.conv2d_backward_weight")(wildcard(), wildcard())


def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
Expand Down Expand Up @@ -173,6 +177,10 @@ def check_conv2d_transpose(call):
return check_conv2d_common("nn.conv2d_transpose", "IHWO", call)


def check_conv2d_backward_weight(call):
return check_conv2d_common("nn.conv2d_backward_weight", "NHWC", call)


def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
Expand Down Expand Up @@ -245,6 +253,11 @@ def partition_for_cutlass(mod, params=None):
# For now, no fusion for grad kernels
conv2d_grad_patterns = [
("cutlass.conv2d_transpose", make_conv2d_transpose_pattern(), check_conv2d_transpose),
(
"cutlass.conv2d_backward_weight",
make_conv2d_backward_weight_pattern(),
check_conv2d_backward_weight,
),
]

residual_block_patterns = []
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
out_dtype=attrs.out_dtype,
)

# infer shape of backward_weight
Expand Down Expand Up @@ -1143,6 +1144,40 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
return backward_weight


@reg.register_convert_op_layout("nn.conv2d_backward_weight")
def convert_conv2d_backward_weight(attrs, inputs, _, desired_layouts):
"""Convert Layout pass registration for conv2d_backward_weight op.
Note that `desired_layouts` must be a pair [`data_layout`, `kernel_layouts`],
where `kernel_layouts` affects the output of this op (since the output of this op
is the weight gradient). The layout of the output gradient (the second input to this op)
is assumed to be the same as `data_layout`.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data and kernel inputs respectively.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
assert len(desired_layouts) == 2, "A desired layout is expected for both of data and gradient."
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
assert desired_data_layout != "default", "Data layout cannot be default"
new_attrs["grad_layout"] = desired_data_layout
new_attrs["data_layout"] = desired_data_layout
new_attrs["kernel_layout"] = desired_kernel_layout
new_attrs.pop("out_layout")
return relay.nn.conv2d_backward_weight(inputs[0], inputs[1], **new_attrs)


#####################
# Shape functions #
#####################
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def conv2d_backward_weight_cudnn(
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]

if dy.dtype == "float16":
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN fp16 wgrad."

conv_dtype = "float32" # Accumulation is always fp32
return cudnn.conv_backward_filter(
dy,
x,
Expand All @@ -139,6 +145,6 @@ def conv2d_backward_weight_cudnn(
dilation,
conv_mode=1,
tensor_format=0 if layout == "NCHW" else 1,
conv_dtype=output_dtype,
conv_dtype=conv_dtype,
groups=groups,
)
43 changes: 37 additions & 6 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
def conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.

Parameters
Expand Down Expand Up @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
is_depth_wise = C == K and C == groups

if is_depth_wise:
assert channels == groups, "Only channel_mult == 1 supported for now."
dw = np.zeros((K, 1, R, S)).astype(dy_np.dtype)
else:
assert groups == 1, "General grouped conv2d not supported for now."
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
for c in range(dw.shape[1]):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
if not is_depth_wise:
in_c = c
else:
in_c = k

coord = (
n,
in_c,
p * stride_h - pad_h + r,
q * stride_w - pad_w + s,
)

if (
coord[2] < H
Expand All @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
return dw


def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
def conv2d_backward_weight_python(
dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None
):
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.

Parameters
Expand All @@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
layout: string
Layout of dy_np and x_np

groups: int
Number of groups for grouped convolution.

channels : int
Number of output channels of this convolution.

Returns
-------
dw_np : np.ndarray
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
"""
if layout == "NCHW":
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
return conv2d_backward_weight_nchw_python(
dy_np, x_np, kernel_size, stride, padding, groups, channels
)

dw_np_oihw = conv2d_backward_weight_nchw_python(
np.transpose(dy_np, [0, 3, 1, 2]),
np.transpose(x_np, [0, 3, 1, 2]),
kernel_size,
stride,
padding,
groups,
channels,
)
return np.transpose(dw_np_oihw, [0, 2, 3, 1])
5 changes: 5 additions & 0 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_transpose"});
return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), true, false));
} else if (pattern_name == "cutlass.conv2d_backward_weight") {
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_backward_weight"});
return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), false, true));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down
25 changes: 21 additions & 4 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,28 @@ bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Att
auto in_channels = dshape_nchw[1];
auto out_channels = grad_shape_nchw[1];

Array<IndexExpr> wshape_oihw(
{out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]});

auto in_channels_intimm = in_channels.as<IntImmNode>();
auto out_channels_intimm = out_channels.as<IntImmNode>();
ICHECK(in_channels_intimm);
ICHECK(out_channels_intimm);

IndexExpr weight_dim_i;
if (in_channels_intimm->value == out_channels_intimm->value &&
in_channels_intimm->value == param->groups) {
// depthwise
ICHECK(param->channels.defined())
<< "out_channels attribute not specified for depth wise conv2d.";
weight_dim_i = indexdiv(param->channels, param->groups);
} else {
weight_dim_i = indexdiv(in_channels, param->groups);
}

Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0],
param->kernel_size[1]};
auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
reporter->Assign(types[2], TensorType(wshape, data->dtype));

const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype;
reporter->Assign(types[2], TensorType(wshape, dw_dtype));
return true;
}

Expand Down
Loading