This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Tvm broadcast backward #15938
Merged
Merged
Tvm broadcast backward #15938
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, there seems to be an assumption that
ograd.ndim = igrad.ndim
, which is not necessarily true. I think you need to prepend axes beforeigrad
ifigrad.ndim < ograd.ndim
and then use the logic here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
igrad.ndim = ograd.ndim
is assumed.@yzhliu suggests padding the input to 5-dim, which is the largest possible dim supported by this op. The padding will 1) reduce the number of kernels (by a factor of 5) 2) handle the
igrad.ndim < ograd.ndim
issue. But there may be loss in performance.I think prepending axes before
igrad
to make itograd.dim
requires more kernels, but the performance is better. It is a tradeoff.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if my understanding is wrong, but don't you still need kernels generated for
ndims < 5
since you will collapse consecutive dimensions where reduction is performed? For example, given a 5d shape(2, 3, 4, 5, 6)
, and perform reduction onaxis=(1, 2)
, thetblob
will be first reshaped into(2, 12, 30)
, and then reduce onaxis=1
. In this case, do you need a kernel generated for 3D shapes?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can pad the shape after dimension collapse. In this case, the
tblob
will be reshaped into(2, 12, 30, 1, 1)
and then reduce onaxis=[1, 3]
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I am in favor of the approach with less kernels generated. We can revisit the performance concern if that turns out to be an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a new version, where the inputs and outputs are padded to 5 dim.