-
Notifications
You must be signed in to change notification settings - Fork 538
[v0.10.x] Softmax optimization & bertpass refactor #1565
Conversation
LGTM |
The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1565/d13d37d19e549bb13984e855b9f3e6cb24a4bbc6/index.html |
The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1565/ad9185846d06baf328878fa7b37a5356a6439c89/index.html |
@szha Can you help with CI? I'm not sure why it's failing and don't know how to rerun it. |
Hi @bgawrych, Could you try to merge with v0.10.x, we ported new CI settings from v0.x to v0.10.x. Thanks! |
@barry-jin, @szha still issue with notebook, little strange as bert.md file was not changed for 1 year - should I fix it or it's CI issue? |
@barry-jin I see the following error in the log, pointing out that there's a CI issue:
|
Hi @bgawrych, we have ported the changes in bert.md from v0.x branch, you could try to merge with current v0.10.x. |
Thanks, I will fix this issue. |
|
Will be fixed in #1575 |
The documentation website for preview: http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR1565/1644e8fe25e66b05300042510ced0603d1dd4098/index.html |
Merged. Thanks @bgawrych! |
I'm suprised that the elimination of the 24x mask tensor creation gave you any speedup (as opposed to using masked softmax, which should) - MXNet already has common expression elimination pass (I wrote it: apache/mxnet#15657). Does that not work for you? |
@ptrendx I didn't know about this feature, but I wrote this small graph pass to test it:
Overhead from these operators are negligible, but seems like it don't work in this case: |
This PR adds graph pass to optimize CPU's softmax on BERT.
Currently for BERT-large length tensor is created by following operations: expand_dims -> brodcast_axis -> Reshape
and there is x24 such tensor creation. This pass replace softmax (with length) with regular softmax (but with masked input) - mask is created only once and then is passed to elemwise_sum to mask input. Applying pass in the scripts is optional
Original:
Masked softmax:
Thoughput in samples/s:
Accuracy:
There is also bug fix in interrleaved mha pass
Accuracy without mha_interleave bug fix: {'exact_match': 79.62157048249763, 'f1': 87.75497143592598}