diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 776993aca1..72853a6811 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -556,7 +556,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True + --use_flash_attention True \ + --flash_attention_causal_mask True ``` - Multi-card finetuning of Llama2-70B with FSDP and LoRA: diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index d057fc3c94..91d139ce6b 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -155,6 +155,15 @@ class ModelArguments: ) }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True.", + ) + }, + ) use_fused_rope: bool = field( default=True, metadata={ @@ -547,7 +556,8 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute - if model_args.use_fused_rope is False: + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + if not model_args.use_fused_rope: model.generation_config.use_fused_rope = False if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 0f9a2c7b16..e6563e433d 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -354,6 +354,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ ``` `--fp8` is required to enable quantization in fp8. +### Using Habana Flash Attention + +Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes. + +Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same length it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation. + +```bash +python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--reuse_cache \ +--trim_logits \ +--attn_softmax_bf16 \ +--max_input_tokens 31744 \ +--max_new_tokens 1024 \ +--batch_size=12 \ +--use_flash_attention \ +--flash_attention_recompute \ +--flash_attention_causal_mask \ +--book_source +``` + +For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa). ## Language Model Evaluation Harness diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1f503ed5e1..0b4ff8b5af 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -227,6 +227,21 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument( + "--flash_attention_recompute", + action="store_true", + help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", + ) + parser.add_argument( + "--flash_attention_causal_mask", + action="store_true", + help="Whether to enable Habana Flash Attention in causal mode on first token generation.", + ) + parser.add_argument( + "--book_source", + action="store_true", + help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.", + ) parser.add_argument( "--torch_compile", action="store_true", @@ -272,6 +287,45 @@ def main(): # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt + elif args.book_source: + + def download_book(book_id): + import os + + import requests + + url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt" + response = requests.get(url) + if response.status_code == 200: + pid = os.getpid() + save_path = f"/tmp/{book_id}_{pid}.txt" + with open(save_path, "wb") as file: + file.write(response.content) + print(f"Book downloaded and saved to: {save_path}") + return save_path + else: + print("Failed to download book! Exiting...") + import sys + + sys.exit() + + def assemble_prompt(prompt_size, book_path): + prompt = "" + counter = 0 + book_lines = open(book_path).readlines() + for line in book_lines: + for word in line.split(): + counter += 1 + prompt += word + " " + if counter == prompt_size: + return [prompt] * args.batch_size + + book_ids = [ + 2701, # Moby Dick; Or, The Whale + 1513, # Romeo and Juliet + 1342, # Pride and Prejudice + ] + input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0])) else: input_sentences = [ "DeepSpeed is a machine learning framework", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index a83242e8b3..9b60fe4920 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -347,6 +347,8 @@ def setup_generation_config(args, model, tokenizer): if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 generation_config.use_flash_attention = args.use_flash_attention + generation_config.flash_attention_recompute = args.flash_attention_recompute + generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index 93df1335db..61585b559f 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): Whether to enable recompute if use Habana flash attention. + flash_attention_causal_mask (`bool`, *optional*): + Whether to enable causal_mask if use Habana flash attention. """ def __init__(self, **kwargs): @@ -48,4 +50,5 @@ def __init__(self, **kwargs): self.reduce_recompile = kwargs.get("reduce_recompile", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index b46481da24..0d50470532 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -726,6 +726,8 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False + if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4bcc32b17b..1381f30b1e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -313,6 +313,7 @@ def pre_attn_forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -325,6 +326,7 @@ def pre_attn_forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ bsz, q_len, _ = hidden_states.size() @@ -408,10 +410,15 @@ def pre_attn_forward( ) else: # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -498,6 +505,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -509,6 +517,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask """ if "padding_mask" in kwargs: warnings.warn( @@ -529,6 +538,7 @@ def forward( reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, **kwargs, ) @@ -560,6 +570,7 @@ def pre_attn( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) @@ -576,6 +587,7 @@ def pre_attn( reuse_cache, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, cache_idx=cache_idx, ) return hidden_states, attn_weights, present_key_value @@ -668,6 +680,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -679,6 +692,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -778,6 +792,7 @@ def forward( False, use_flash_attention, flash_attention_recompute, + flash_attention_causal_mask, ) else: layer_outputs = decoder_layer( @@ -793,6 +808,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -864,6 +880,7 @@ def forward( reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -893,6 +910,7 @@ def forward( reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, lazy_mode=lazy_mode, ) @@ -1027,6 +1045,7 @@ def prepare_inputs_for_generation( "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), } diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index c91d0724f8..dc6e136a41 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -925,6 +925,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1806,6 +1808,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.flash_attention_causal_mask: + inputs["flash_attention_causal_mask"] = True # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)