Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,12 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
if attrs.pad_mode != "constant":
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
return False
if float(attrs.pad_value) != 0.0:
logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value))
if (
isinstance(args[1], relay.Constant)
and len(args[1].checked_type.shape) == 0
and args[1].data.numpy().item() != 0.0
):
logger.info("nn.pad: pad value is %s but must be 0.0.", args[1])
return False
if len(attrs.pad_width) not in [4, 5]:
logger.info("nn.pad: can only pad 4D or 5D inputs")
Expand Down
8 changes: 5 additions & 3 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,17 +793,19 @@ def get_graph(x_shape=(1, 16)):


def test_pad():
def get_graph(x_shape, pad_width):
def get_graph(x_shape, pad_width, pad_value=0.0):
x = relay.var("x", shape=(x_shape), dtype="float32")
out = relay.nn.pad(x, pad_width=pad_width)
out = relay.nn.pad(x, pad_width=pad_width, pad_value=pad_value)
f = relay.Function([x], out)
return f, {"x": x_shape}, []

run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]]))
run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]]))
run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]], pad_value=-1.0e30))
run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]]))
run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]], pad_value=-1.0e30))
run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]))

run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], pad_value=-1.0e30))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the test cases.


def test_softmax():
def get_graph(x_shape, axis):
Expand Down