-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-882] Support for N-d arrays added to diag op. #12430
Conversation
src/operator/tensor/diag_op-inl.h
Outdated
if (ishape.ndim() == 1) { | ||
auto s = ishape[0] + std::abs(k); | ||
return TShape({s, s}); | ||
} | ||
|
||
auto h = ishape[0]; | ||
auto w = ishape[1]; | ||
auto h = ishape[axis1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related to the previous comment, shouldnt we check that the provided axis1 and axis2 values are within the acceptable range [0:ndim-1] so that users dont get nasty segfaults on incorrect inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see there is a check later on line 85/86, so maybe consider re-organizing the code so that this is looked up after the check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right. These checks should be done first before using axis1 and axis2. Also I find it necessary to add some tests testing cases with negative axes.
@samskalicky Thanks for your review! I have accordingly made modifications to the code. |
I don't understand what has happened. It looks quite unrelated to my modifications. It seems that several recent updates have failed at the same place.
|
@jasonyu1996 regarding this error, I wonder if its transient. Can you try pushing an empty change and restart the CI pipeline? Its possible it will pass a 2nd time. If that still doesnt work, can you try to debug it? Sorry, the message doesnt look very meaningful. Does that test fail if you run it yourself on your branch? If you run the same test in the master branch does it fail the same way? You might have to try re-applying some of your changes and see when it starts breaking to debug. There are some instructions here to setup debugging (if you're using a Mac): https://cwiki.apache.org/confluence/display/MXNET/MXNet+Developer+Setup+on+Mac but im also working on improving this and putting together a debug guide so please post any steps or processes that you do (whether they worked or not) and i'll be sure to include those. |
@samskalicky Thanks! I also think the error would be transient. Possibly I should wait for a while until this problem gets solved. It has apparently not, because so far updates since noon of 6th (UTC) have all failed, even including one that changes only a markdown file, which has failed at exactly the same place as ours. UPD: Somebody has opened an issue regarding this problem: #12473 |
@jasonyu1996 There is a PR to fix this issue dmlc/nnvm#525 and an associated change on mxnet #12479. I know that this issue is blocking everyone else so hopefully this gets through soon and we can get your changes merged too. Thanks for all of your explanations on the code (especially the code you didnt write)! I hope you dont mind if I ask some more while we wait for the nnvm fix. Also, can you explain the backward pass with the diagonal operator? What exactly is happening there, the diagonal comes in and the original input matrix gets updated? What is the ML meaning in the diag case; what error is there to correct in this calculation? |
@samskalicky Thanks! As what the diag operator does is basically copying some elements from the input directly to output, the backward process is quite straightforward. What it does is simply copying the gradients from the output back to the corresponding elements of the input. This is achieved with a template parameter indicating whether it is forward or backward to control the direction of the assignment in the copy process. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@anirudh2290 Please review/merge |
@anirudh2290 @samskalicky Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jasonyu1996 Thanks for your contribution. It looks good overall. I have a few questions and suggested some small changes. Thanks!
@apeforest Thanks for your meticulous review! I have made the change according to your suggestions. |
@samskalicky @apeforest Thanks again! |
@samskalicky @apeforest It seems that |
@apeforest @samskalicky Am I expected to do something more to get this PR merged? |
Sorry @jasonyu1996, i'll take a look at this today and give you a status update. Thanks for your patience! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Many thanks for your contribution!
@sandeep-krishnamurthy @nswamy @anirudh2290 - looks like this one is ready to go, can you please review/merge? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
* Support for N-d arrays added to diag op. * Doc for diag updated. Sanity fix. Unit test for N-d diag op added. * Unwanted print in diag test removed. * Bad negative axis support in diag fixed. * Bad negative axis support in diag fixed. * Index overflow in diag fixed. Exemplars for Nd diag added. * A bug in diag op fixed. * Types of axis number and dim size changed to dim_t in diag. * A bug in params of diag fixed. * Poisonous dim_t removed from diag parameters. * Diag impl details documented in comments.
Description
Github Issue #12327
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
diag
operator now supports N-d arrays whereN > 2
. In this case it behaves in the same way as thediagonal
operator provided by numpy (https://www.numpy.org/devdocs/reference/generated/numpy.diagonal.html). Axes of the sub-arrays can be specified withaxis1
andaxis2
.Comments