Skip to content

Commit 1cf0c0a

Browse files
authored
[CUDNN] Add partitioning support for fused conv2d+bias+act (#10997)
cuDNN has kernel support for the pattern conv2d+bias+act, although as of v8 only relu is supported as the activation.
1 parent 96616b7 commit 1cf0c0a

File tree

5 files changed

+186
-12
lines changed

5 files changed

+186
-12
lines changed

python/tvm/relay/op/contrib/cudnn.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
# under the License.
1717
# pylint: disable=unused-argument
1818
"""cuDNN Relay integration."""
19-
from typing import Callable, List, Tuple, Dict, Optional
19+
from typing import Callable, List, Tuple
2020

2121
import tvm
2222
import tvm.ir
2323
from tvm import relay
2424
from tvm import te
2525
from tvm.relay import transform
2626
from tvm.contrib import cudnn
27-
from tvm.relay.build_module import bind_params_by_name
2827

2928
from ...dataflow_pattern import is_op, wildcard
3029
from .te_target import lower_composite, relay_to_runtime
@@ -34,25 +33,19 @@
3433
tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda()))
3534

3635

37-
def partition_for_cudnn(
38-
mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
39-
) -> tvm.IRModule:
36+
def partition_for_cudnn(mod: tvm.IRModule) -> tvm.IRModule:
4037
"""Partition the graph to offload for cuDNN.
4138
4239
Parameters
4340
----------
4441
mod : tvm.IRModule
4542
The module to partition.
46-
params : Optional[Dict[str, tvm.runtime.NDArray]]
47-
Constant input parameters.
4843
4944
Returns
5045
-------
5146
tvm.IRModule
5247
The partitioned module.
5348
"""
54-
if params:
55-
mod["main"] = bind_params_by_name(mod["main"], params)
5649

5750
seq = tvm.transform.Sequential(
5851
[
@@ -82,6 +75,12 @@ def conv2d_pattern() -> relay.Pattern:
8275
"""Create pattern for conv2d."""
8376
return is_op("nn.conv2d")(wildcard(), wildcard())
8477

78+
def conv2d_bias_act_pattern() -> relay.Pattern:
79+
"""Create pattern for fused conv2d+bias+activation."""
80+
conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
81+
bias = is_op("nn.bias_add")(conv2d, wildcard())
82+
return bias.optional(is_op("nn.relu"))
83+
8584
def check_softmax(matched: relay.Call) -> bool:
8685
"""Check if softmax is supported by cuDNN."""
8786
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
@@ -115,9 +114,13 @@ def check_conv2d(matched: relay.Call) -> bool:
115114

116115
return True
117116

117+
def check_conv2d_bias_act(matched: relay.Call) -> bool:
118+
return True
119+
118120
return [
119121
("cudnn.softmax", softmax_pattern(), check_softmax),
120122
("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
123+
("cudnn.conv2d_bias_act", conv2d_bias_act_pattern(), check_conv2d_bias_act),
121124
("cudnn.conv2d", conv2d_pattern(), check_conv2d),
122125
]
123126

@@ -134,6 +137,64 @@ def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
134137
return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])
135138

136139

140+
@lower_composite("cudnn.conv2d_bias_act")
141+
def _lower_conv2d_bias_act(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
142+
"""Lower a fused conv2d+bias+activation using cuDNN."""
143+
conv_dtype = op.checked_type.dtype
144+
if op.op.name == "nn.relu":
145+
activation_mode = 1 # Relu
146+
conv2d = op.args[0].args[0]
147+
else:
148+
activation_mode = 5 # Identity
149+
conv2d = op.args[0]
150+
151+
conv_mode = 1
152+
tensor_format = 0
153+
algo = 1
154+
pad = conv2d.attrs["padding"]
155+
strides = conv2d.attrs["strides"]
156+
dilation = conv2d.attrs["dilation"]
157+
groups = conv2d.attrs["groups"]
158+
159+
oshape = cudnn.conv_output_shape(
160+
tensor_format,
161+
pad,
162+
strides,
163+
dilation,
164+
inputs[0].shape,
165+
inputs[1].shape,
166+
inputs[0].dtype,
167+
conv_dtype,
168+
groups,
169+
)
170+
171+
return te.extern(
172+
oshape,
173+
inputs,
174+
lambda ins, outs: tvm.tir.call_packed(
175+
"tvm.contrib.cudnn.conv2d+bias+act.forward",
176+
conv_mode,
177+
tensor_format,
178+
algo,
179+
pad[0],
180+
pad[1],
181+
strides[0],
182+
strides[1],
183+
dilation[0],
184+
dilation[1],
185+
activation_mode,
186+
0,
187+
ins[0],
188+
ins[1],
189+
ins[2],
190+
outs[0],
191+
conv_dtype,
192+
groups,
193+
),
194+
name="y",
195+
)
196+
197+
137198
@lower_composite("cudnn.conv2d")
138199
def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
139200
"""Lower a conv2d using cuDNN."""

src/runtime/contrib/cudnn/conv_forward.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,44 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
6060
entry_ptr->conv_entry.output_desc, y->data));
6161
}
6262

