Skip to content

Flash attention enhancement of repeatKV#626

Merged
regisss merged 1 commit into
huggingface:mainfrom
puneeshkhanna:repeatKV
Jan 23, 2024
Merged

Flash attention enhancement of repeatKV#626
regisss merged 1 commit into
huggingface:mainfrom
puneeshkhanna:repeatKV

Conversation

@puneeshkhanna
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@puneeshkhanna puneeshkhanna requested a review from a user January 8, 2024 06:29
@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

@regisss - Don't merge this yet. Let me add some testing results too in this week.

Fused SDPA will internally handle repeat KV logic in 1.14 release hence we don't need repeat KV logic when flash attention is enabled. We should see perf boost for some configs such as TP-4.

Can you please add tag of synapse 1.14 to this PR mean while.

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

puneeshkhanna commented Jan 10, 2024

World size 8 - no impact on performance.

Significant perf improvements will be seen for world size 4 or world size 2 with this change when flash attention is enabled. Few readings taken below.
I will try to fix the non flash attention path to handle world size 4 or world size 2 in separate PR.

python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 128 --max_new_tokens 128 --batch_size 8 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention

Without fix -
Stats:
Throughput (including tokenization) = 277.43501473449646 tokens/second
Number of HPU graphs = 19
Memory allocated = 32.88 GB
Max memory allocated = 36.85 GB
Total memory available = 94.62 GB
Graph compilation duration = 9.24757074000081 seconds

With fix -
Throughput (including tokenization) = 327.1925781797973 tokens/second
Number of HPU graphs = 19
Memory allocated = 32.86 GB
Max memory allocated = 36.94 GB
Total memory available = 94.62 GB
Graph compilation duration = 8.126413758000126 seconds

python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 512 --max_new_tokens 512 --batch_size 16 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention

Without fix -
Stats:
Throughput (including tokenization) = 288.127251016057 tokens/second
Number of HPU graphs = 19
Memory allocated = 34.23 GB
Max memory allocated = 44.86 GB
Total memory available = 94.62 GB
Graph compilation duration = 58.593736156995874 seconds

With fix -
Stats:
Throughput (including tokenization) = 589.2105646981188 tokens/second
Number of HPU graphs = 19
Memory allocated = 34.2 GB
Max memory allocated = 44.83 GB
Total memory available = 94.62 GB
Graph compilation duration = 29.483050564012956 seconds

python ../gaudi_spawn.py --use_deepspeed --world_size 4 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --use_hpu_graphs --use_kv_cache --max_input_tokens 1024 --max_new_tokens 1024 --batch_size 8 --limit_hpu_graphs --attn_softmax_bf16 --trim_logits --bf16 --reuse_cache --warmup 2 --n_iterations 2 --use_flash_attention

Without fix -
Stats:
Throughput (including tokenization) = 140.24620942590755 tokens/second
Number of HPU graphs = 19
Memory allocated = 34.28 GB
Max memory allocated = 44.92 GB
Total memory available = 94.62 GB
Graph compilation duration = 118.88620932200865 seconds

With fix -
Stats:
Throughput (including tokenization) = 313.4460591743231 tokens/second
Number of HPU graphs = 19
Memory allocated = 34.22 GB
Max memory allocated = 44.86 GB
Total memory available = 94.62 GB
Graph compilation duration = 54.29132983699674 seconds

Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice speedups!

So we have to wait for the release of SynapseAI v1.14 to merge this one right?

@regisss regisss added the run-test Run CI for PRs from external contributors label Jan 11, 2024
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

Yes @regisss - I think we l need to wait for it.

@schoi-habana
Copy link
Copy Markdown
Collaborator

@puneeshkhanna we tested the PR with a 4x finetuning, but no performance improvement was seen. is the perf improvement limited to inference, or did I miss something?

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

puneeshkhanna commented Jan 12, 2024

@schoi-habana - Just to double check that you did use "flash attention" for finetuning too and also finetuning 70B model right ? I m not sure how flash attention flag is passed for finetuning. Perf improvement is surely there for inference of 70B model. Change is not applicable for 7b and 13 b models as they don't use GQA. We should also see some improvement in memory usage too.

@schoi-habana
Copy link
Copy Markdown
Collaborator

schoi-habana commented Jan 12, 2024

@puneeshkhanna yes i tested 4x llama2-70b with --use_flash_attention. the sec/iter improvement is about 2-3% and the mem usage didn't improve.

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

Thanks for the update @schoi-habana . 2-3% improvement in finetuning is good too I guess. We should be good to merge this with 1.14 synapse release. I will try to fix the non-flash attention path too in next week in separate PR.

@puneeshkhanna
Copy link
Copy Markdown
Contributor Author

We should first merge this PR and then #639 to avoid any merge conflicts hopefully.

@regisss regisss merged commit 1cca12a into huggingface:main Jan 23, 2024
jychen21 pushed a commit to jychen21/optimum-habana that referenced this pull request Feb 27, 2024
gplutop7 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Oct 15, 2025
…uggingface#626)

Co-authored-by: Piotr Bielak <pbielak@users.noreply.github.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-test Run CI for PRs from external contributors

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants