Skip to content

Commit 98fc649

Browse files
authored
[CUDNN] Add partitioning support for conv2d and log_softmax (#10961)
1 parent 11d22bd commit 98fc649

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

python/tvm/relay/op/contrib/cudnn.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm import te
2525
from tvm.relay import transform
2626
from tvm.contrib import cudnn
27+
from tvm.relay.build_module import bind_params_by_name
2728

2829
from ...dataflow_pattern import is_op, wildcard
2930
from .te_target import lower_composite, relay_to_runtime
@@ -50,6 +51,8 @@ def partition_for_cudnn(
5051
tvm.IRModule
5152
The partitioned module.
5253
"""
54+
if params:
55+
mod["main"] = bind_params_by_name(mod["main"], params)
5356

5457
seq = tvm.transform.Sequential(
5558
[
@@ -71,19 +74,78 @@ def softmax_pattern() -> relay.Pattern:
7174
"""Create pattern for softmax."""
7275
return is_op("nn.softmax")(wildcard())
7376

77+
def log_softmax_pattern() -> relay.Pattern:
78+
"""Create pattern for log_softmax."""
79+
return is_op("nn.log_softmax")(wildcard())
80+
81+
def conv2d_pattern() -> relay.Pattern:
82+
"""Create pattern for conv2d."""
83+
return is_op("nn.conv2d")(wildcard(), wildcard())
84+
7485
def check_softmax(matched: relay.Call) -> bool:
7586
"""Check if softmax is supported by cuDNN."""
7687
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
7788
return False
7889

7990
return True
8091

92+
def check_log_softmax(matched: relay.Call) -> bool:
93+
"""Check if log_softmax is supported by cuDNN."""
94+
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
95+
return False
96+
97+
if len(matched.args[0].checked_type.shape) != 2:
98+
return False
99+
100+
if matched.attrs["axis"] not in (1, -1):
101+
return False
102+
103+
return True
104+
105+
def check_conv2d(matched: relay.Call) -> bool:
106+
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
107+
return False
108+
109+
if matched.attrs["data_layout"] != "NCHW" or matched.attrs["kernel_layout"] != "OIHW":
110+
return False
111+
112+
padding = matched.attrs["padding"]
113+
if padding[0] != padding[2] or padding[1] != padding[3]:
114+
return False
115+
116+
return True
117+
81118
return [
82119
("cudnn.softmax", softmax_pattern(), check_softmax),
120+
("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
121+
("cudnn.conv2d", conv2d_pattern(), check_conv2d),
83122
]
84123

85124

86125
@lower_composite("cudnn.softmax")
87126
def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
88127
"""Lower a softmax using cuDNN."""
89128
return cudnn.softmax(inputs[0], axis=op.attrs["axis"])
129+
130+
131+
@lower_composite("cudnn.log_softmax")
132+
def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
133+
"""Lower a log_softmax using cuDNN."""
134+
return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])
135+
136+
137+
@lower_composite("cudnn.conv2d")
138+
def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
139+
"""Lower a conv2d using cuDNN."""
140+
return cudnn.conv_forward(
141+
inputs[0],
142+
inputs[1],
143+
pad=op.attrs["padding"],
144+
stride=op.attrs["strides"],
145+
dilation=op.attrs["dilation"],
146+
conv_mode=1,
147+
tensor_format=0,
148+
algo=1,
149+
conv_dtype=op.checked_type.dtype,
150+
groups=op.attrs["groups"],
151+
)

tests/python/contrib/test_cudnn.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _verify_cudnn_relay(expr):
484484
tvm.testing.assert_allclose(
485485
outputs[0],
486486
outputs[1],
487-
rtol=1e-3,
487+
rtol=1e-2,
488488
)
489489

490490

@@ -513,5 +513,69 @@ def test_relay_cudnn_softmax(shape, axis, dtype):
513513
_verify_cudnn_relay(softmax)
514514

515515

516+
@tvm.testing.requires_cuda
517+
@pytest.mark.parametrize(
518+
"shape,axis",
519+
[
520+
((32, 16), -1),
521+
((13, 27), 1),
522+
],
523+
)
524+
@pytest.mark.parametrize(
525+
"dtype",
526+
[
527+
"float32",
528+
"float16",
529+
"float64",
530+
],
531+
)
532+
def test_relay_cudnn_log_softmax(shape, axis, dtype):
533+
x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
534+
log_softmax = relay.op.nn.log_softmax(x, axis=axis)
535+
_verify_cudnn_relay(log_softmax)
536+
537+
538+
@tvm.testing.requires_cuda
539+
@pytest.mark.parametrize(
540+
"n,h,w,ci,co,groups",
541+
[
542+
(1, 16, 20, 8, 16, 1),
543+
(10, 17, 19, 16, 8, 4),
544+
],
545+
)
546+
@pytest.mark.parametrize(
547+
"kh,kw,padding",
548+
[
549+
(1, 1, (3, 1, 3, 1)),
550+
(3, 3, (1, 2)),
551+
(7, 2, (0, 0)),
552+
],
553+
)
554+
@pytest.mark.parametrize(
555+
"strides,dilation,dtype",
556+
[
557+
((1, 1), (1, 1), "float32"),
558+
((2, 1), (2, 2), "float16"),
559+
((3, 3), (1, 2), "float64"),
560+
],
561+
)
562+
def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype):
563+
data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
564+
weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
565+
conv2d = relay.op.nn.conv2d(
566+
data,
567+
weight,
568+
groups=groups,
569+
channels=co,
570+
kernel_size=(kh, kw),
571+
strides=strides,
572+
dilation=dilation,
573+
padding=padding,
574+
data_layout="NCHW",
575+
kernel_layout="OIHW",
576+
)
577+
_verify_cudnn_relay(conv2d)
578+
579+
516580
if __name__ == "__main__":
517581
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)