63+
void ConvolutionBiasActivationForward(int mode, int format, int algo, int dims, int groups, int act,
64+
double coef, const int pad[], const int stride[],
65+
const int dilation[], DLTensor* x, DLTensor* w, DLTensor* y,
66+
DLTensor* bias, const std::string& conv_dtype) {
67+
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
68+
// Set Mode
69+
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
70+
CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc,
71+
static_cast<cudnnActivationMode_t>(act),
72+
cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef));
73+
CUDNN_CALL(cudnnSetTensor4dDescriptor(
74+
entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format,
75+
CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, static_cast<int>(w->shape[0]), 1, 1));
76+
77+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
78+
y->shape, x->dtype, conv_dtype);
79+
// Set Device
80+
entry_ptr->conv_entry.device = x->device;
81+
// Set Algo
82+
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
83+
84+
// Set workspace
85+
size_t workspace_size = 0;
86+
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
87+
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
88+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
89+
entry_ptr->conv_entry.fwd_algo, &workspace_size));
90+
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
91+
CUDNN_CALL(cudnnConvolutionBiasActivationForward(
92+
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
93+
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
94+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
95+
entry_ptr->conv_entry.workspace, workspace_size,
96+
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
97+
entry_ptr->conv_entry.output_desc, y->data, entry_ptr->conv_entry.bias_desc, bias->data,
98+
entry_ptr->conv_entry.activation_desc, entry_ptr->conv_entry.output_desc, y->data));
99+
}
100+
63101
void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
64102
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
65103
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
@@ -126,6 +164,30 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
126164
conv_dtype);
127165
});
128166

167+
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward")
168+
.set_body([](TVMArgs args, TVMRetValue* ret) {
169+
int mode = args[0];
170+
int format = args[1];
171+
int algo = args[2];
172+
int pad_v[2], stride_v[2], dilation_v[2];
173+
for (int i = 0; i < 2; i++) {
174+
pad_v[i] = args[3 + i];
175+
stride_v[i] = args[5 + i];
176+
dilation_v[i] = args[7 + i];
177+
}
178+
int act = args[9];
179+
double coef = args[10];
180+
DLTensor* x = args[11];
181+
DLTensor* w = args[12];
182+
DLTensor* bias = args[13];
183+
DLTensor* y = args[14];
184+
std::string conv_dtype = args[15];
185+
int groups = args[16];
186+
187+
ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, coef, pad_v, stride_v,
188+
dilation_v, x, w, y, bias, conv_dtype);
189+
});
190+
129191
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
130192
.set_body([](TVMArgs args, TVMRetValue* ret) {
131193
int mode = args[0];

src/runtime/contrib/cudnn/cudnn_utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,17 @@ ConvEntry::ConvEntry() {
140140
CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
141141
CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
142142
CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
143+
CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc));
144+
CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc));
143145
}
144146

145147
ConvEntry::~ConvEntry() {
146148
CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc));
147149
CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
148150
CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
149151
CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
152+
CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc));
153+
CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc));
150154
CleanWorkspace();
151155
}
152156

src/runtime/contrib/cudnn/cudnn_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ struct ConvEntry {
7171
cudnnTensorFormat_t tensor_format;
7272
cudnnTensorDescriptor_t input_desc;
7373
cudnnFilterDescriptor_t filter_desc;
74+
cudnnTensorDescriptor_t bias_desc;
75+
cudnnActivationDescriptor_t activation_desc;
7476
cudnnTensorDescriptor_t output_desc;
7577
cudnnConvolutionFwdAlgo_t fwd_algo;
7678
cudnnConvolutionBwdDataAlgo_t bwd_data_algo;

tests/python/contrib/test_cudnn.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,12 @@ def _verify_cudnn_relay(expr):
461461
for param in func.params:
462462
shape = [int(x) for x in param.checked_type.shape]
463463
input_data.append(
464-
(param.name_hint, np.random.uniform(0, 32, size=shape).astype(param.checked_type.dtype))
464+
(
465+
param.name_hint,
466+
np.random.uniform(-32, 32, size=shape).astype(param.checked_type.dtype),
467+
)
465468
)
466469

467-
# Test against CPU reference
468470
cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod)
469471
cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
470472
outputs = []
@@ -484,7 +486,8 @@ def _verify_cudnn_relay(expr):
484486
tvm.testing.assert_allclose(
485487
outputs[0],
486488
outputs[1],
487-
rtol=1e-2,
489+
rtol=1e-3,
490+
atol=30,
488491
)
489492

490493

@@ -577,5 +580,47 @@ def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding,
577580
_verify_cudnn_relay(conv2d)
578581

579582

583+
@tvm.testing.requires_cuda
584+
@pytest.mark.parametrize(
585+
"n,h,w,ci,co,groups",
586+
[
587+
(1, 16, 20, 8, 16, 1),
588+
(10, 17, 19, 16, 8, 4),
589+
],
590+
)
591+
@pytest.mark.parametrize(
592+
"kh,kw,padding,strides,dilation,dtype",
593+
[
594+
(1, 1, (3, 1, 3, 1), (1, 1), (1, 1), "float32"),
595+
(3, 3, (1, 2), (2, 1), (2, 2), "float16"),
596+
(7, 2, (0, 0), (3, 3), (1, 2), "float64"),
597+
],
598+
)
599+
@pytest.mark.parametrize("activation", [True, False])
600+
def test_relay_cudnn_conv2d_bias_act(
601+
n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype, activation
602+
):
603+
data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
604+
weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
605+
bias = relay.var("bias", relay.TensorType((co,), dtype))
606+
conv2d = relay.op.nn.conv2d(
607+
data,
608+
weight,
609+
groups=groups,
610+
channels=co,
611+
kernel_size=(kh, kw),
612+
strides=strides,
613+
dilation=dilation,
614+
padding=padding,
615+
data_layout="NCHW",
616+
kernel_layout="OIHW",
617+
)
618+
out = relay.op.nn.bias_add(conv2d, bias)
619+
if activation:
620+
out = relay.op.nn.relu(out)
621+
622+
_verify_cudnn_relay(out)
623+
624+
580625
if __name__ == "__main__":
581626
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)