Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/tvm/relay/op/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import te
from tvm.relay import transform
from tvm.contrib import cudnn
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import is_op, wildcard
from .te_target import lower_composite, relay_to_runtime
Expand All @@ -50,6 +51,8 @@ def partition_for_cudnn(
tvm.IRModule
The partitioned module.
"""
if params:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given it's a one liner never figured out why folks want to fold that into every partition function. Cargo culting?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case, I want to pattern match against a mod where batch norm is removed by constant folding + fold scale axis. Param binding is a prereq for these passes.

mod["main"] = bind_params_by_name(mod["main"], params)

seq = tvm.transform.Sequential(
[
Expand All @@ -71,19 +74,78 @@ def softmax_pattern() -> relay.Pattern:
"""Create pattern for softmax."""
return is_op("nn.softmax")(wildcard())

def log_softmax_pattern() -> relay.Pattern:
"""Create pattern for log_softmax."""
return is_op("nn.log_softmax")(wildcard())

def conv2d_pattern() -> relay.Pattern:
"""Create pattern for conv2d."""
return is_op("nn.conv2d")(wildcard(), wildcard())

def check_softmax(matched: relay.Call) -> bool:
"""Check if softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

return True

def check_log_softmax(matched: relay.Call) -> bool:
"""Check if log_softmax is supported by cuDNN."""
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

if len(matched.args[0].checked_type.shape) != 2:
return False

if matched.attrs["axis"] not in (1, -1):
return False

return True

def check_conv2d(matched: relay.Call) -> bool:
if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
return False

if matched.attrs["data_layout"] != "NCHW" or matched.attrs["kernel_layout"] != "OIHW":
return False

padding = matched.attrs["padding"]
if padding[0] != padding[2] or padding[1] != padding[3]:
return False

return True

return [
("cudnn.softmax", softmax_pattern(), check_softmax),
("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
("cudnn.conv2d", conv2d_pattern(), check_conv2d),
]


@lower_composite("cudnn.softmax")
def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a softmax using cuDNN."""
return cudnn.softmax(inputs[0], axis=op.attrs["axis"])


@lower_composite("cudnn.log_softmax")
def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a log_softmax using cuDNN."""
return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])


@lower_composite("cudnn.conv2d")
def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
"""Lower a conv2d using cuDNN."""
return cudnn.conv_forward(
inputs[0],
inputs[1],
pad=op.attrs["padding"],
stride=op.attrs["strides"],
dilation=op.attrs["dilation"],
conv_mode=1,
tensor_format=0,
algo=1,
conv_dtype=op.checked_type.dtype,
groups=op.attrs["groups"],
)
66 changes: 65 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _verify_cudnn_relay(expr):
tvm.testing.assert_allclose(
outputs[0],
outputs[1],
rtol=1e-3,
rtol=1e-2,
)


Expand Down Expand Up @@ -513,5 +513,69 @@ def test_relay_cudnn_softmax(shape, axis, dtype):
_verify_cudnn_relay(softmax)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"shape,axis",
[
((32, 16), -1),
((13, 27), 1),
],
)
@pytest.mark.parametrize(
"dtype",
[
"float32",
"float16",
"float64",
],
)
def test_relay_cudnn_log_softmax(shape, axis, dtype):
x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
log_softmax = relay.op.nn.log_softmax(x, axis=axis)
_verify_cudnn_relay(log_softmax)


@tvm.testing.requires_cuda
@pytest.mark.parametrize(
"n,h,w,ci,co,groups",
[
(1, 16, 20, 8, 16, 1),
(10, 17, 19, 16, 8, 4),
],
)
@pytest.mark.parametrize(
"kh,kw,padding",
[
(1, 1, (3, 1, 3, 1)),
(3, 3, (1, 2)),
(7, 2, (0, 0)),
],
)
@pytest.mark.parametrize(
"strides,dilation,dtype",
[
((1, 1), (1, 1), "float32"),
((2, 1), (2, 2), "float16"),
((3, 3), (1, 2), "float64"),
],
)
def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype):
data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
conv2d = relay.op.nn.conv2d(
data,
weight,
groups=groups,
channels=co,
kernel_size=(kh, kw),
strides=strides,
dilation=dilation,
padding=padding,
data_layout="NCHW",
kernel_layout="OIHW",
)
_verify_cudnn_relay(conv2d)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))