Enable fast softmax mode in FusedSDPA#159
Conversation
|
This ModuleFusedSDPA forward API change will require changes in the HQT patched module for quantization. which means it will break the nightly testing once merged. Im not sure we support regular-softmax for 8bit, we need to consider the appropriate behavior when user requests both quantization and regular-softmax, should we ignore the quantization or the softmax? or assert on that configuration. |
|
Discussed offline. Relevant change for quantization toolkit: https://gerrit.habana-labs.com/#/c/411008/ pushed by @dudilester is in review. |
|
Change in https://gerrit.habana-labs.com/#/c/411008/ is merged. @dvarshney-habana @puneeshkhanna @dudilester I think we can merge this PR now. |
|
Change in https://gerrit.habana-labs.com/#/c/411008/ did not pass promotion yet, we need to wait till it will pass before we merge this PR. |
|
FYI, commit https://gerrit.habana-labs.com/#/c/411008/ was promoted since my previous comment, and is included in builds since CD 1.16.0-328 release build |
|
Thanks, I was not tracking it closely. @dvarshney-habana can we merge it? |
* Enable fast softmax mode in FusedSDPA * Add fast_softmax parameter to _gradient_checkpointing_func
* Enable fast softmax mode in FusedSDPA * Add fast_softmax parameter to _gradient_checkpointing_func
|
upstreamed in: huggingface#972 |
Support for setting fast softmax mode in FusedSDPA operator. This is a tradeoff: performance vs accuracy.
Data on performance:
Data on accuracy (using mlperf test from: https://gerrit.habana-labs.com/plugins/gitiles/mlperf_inference/+/refs/heads/master_next/code/llama/llama_greedy.py
and https://gerrit.habana-labs.com/plugins/gitiles/mlperf_inference/+/refs/heads/master_next/code/llama/evaluation.py):