-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TensorRT] Fixed access to pad_value #8131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
By the way, are there suitable codeplaces of dynamic_cast util functions to relay ops from python? |
comaniac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test case?
Also cc @trevor-m
|
Thanks @akmaru https://github.com/apache/tvm/blob/main/src/runtime/contrib/tensorrt/tensorrt_ops.cc#L1060 |
| run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]], pad_value=-1.0e30)) | ||
| run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])) | ||
|
|
||
| run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], pad_value=-1.0e30)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the test cases.
Umm, I haven't encountered runtime error in my running. Though, it should be fix, so I'll do it! |
| class PadOpConverter : public TensorRTOpConverter { | ||
| public: | ||
| PadOpConverter() : TensorRTOpConverter({kTensor}) {} | ||
| PadOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added pad_value input into PadOpConverter.
| assert isinstance(args[1], relay.Constant) | ||
| if len(args[1].checked_type.shape) == 0 and args[1].data.numpy().item() != 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to check type of pad_value with if instead of assert because it may be Expr.
I suggest the following code, including the fix of checking the shape of pad_value,
| assert isinstance(args[1], relay.Constant) | |
| if len(args[1].checked_type.shape) == 0 and args[1].data.numpy().item() != 0.0: | |
| if (not isinstance(args[1], relay.Constant) or | |
| len(args[1].checked_type.shape) != 0 or | |
| args[1].data.numpy().item() != 0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be Expr anyways if the pad mode is constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think pad_mode="constant" just means padding by pad_value, not type of pad_value is relay.Constant.
nn.pad allows Expr for pad_value.
tvm/python/tvm/relay/op/nn/nn.py
Lines 1657 to 1658 in 03c8a6f
| pad_value: float, or tvm.relay.Expr, optional, default=0 | |
| The value used for padding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm you're right. I didn't notice that constant mode doesn't enforce constant node. @akmaru please change accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by e50bc43, please review again.
|
gentle ping @akmaru to resolve the lint error |
|
@akmaru not sure if the CI failure is related to this PR. Please fix it or retrigger the CI otherwise. |
|
Gentle ping @akmaru |
|
@akmaru Could you resolve the conflict? |
|
Gentle ping @akmaru please resolve the conflict and pass the CI. |
|
closing due to inactivity, but feel free to reopen! |
Fixed the omission modify pad_value from Attr to Args of `relay.nn.pad ' by the following commit.
78657e1