-
Notifications
You must be signed in to change notification settings - Fork 538
[API] use softmax with length, and interleaved matmul for BERT #1091
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1091 +/- ##
=========================================
Coverage ? 88.77%
=========================================
Files ? 67
Lines ? 6351
Branches ? 0
=========================================
Hits ? 5638
Misses ? 713
Partials ? 0
|
@leezu I had a temporary gluon parameter workaround due to apache/mxnet#17220, by overriding |
@muhyun this helps GPT-2, too |
@eric-haibin-lin Will we achieve similar speed up if we fuse the kernel + using |
Merged master due to #1096. Should be possible to merge now |
Job PR-1091/5 is complete. |
|
Job PR-1091/7 is complete. |
Job PR-1091/9 is complete. |
@leezu could you help build the latest mxnet 1.6.0.rc1 for our CI pipeline? thanks! |
@eric-haibin-lin done |
Job PR-1091/10 is complete. |
Have u tried training the model? |
Yes. The one I am running contains the new dataset loader and this change. So far the loss looks normal |
Job PR-1091/11 is complete. |
…1091) * use softmax with length, and interleaved matmul * push backward compatibility fix * fix failing unittests for output_all_encodings, and valid-len=None * fix lint * Update bert.py * amp patch * Update MXNet 1.6 pre-release version tested on CI * Update bert.py Co-authored-by: Leonard Lausen <[email protected]>
* [API] use softmax with length, and interleaved matmul for BERT (dmlc#1091) * use softmax with length, and interleaved matmul * push backward compatibility fix * fix failing unittests for output_all_encodings, and valid-len=None * fix lint * Update bert.py * amp patch * Update MXNet 1.6 pre-release version tested on CI * Update bert.py Co-authored-by: Leonard Lausen <[email protected]> * Add fused attn and softmax * remove amp patch * add test * test for checkponts * Update files.py * py3.5 compatibility Co-authored-by: Leonard Lausen <[email protected]>
Description
This PR changes the input layout of BERTEncoder from NTC to TNC, so that we can adopt the fast fused self attention op introduced by @Caenorst and @TaoLv .
For BERTModel API, the input layout remains unchanged. If users obtain the BERT model via the get_model API, they don't need to make any code change to run the optimized version (other than upgrading the gluon-nlp version).
The tests won't pass as we need to wait for MXNet's nightly build since Jan 4th, otherwise the CPU op is missing.
On p3.16xlarge, BERT base, seq_len=512, batch_size=256
Checklist
Essentials
Changes
Comments