Flash attention enhancement of repeatKV#626
Conversation
|
@regisss - Don't merge this yet. Let me add some testing results too in this week. Fused SDPA will internally handle repeat KV logic in 1.14 release hence we don't need repeat KV logic when flash attention is enabled. We should see perf boost for some configs such as TP-4. Can you please add tag of synapse 1.14 to this PR mean while. |
|
World size 8 - no impact on performance. Significant perf improvements will be seen for world size 4 or world size 2 with this change when flash attention is enabled. Few readings taken below. python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 128 --batch_size 8 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention Without fix - With fix - python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 512 --max_new_tokens 512 --batch_size 16 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention Without fix - With fix - python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 1024 --max_new_tokens 1024 --batch_size 8 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention Without fix - With fix - |
regisss
left a comment
There was a problem hiding this comment.
Very nice speedups!
So we have to wait for the release of SynapseAI v1.14 to merge this one right?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Yes @regisss - I think we l need to wait for it. |
|
@puneeshkhanna we tested the PR with a 4x finetuning, but no performance improvement was seen. is the perf improvement limited to inference, or did I miss something? |
|
@schoi-habana - Just to double check that you did use "flash attention" for finetuning too and also finetuning 70B model right ? I m not sure how flash attention flag is passed for finetuning. Perf improvement is surely there for inference of 70B model. Change is not applicable for 7b and 13 b models as they don't use GQA. We should also see some improvement in memory usage too. |
|
@puneeshkhanna yes i tested 4x llama2-70b with --use_flash_attention. the sec/iter improvement is about 2-3% and the mem usage didn't improve. |
|
Thanks for the update @schoi-habana . 2-3% improvement in finetuning is good too I guess. We should be good to merge this with 1.14 synapse release. I will try to fix the non-flash attention path too in next week in separate PR. |
|
We should first merge this PR and then #639 to avoid any merge conflicts hopefully. |
…uggingface#626) Co-authored-by: Piotr Bielak <pbielak@users.noreply.github.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
What does this PR do?
Fixes # (issue)
Before submitting