Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1327] Allow RNN Layers to be initialized to fp16 #14219

Merged
merged 9 commits into from
Mar 12, 2019

Conversation

ThomasDelteil
Copy link
Contributor

@ThomasDelteil ThomasDelteil commented Feb 21, 2019

Description

Currently if you hybridize a RNN layers, it doesn't work with fp16 except if you explicitly give fp16 states as initial states. This makes it cumbersome to use.
This PR updates the RNN layers to have a dtype that is used for the default states and is updated upon casting.

Currently

net = gluon.rnn.LSTM(5, bidirectional=True)
net.collect_params().initialize(ctx=mx.gpu())
net.cast('float16')
net.hybridize()
net(mx.nd.ones((1,5,5), dtype='float16', ctx=mx.gpu()))

gives

MXNetError: Error in operator lstm5_rnn0: [01:23:01] src/operator/./rnn-inl.h:750: Check failed: (*in_type)[i] == dtype (0 vs. 2) This layer requires uniform type. Expected 'float16' v.s. given 'float32' at 'state'

Stack trace returned 10 entries:
[bt] (0) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x40e86a) [0x7fd8952bf86a]
[bt] (1) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x40eea1) [0x7fd8952bfea1]

After the fix

[[[ 0.001968  0.002743 -0.03683   0.015366 -0.02847   0.04007   0.00813
    0.002373  0.014656  0.03268 ]
  [ 0.001968  0.002743 -0.03683   0.015366 -0.02847   0.04007   0.00813
    0.002373  0.014656  0.03268 ]
  [ 0.001968  0.002743 -0.03683   0.015366 -0.02847   0.04007   0.00813
    0.002373  0.014656  0.03268 ]
  [ 0.001968  0.002743 -0.03683   0.015366 -0.02847   0.04007   0.00813
    0.002373  0.014656  0.03268 ]
  [ 0.001968  0.002743 -0.03683   0.015366 -0.02847   0.04007   0.00813
    0.002373  0.014656  0.03268 ]]]
<NDArray 1x5x10 @gpu(0)>

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Added a non-breaking dtype param to RNN layers

@haojin2
Copy link
Contributor

haojin2 commented Feb 21, 2019

I think you've linked with the wrong JIRA ticket for this PR... the title of MXNET-939 JIRA ticket is actually "Variance Operator".

@ThomasDelteil ThomasDelteil changed the title [MXNET-939] Allow RNN Layers to be initialized to fp16 [MXNET-1327] Allow RNN Layers to be initialized to fp16 Feb 21, 2019
Copy link
Member

@roywei roywei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, LGTM!

@roywei
Copy link
Member

roywei commented Feb 23, 2019

@mxnet-label-bot add[Gluon, RNN, FP16]

@eric-haibin-lin
Copy link
Member

Could you fix the test? Thanks!

@eric-haibin-lin eric-haibin-lin merged commit 6aa8c27 into apache:master Mar 12, 2019
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
* update rnn for fp16

* fix typo in test

* fix tests

* fix tests

* fix gpu tests

* Update test_gluon_rnn.py

* Update test_gluon_rnn.py

* trigger

* try removing checks for unix
nswamy pushed a commit that referenced this pull request Apr 5, 2019
* update rnn for fp16

* fix typo in test

* fix tests

* fix tests

* fix gpu tests

* Update test_gluon_rnn.py

* Update test_gluon_rnn.py

* trigger

* try removing checks for unix
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* update rnn for fp16

* fix typo in test

* fix tests

* fix tests

* fix gpu tests

* Update test_gluon_rnn.py

* Update test_gluon_rnn.py

* trigger

* try removing checks for unix
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants