Skip to content

Commit 7fd73b2

Browse files
authored
[CUTLASS] Initial support for conv2d wgrad (#10177)
* [CUTLASS] Add wgrad support (without split-k) * run black * wgrad tests now work under pytest * dw conv2d properly supported for wgrad * all tests work * fixed for sm75 * cpplint * fix conv2d grad test
1 parent 2ea2f5a commit 7fd73b2

File tree

13 files changed

+335
-42
lines changed

13 files changed

+335
-42
lines changed

python/tvm/contrib/cudnn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,10 +826,20 @@ def conv_backward_filter(
826826
x.shape[0], tvm.tir.expr.IntImm
827827
), "Dynamic batch is not supported for cudnn conv2d backwad filter yet."
828828

829+
ic_ind = 1 if tensor_format == 0 else 3
830+
831+
if groups > 1:
832+
assert (
833+
x_shape[ic_ind] == dy.shape[ic_ind] and x_shape[ic_ind] == groups
834+
), "Only depthwise wgrad supported for groups > 1."
835+
ic = 1
836+
else:
837+
ic = x_shape[ic_ind]
838+
829839
if tensor_format == 0:
830-
dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w]
840+
dw_shape = [dy.shape[1], ic, filter_h, filter_w]
831841
else:
832-
dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]]
842+
dw_shape = [dy.shape[3], filter_h, filter_w, ic]
833843

834844
algo = conv_backward_filter_find_algo(
835845
tensor_format,

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def select_op(
252252
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
253253
use_3xtf32,
254254
profile_all_alignments,
255+
# Use fp32 accumulation for wgrad to align with cuDNN
256+
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype,
255257
)
256258

257259
if not find_first_valid:

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def get_default(
164164
lambda align: align == 1, # Only request align1 kernels
165165
use_3xtf32,
166166
profile_all_alignments=True, # To include all align1 kernels
167+
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
168+
accumlator_dtype=out_dtype,
167169
)
168170

169171
default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]
@@ -220,6 +222,8 @@ def select_op(
220222
lambda align: all([dim % align == 0 for dim in [M, N, K]]),
221223
use_3xtf32,
222224
profile_all_alignments=profile_all_alignments,
225+
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
226+
accumlator_dtype=out_dtype,
223227
)
224228

225229
if not find_first_valid:

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def generate_tensor_op_common(
5151
data_type = [
5252
math_inst.element_a,
5353
math_inst.element_b,
54-
math_inst.element_accumulator,
54+
math_inst.element_c,
5555
math_inst.element_accumulator,
5656
]
5757

@@ -63,7 +63,14 @@ def generate_tensor_op_common(
6363

6464

6565
def generate_sm75_tensor_op_1688(
66-
out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False
66+
out_dtype,
67+
arg0_dtype,
68+
arg1_dtype,
69+
op_creator,
70+
check_align,
71+
_,
72+
profile_all_alignments=False,
73+
accumlator_dtype="float32",
6774
):
6875
"""Generate GEMM or Conv2D kernels for Turing."""
6976
assert out_dtype in ["float32", "float16", "int32"]
@@ -77,6 +84,7 @@ def generate_sm75_tensor_op_1688(
7784
DataType.f16,
7885
DataType.f16,
7986
dtype_map[out_dtype],
87+
dtype_map[accumlator_dtype],
8088
OpcodeClass.TensorOp,
8189
MathOperation.multiply_add,
8290
)
@@ -100,6 +108,7 @@ def generate_sm75_tensor_op_1688(
100108
dtype_map[arg0_dtype],
101109
dtype_map[arg1_dtype],
102110
DataType.s32,
111+
DataType.s32,
103112
OpcodeClass.TensorOp,
104113
MathOperation.multiply_add_saturate,
105114
),
@@ -141,6 +150,7 @@ def generate_sm80_tensor_op_16816(
141150
check_align,
142151
use_3xtf32=True,
143152
profile_all_alignments=False,
153+
accumlator_dtype="float32",
144154
):
145155
"""Generate GEMM or Conv2D kernels for Ampere."""
146156
min_cc = 80
@@ -176,6 +186,7 @@ def get_default_tile_descriptions(block_k_factor):
176186
DataType.f16,
177187
DataType.f16,
178188
dtype_map[out_dtype],
189+
dtype_map[accumlator_dtype],
179190
OpcodeClass.TensorOp,
180191
MathOperation.multiply_add,
181192
)
@@ -189,6 +200,7 @@ def get_default_tile_descriptions(block_k_factor):
189200
DataType.f32,
190201
DataType.f32,
191202
DataType.f32,
203+
DataType.f32,
192204
OpcodeClass.TensorOp,
193205
MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add,
194206
),
@@ -221,6 +233,7 @@ def get_default_tile_descriptions(block_k_factor):
221233
dtype_map[arg0_dtype],
222234
dtype_map[arg1_dtype],
223235
DataType.s32,
236+
DataType.s32,
224237
OpcodeClass.TensorOp,
225238
MathOperation.multiply_add_saturate,
226239
),
@@ -248,6 +261,7 @@ def get_tile_descriptions(math_inst):
248261
check_align,
249262
False,
250263
profile_all_alignments,
264+
accumlator_dtype=accumlator_dtype,
251265
)
252266
else:
253267
# TF32 (float32 + float32 case) is only supported on sm80
@@ -292,6 +306,7 @@ def get_tile_descriptions(math_inst):
292306
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
293307
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
294308
"cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False),
309+
"cutlass.conv2d_backward_weight": (EpilogueFunctor.LinearCombination, False),
295310
}
296311

297312

python/tvm/contrib/cutlass/library.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,15 @@ def __init__(
266266
instruction_shape,
267267
element_a,
268268
element_b,
269+
element_c,
269270
element_accumulator,
270271
opcode_class,
271272
math_operation=MathOperation.multiply_add,
272273
):
273274
self.instruction_shape = instruction_shape
274275
self.element_a = element_a
275276
self.element_b = element_b
277+
self.element_c = element_c
276278
self.element_accumulator = element_accumulator
277279
self.opcode_class = opcode_class
278280
self.math_operation = math_operation

python/tvm/relay/op/contrib/cutlass.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def make_conv2d_transpose_pattern():
9494
return is_op("nn.conv2d_transpose")(wildcard(), wildcard())
9595

9696

97+
def make_conv2d_backward_weight_pattern():
98+
return is_op("nn.conv2d_backward_weight")(wildcard(), wildcard())
99+
100+
97101
def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
98102
"""Add pattern for residual blocks."""
99103
residual_input = wildcard()
@@ -173,6 +177,10 @@ def check_conv2d_transpose(call):
173177
return check_conv2d_common("nn.conv2d_transpose", "IHWO", call)
174178

175179

180+
def check_conv2d_backward_weight(call):
181+
return check_conv2d_common("nn.conv2d_backward_weight", "NHWC", call)
182+
183+
176184
def check_conv2d_residual(call, binary_op):
177185
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
178186
conv2d = get_root_call(call, "nn.conv2d")
@@ -245,6 +253,11 @@ def partition_for_cutlass(mod, params=None):
245253
# For now, no fusion for grad kernels
246254
conv2d_grad_patterns = [
247255
("cutlass.conv2d_transpose", make_conv2d_transpose_pattern(), check_conv2d_transpose),
256+
(
257+
"cutlass.conv2d_backward_weight",
258+
make_conv2d_backward_weight_pattern(),
259+
check_conv2d_backward_weight,
260+
),
248261
]
249262

250263
residual_block_patterns = []

python/tvm/relay/op/nn/_nn.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
11071107
padding=attrs.padding,
11081108
dilation=attrs.strides,
11091109
groups=in_channel * batch,
1110+
out_dtype=attrs.out_dtype,
11101111
)
11111112

11121113
# infer shape of backward_weight
@@ -1143,6 +1144,40 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
11431144
return backward_weight
11441145

11451146

1147+
@reg.register_convert_op_layout("nn.conv2d_backward_weight")
1148+
def convert_conv2d_backward_weight(attrs, inputs, _, desired_layouts):
1149+
"""Convert Layout pass registration for conv2d_backward_weight op.
1150+
Note that `desired_layouts` must be a pair [`data_layout`, `kernel_layouts`],
1151+
where `kernel_layouts` affects the output of this op (since the output of this op
1152+
is the weight gradient). The layout of the output gradient (the second input to this op)
1153+
is assumed to be the same as `data_layout`.
1154+
Parameters
1155+
----------
1156+
attrs : tvm.ir.Attrs
1157+
Attributes of current op
1158+
inputs : list of tvm.relay.Expr
1159+
The args of the Relay expr to be legalized
1160+
tinfos : list of types
1161+
List of input and output types
1162+
desired_layouts : list of layout strings
1163+
List of layouts defining our desired
1164+
layout for the data and kernel inputs respectively.
1165+
Returns
1166+
-------
1167+
result : tvm.relay.Expr
1168+
The transformed expr
1169+
"""
1170+
new_attrs = dict(attrs)
1171+
assert len(desired_layouts) == 2, "A desired layout is expected for both of data and gradient."
1172+
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
1173+
assert desired_data_layout != "default", "Data layout cannot be default"
1174+
new_attrs["grad_layout"] = desired_data_layout
1175+
new_attrs["data_layout"] = desired_data_layout
1176+
new_attrs["kernel_layout"] = desired_kernel_layout
1177+
new_attrs.pop("out_layout")
1178+
return relay.nn.conv2d_backward_weight(inputs[0], inputs[1], **new_attrs)
1179+
1180+
11461181
#####################
11471182
# Shape functions #
11481183
#####################

python/tvm/topi/cuda/conv2d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def conv2d_backward_weight_cudnn(
130130
):
131131
"""Compute conv2d wgrad using CuDNN library"""
132132
assert layout in ["NCHW", "NHWC"]
133+
134+
if dy.dtype == "float16":
135+
# cuDNN does not seem to support other combination.
136+
assert output_dtype == "float16", "Only supports fp16 output for cuDNN fp16 wgrad."
137+
138+
conv_dtype = "float32" # Accumulation is always fp32
133139
return cudnn.conv_backward_filter(
134140
dy,
135141
x,
@@ -139,6 +145,6 @@ def conv2d_backward_weight_cudnn(
139145
dilation,
140146
conv_mode=1,
141147
tensor_format=0 if layout == "NCHW" else 1,
142-
conv_dtype=output_dtype,
148+
conv_dtype=conv_dtype,
143149
groups=groups,
144150
)

python/tvm/topi/testing/conv2d_backcward_weight_python.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121

2222
# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
23-
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
23+
def conv2d_backward_weight_nchw_python(
24+
dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None
25+
):
2426
"""Gradient of the conv2d op with respect to weight, in NCHW layout.
2527
2628
Parameters
@@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
5153
R, S = kernel_size
5254
pad_h, pad_w = padding
5355
stride_h, stride_w = stride
54-
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
56+
is_depth_wise = C == K and C == groups
57+
58+
if is_depth_wise:
59+
assert channels == groups, "Only channel_mult == 1 supported for now."
60+
dw = np.zeros((K, 1, R, S)).astype(dy_np.dtype)
61+
else:
62+
assert groups == 1, "General grouped conv2d not supported for now."
63+
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)
5564

5665
for k in range(K):
5766
for r in range(R):
5867
for s in range(S):
59-
for c in range(C):
68+
for c in range(dw.shape[1]):
6069
acc = 0
6170
for n in range(N):
6271
for p in range(P):
6372
for q in range(Q):
64-
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)
73+
if not is_depth_wise:
74+
in_c = c
75+
else:
76+
in_c = k
77+
78+
coord = (
79+
n,
80+
in_c,
81+
p * stride_h - pad_h + r,
82+
q * stride_w - pad_w + s,
83+
)
6584

6685
if (
6786
coord[2] < H
@@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
7695
return dw
7796

7897

79-
def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
98+
def conv2d_backward_weight_python(
99+
dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None
100+
):
80101
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
81102
82103
Parameters
@@ -99,20 +120,30 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay
99120
layout: string
100121
Layout of dy_np and x_np
101122
123+
groups: int
124+
Number of groups for grouped convolution.
125+
126+
channels : int
127+
Number of output channels of this convolution.
128+
102129
Returns
103130
-------
104131
dw_np : np.ndarray
105132
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
106133
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
107134
"""
108135
if layout == "NCHW":
109-
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
136+
return conv2d_backward_weight_nchw_python(
137+
dy_np, x_np, kernel_size, stride, padding, groups, channels
138+
)
110139

111140
dw_np_oihw = conv2d_backward_weight_nchw_python(
112141
np.transpose(dy_np, [0, 3, 1, 2]),
113142
np.transpose(x_np, [0, 3, 1, 2]),
114143
kernel_size,
115144
stride,
116145
padding,
146+
groups,
147+
channels,
117148
)
118149
return np.transpose(dw_np_oihw, [0, 2, 3, 1])

src/relay/backend/contrib/cutlass/codegen.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,11 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
615615
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_transpose"});
616616
return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller),
617617
Conv2dArgs(std::ref(attrs_), true, false));
618+
} else if (pattern_name == "cutlass.conv2d_backward_weight") {
619+
const auto* conv2d_call =
620+
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_backward_weight"});
621+
return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller),
622+
Conv2dArgs(std::ref(attrs_), false, true));
618623
}
619624

620625
LOG(FATAL) << "Unknown composite function: " << pattern_name;

0 commit comments

Comments
 (0)