Skip to content

Commit b4c74a9

Browse files
authored
[TensorRT] Fix pad_value access (removed from PadAttrs) (#9858)
1 parent 9258997 commit b4c74a9

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,17 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
737737
"""Check if nn.pad is supported by TensorRT."""
738738

739739
attrs, args = expr.attrs, expr.args
740+
pad_value = args[1]
741+
assert isinstance(pad_value, relay.Constant)
742+
pad_value = pad_value.data.numpy().item()
740743
if any([x.checked_type.dtype != "float32" for x in args]):
741744
logger.info("Only float32 inputs are supported for TensorRT.")
742745
return False
743746
if attrs.pad_mode != "constant":
744747
logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode)
745748
return False
746-
if float(attrs.pad_value) != 0.0:
747-
logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value))
749+
if pad_value > 0.0:
750+
logger.info("nn.pad: pad value is %f but must be 0.0.", pad_value)
748751
return False
749752
if len(attrs.pad_width) not in [4, 5]:
750753
logger.info("nn.pad: can only pad 4D or 5D inputs")

0 commit comments

Comments
 (0)