diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a71cda507957..992112139842 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -737,14 +737,17 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.pad is supported by TensorRT.""" attrs, args = expr.attrs, expr.args + pad_value = args[1] + assert isinstance(pad_value, relay.Constant) + pad_value = pad_value.data.numpy().item() if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False if attrs.pad_mode != "constant": logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode) return False - if float(attrs.pad_value) != 0.0: - logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value)) + if pad_value > 0.0: + logger.info("nn.pad: pad value is %f but must be 0.0.", pad_value) return False if len(attrs.pad_width) not in [4, 5]: logger.info("nn.pad: can only pad 4D or 5D inputs")