Skip to content

Enable Flash Attention in recompute and causal modes (#21)#862

Merged
regisss merged 13 commits into
huggingface:mainfrom
wszczurekhabana:flash_attention_causal_recompute
Apr 5, 2024
Merged

Enable Flash Attention in recompute and causal modes (#21)#862
regisss merged 13 commits into
huggingface:mainfrom
wszczurekhabana:flash_attention_causal_recompute

Conversation

@wszczurekhabana
Copy link
Copy Markdown
Contributor

Cherry-pick from: HabanaAI#21

Original description:
This is a follow-up on: #623

where main issue is that when running with Flash Attention in causal mode (required for performance and memory optimizations on 1st token) it will generate a triangular attention mask - same on each batch of the input data.
If we have more than one batch of sequences that have different lengths, tokenizer will include a padding for those sentences that have smaller sequence lengths than 'max input tokens'. In the case of Flash Attention in causal mode, where triangular attention mask is applied to all the inputs, this will effectively mean that padding tokens are also attended to, which will in turn result in junk output being generated

This PR propagates the modes of running to run_generation.py so that optimization can still be used in case of the same length inputs in a batch. This is controlled through: --flash_attention_recompute and --flash_attention_causal_mask.

Additionally this PR provides a way to pass real input data to the model from Project Gutenberg Books for easier testing of large sequence lengths.

Below are the throughput measurements for different ratios of prompt to max seq length:

Ratio Max input tokens Max new tokens Batch size Throughput [tokens/s]
97% 31744 1042 12 85.54
75% 24576 8192 16 336.29
50% 16384 16384 24 521.39
25% 8192 24576 36 708.42

example of the command tested:
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ --num_beams 1 --attn_softmax_bf16 --model_name_or_path meta-llama/Llama-2-70b-hf \ --warmup 2 --n_iterations 3 --use_hpu_graphs --use_kv_cache --max_input_tokens 31744 --max_new_tokens 1042 --bf16 --batch_size 12 --reuse_cache --trim_logits --limit_hpu_graphs --use_flash_attention --flash_attention_recompute --flash_attention_causal_mask --book_source

Results on finetuning:

No Flash Attention:
'train_runtime': 2499.5658, 'train_samples_per_second': 2.626

Flash Attention:
'train_runtime': 2487.0323, 'train_samples_per_second': 2.636

Flash Attention Causal:
'train_runtime': 2449.3563, 'train_samples_per_second': 2.686

regisss and others added 8 commits March 29, 2024 23:08
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
* Enable Flash Attention in recompute and causal modes

* Add flash_attention_causal_mask to generation utils

* Propagate Flash Attention causal_mask to finetuning example

* Modify README example and provide additional description

* Add flash_attention_causal_mask to FT README
Copy link
Copy Markdown
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding use_fused_rope in the changes of this PR? You need it for recompute and causal mode?

Comment thread optimum/habana/transformers/generation/utils.py Outdated
Comment thread examples/text-generation/run_generation.py
Comment thread optimum/habana/transformers/models/llama/modeling_llama.py Outdated
Comment thread examples/text-generation/README.md Outdated
Comment thread examples/text-generation/README.md Outdated
Comment thread optimum/habana/transformers/trainer.py Outdated
@regisss regisss changed the base branch from v1.11-release to main April 5, 2024 09:34
@regisss
Copy link
Copy Markdown
Collaborator

regisss commented Apr 5, 2024

@wszczurekhabana I changed the target of this PR to main because we first need to merge this change there and then I'll cherry-pick it in the release branch.
Can you make sure your main branch is up to date, run git merge main in your working branch and then solve the merge conflicts please?

Comment thread optimum/habana/version.py Outdated
Comment thread setup.py Outdated
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss regisss added the run-test Run CI for PRs from external contributors label Apr 5, 2024
@regisss regisss merged commit 8bfda75 into huggingface:main Apr 5, 2024
regisss added a commit that referenced this pull request Apr 5, 2024
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: Libin Tang <litang@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-test Run CI for PRs from external contributors

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants