Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-798] Fix the dtype cast from non float32 in Gradient computation #12290

Merged
merged 14 commits into from
Sep 14, 2018

Conversation

apeforest
Copy link
Contributor

@apeforest apeforest commented Aug 22, 2018

Description

This PR fixes the issues #9067 and #8799 where gradient computation for operators with multiple output fails in ndarray if the dtype is not float32.

The root cause of the issue is that a _zeros operator was added for the other don't care output. The _zeros operator uses float32 dtype by default and it will cause conflict if the dtype in ndarray is not float32. My solution is to create a new _zeros_without_dtype operator that does not take any default dtype and use it to replace the _zeros operator in the computation graph. This change solves the dtype conflict problem and should be backward compatible.

A unit test is added to test this fix.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Change the way to infer type for auto-derived zero operator in nnvm::Graph
  • Added a unittest for operators with multioutput.

Comments

  • This seems to be a general problem for all multioutput operators when computing gradient in imperative mode. A simple example is copied from the original issue below:
  • Although the change is small, the impact could be large. Thus thorough review is solicited.
import mxnet as mx
from mxnet import autograd


data = mx.nd.arange(16, dtype='float64').reshape((4, 4))
data.attach_grad()

with autograd.record():
    y = mx.nd.split(data, axis=0, num_outputs=2)
y[0].backward()
print(data.grad)

@apeforest
Copy link
Contributor Author

apeforest commented Aug 22, 2018

@eric-haibin-lin @piiswrong @haojin2 I will appreciate your review.



if __name__ == "__main__":
test_infer_multiout_op()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.

test64.backward()
assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all())


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should go to something like test_operator.py instead of creating a separate file for it? And, please see https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L7017-L7018 for how to use nosetests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not to test the functionality of the operator but a general type casting issue for all multioutput operators. I inclined to add it in the infer type tests but would like to hear more suggestions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed test to run nose runmodule

@apeforest apeforest changed the title [MXNET-798] Fix the dtype cast from non float32 in Gradient computation [MXNET-798][WIP] Fix the dtype cast from non float32 in Gradient computation Aug 22, 2018
@apeforest
Copy link
Contributor Author

Change to [WIP] to fix some platform dependent unit test failure.

@@ -254,7 +254,8 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
dispatch_mode = &dispatch_modes[nid];
if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false;
}
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
auto finfer = (inode.source->op() == Op::Get("_zeros")) ? fdefault :
finfer_shape.get(inode.source->op(), fdefault);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about this? This affects all _zero ops, not just for the case you mentioned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, this is breaking some unit test (however, due to unittest of master branch is broken in MacOS, I wan't able to verify before checkin). I have changed the PR to WIP.

@apeforest apeforest changed the title [MXNET-798][WIP] Fix the dtype cast from non float32 in Gradient computation [MXNET-798] Fix the dtype cast from non float32 in Gradient computation Sep 12, 2018
@apeforest
Copy link
Contributor Author

@eric-haibin-lin Please review this new implementation. Thanks for your suggestion!

@eric-haibin-lin
Copy link
Member

What's up with the build?

@apeforest
Copy link
Contributor Author

@eric-haibin-lin Not sure exactly. An earlier build passed dcc5f78). After I renamed some variables the build on ARM7 failed. I can submit an empty change to trigger the build again.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

with autograd.record():
test64 = test_func(data64)
test64.backward()
assert_almost_equal(data64.grad.asnumpy().all(), data32.grad.asnumpy().all())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you set rtol and atol to some bigger value than default here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why increase the rtol and atol if the unit test can pass with the default one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be flaky. you are comparing a float32 numpy to a float64 numpy and the atol and rtol defaults are small.

@anirudh2290
Copy link
Member

Also,maybe we should add zeros to APIs that may be good to break for 2.0 #9686

@apeforest
Copy link
Contributor Author

@anirudh2290 The _zeros_without_dtype operator is a private operator used only in building nnvm graph. It is not meant to be exposed to users.

@anirudh2290
Copy link
Member

@apeforest what i meant is we can change the dtype default to -1 for zeros operator for 2.0.

@apeforest
Copy link
Contributor Author

@anirudh2290 Thanks for the clarification. I have increased atol and rtol values in unit test. As to changing the dtype default to -1 for zeros, I think it is not related to this PR and may cause a backward compatibility issue with old models. Therefore, I would prefer doing that in a separate PR. Please let me know what you think. Thanks.

@anirudh2290
Copy link
Member

Not suggesting to do it in this PR. Just wanted to document it in the APIs to break for 2.0 and we can do it before 2.0 release.

@anirudh2290 anirudh2290 merged commit 8209906 into apache:master Sep 14, 2018
anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request Sep 19, 2018
…on (apache#12290)

* Fix the dtype mismatch in derived _zeros node

* Add unittest for infer dtype

* Add one more unit test

* Add nose runmodule

* Add a zero operator with no default dtype

* Rename variables

* fix a bug: rename operator for gpu

* Increase atol and rtol to avoid flakiness
@apeforest apeforest deleted the bugfix/dtype-cast branch January 7, 2020 22:49
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants