From ff27bfaa6e5b7bc9ea5b60663c715ed757beadcb Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Fri, 2 Feb 2024 13:12:53 +0200 Subject: [PATCH 1/5] Enable Flash Attention in recompute and causal modes --- examples/text-generation/README.md | 20 +++++++ examples/text-generation/run_generation.py | 54 +++++++++++++++++++ examples/text-generation/utils.py | 2 + .../models/llama/modeling_llama.py | 27 ++++++++-- 4 files changed, 99 insertions(+), 4 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b57bf49045..254a9f568f 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -296,6 +296,26 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ ``` `--fp8` is required to enable quantization in fp8. +### Using Habana Flash Attention + +Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. + +```bash +python run_generation.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--reuse_cache \ +--trim_logits \ +--attn_softmax_bf16 \ +--max_input_tokens 31744 \ +--max_new_tokens 1024 \ +--batch_size=2 \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask \ +--book_source +``` ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index d2345c711c..048ef827dd 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -232,6 +232,21 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--flash_attention_recompute", + action="store_true", + help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", + ) + parser.add_argument( + "--flash_attention_causal_mask", + action="store_true", + help="Whether to enable Habana Flash Attention in causal mode on first token generation.", + ) + parser.add_argument( + "--book_source", + action="store_true", + help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", + ) parser.add_argument( "--torch_compile", action="store_true", @@ -271,6 +286,45 @@ def main(): # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt + elif args.book_source: + + def download_book(book_id): + import os + + import requests + + url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" + response = requests.get(url) + if response.status_code == 200: + pid = os.getpid() + save_path = f"/tmp/{book_id}_{pid}.txt" + with open(save_path, "wb") as file: + file.write(response.content) + print(f"Book downloaded and saved to: {save_path}") + return save_path + else: + print("Failed to download book! Exiting...") + import sys + + sys.exit() + + def assemble_prompt(prompt_size, book_path): + prompt = "" + counter = 0 + book_lines = open(book_path).readlines() + for line in book_lines: + for word in line.split(): + counter += 1 + prompt += word + " " + if counter == prompt_size: + return [prompt] * args.batch_size + + book_ids = [ + 2701, # Moby Dick; Or, The Whale + 1513, # Romeo and Juliet + 1342, # Pride and Prejudice + ] + input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) else: input_sentences = [ "DeepSpeed is a machine learning framework", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4bd8f27bb5..fc7f042223 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -344,6 +344,8 @@ def setup_generation_config(args, model, tokenizer): assert generation_config.bucket_size > 0 generation_config.kv_cache_fp8 = args.kv_cache_fp8 generation_config.use_flash_attention = args.use_flash_attention + generation_config.flash_attention_recompute = args.flash_attention_recompute + generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask return generation_config diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ce55b283be..2622d832bd 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -199,6 +199,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -211,6 +212,7 @@ def pre_attn_forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ bsz, q_len, _ = hidden_states.size() @@ -289,10 +291,15 @@ def pre_attn_forward( ) else: # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same lenght + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -424,6 +431,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -435,6 +443,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ residual = hidden_states output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( @@ -449,6 +458,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -479,6 +489,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -495,6 +506,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -545,6 +557,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -556,6 +569,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -637,6 +651,7 @@ def custom_forward(*inputs): attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, use_fused_rope=use_fused_rope, ) @@ -658,6 +673,7 @@ def custom_forward(*inputs): reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -727,6 +743,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -751,6 +768,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) @@ -838,6 +856,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), } ) From 363fc7c4e8ed5a28d93b5588d34aacd24b806dd7 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Wed, 7 Feb 2024 11:48:55 +0200 Subject: [PATCH 2/5] Add flash_attention_causal_mask to generation utils --- optimum/habana/transformers/generation/configuration_utils.py | 1 + optimum/habana/transformers/generation/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 57f12810db..75f2b2bc64 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -49,4 +49,5 @@ def __init__(self, **kwargs): self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index edf9afc4f2..f1fdf0748c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -708,6 +708,7 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True if not self.config.is_encoder_decoder: From d0d7403ed786bbceed9f0f0125e1de8019e47a18 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Wed, 7 Feb 2024 22:27:41 +0200 Subject: [PATCH 3/5] Propagate Flash Attention causal_mask to finetuning example --- examples/language-modeling/run_lora_clm.py | 10 ++++++++++ .../transformers/generation/configuration_utils.py | 2 ++ optimum/habana/transformers/trainer.py | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 47da3af150..ba3244e57f 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -156,6 +156,15 @@ class ModelArguments: ) }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True.", + ) + }, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -545,6 +554,7 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 75f2b2bc64..2e72342263 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): Whether to enable recompute if use Habana flash attention. + flash_attention_causal_mask (`bool`, *optional*): + Whether to enable causal_mask if use Habana flash attention. """ def __init__(self, **kwargs): diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 2f217a64b9..99514c295b 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -874,6 +874,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True if not self.model.generation_config.use_fused_rope: inputs["use_fused_rope"] = False @@ -1628,6 +1630,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True if not self.model.generation_config.use_fused_rope: inputs["use_fused_rope"] = False From e20451f26c8aa487dd39a3d8a76f2f85a6cc7658 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Wed, 7 Feb 2024 22:42:30 +0200 Subject: [PATCH 4/5] Modify README example and provide additional description --- examples/text-generation/README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 254a9f568f..332d117e2f 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -300,9 +300,11 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. +Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation. + ```bash -python run_generation.py \ ---model_name_or_path meta-llama/Llama-2-7b-hf \ +python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ --use_hpu_graphs \ --use_kv_cache \ --reuse_cache \ @@ -310,13 +312,15 @@ python run_generation.py \ --attn_softmax_bf16 \ --max_input_tokens 31744 \ --max_new_tokens 1024 \ ---batch_size=2 \ +--batch_size=12 \ --use_flash_attention \ --flash_attention_recompute \ --flash_attention_causal_mask \ --book_source ``` +For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). + ## Language Model Evaluation Harness The evaluation of LLMs can be done using the `lm_eval.py` script. It utilizes the [LM evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness) From 347a0bb54c9242d6886d089f5fa7c91197531255 Mon Sep 17 00:00:00 2001 From: Witold Szczurek Date: Thu, 8 Feb 2024 10:23:22 +0200 Subject: [PATCH 5/5] Add flash_attention_causal_mask to FT README --- examples/language-modeling/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 909593427d..ac4a74ab69 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -550,7 +550,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Falcon-180B: