From e35222badfcb9262679339370801ab89201a687f Mon Sep 17 00:00:00 2001 From: Urszula Date: Fri, 17 Oct 2025 11:08:52 +0300 Subject: [PATCH] Remove repeating flash_attention options. Some flash attention options were duplicated in training args and in run_lora_clm script which caused an ArgumentError. This is a quick fix to remove the error. A broader unification of the arguments would be nice later. Signed-off-by: Urszula --- examples/language-modeling/README.md | 8 ++-- examples/language-modeling/run_lora_clm.py | 40 +++----------------- optimum/habana/transformers/training_args.py | 25 +++++++++++- 3 files changed, 33 insertions(+), 40 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 8f32cb9e8c..7bfde95871 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -391,7 +391,7 @@ PT_TE_CUSTOM_OP=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \ --torch_compile_backend hpu_backend \ --torch_compile \ --fp8 \ - --use_flash_attention True \ + --attn_implementation gaudi_fused_sdpa \ --flash_attention_causal_mask True \ --per_device_eval_batch_size 4 \ --cache_size_limit 64 \ @@ -435,7 +435,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \ --lora_rank 4 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --validation_split_percentage 4 \ - --use_flash_attention True \ + --attn_implementation gaudi_fused_sdpa \ --flash_attention_causal_mask True \ --fp8 True ``` @@ -478,7 +478,7 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \ --torch_compile_backend hpu_backend \ --torch_compile \ --gradient_accumulation_steps 2 \ - --use_flash_attention True \ + --attn_implementation gaudi_fused_sdpa \ --flash_attention_causal_mask True ``` @@ -647,7 +647,7 @@ PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py \ --eval_strategy epoch \ --pipelining_fwd_bwd \ --use_lazy_mode \ - --use_flash_attention True \ + --attn_implementation gaudi_fused_sdpa \ --deepspeed llama3_ds_zero1_config.json \ --num_train_epochs 3 \ --eval_delay 3 \ diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index b22eabba44..9517c1fb2d 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -146,36 +146,6 @@ class ModelArguments: ) }, ) - use_flash_attention: bool = field( - default=False, - metadata={ - "help": ( - "Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only." - ) - }, - ) - flash_attention_recompute: bool = field( - default=False, - metadata={ - "help": ( - "Whether to enable recompute in Habana flash attention for fine-tuning." - " It is applicable only when use_flash_attention is True." - ) - }, - ) - flash_attention_causal_mask: bool = field( - default=False, - metadata={ - "help": ( - "Whether to enable causal mask in Habana flash attention for fine-tuning." - " It is applicable only when use_flash_attention is True." - ) - }, - ) - flash_attention_fp8: bool = field( - default=False, - metadata={"help": ("Whether to enable flash attention in FP8.")}, - ) use_fused_rope: bool = field( default=True, metadata={ @@ -513,7 +483,7 @@ def main(): "trust_remote_code": True if model_args.trust_remote_code else None, "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, "token": model_args.token, - "flash_attention_fp8": model_args.flash_attention_fp8, + "flash_attention_fp8": training_args.flash_attention_fp8, } if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) @@ -707,12 +677,12 @@ def main(): model.generation_config.pad_token_id = model.generation_config.eos_token_id[0] if model_args.attn_softmax_bf16: model.generation_config.attn_softmax_bf16 = True - if model_args.use_flash_attention: + if training_args.attn_implementation == "gaudi_fused_sdpa": model.generation_config.use_flash_attention = True - model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute - model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + model.generation_config.flash_attention_recompute = training_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = training_args.flash_attention_causal_mask - if model_args.flash_attention_fp8: + if training_args.flash_attention_fp8: import habana_frameworks.torch.hpu as hthpu assert hthpu.get_device_name() == "GAUDI3", "Flash attention in FP8 is supported only on Gaudi3" diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 163cc51b8e..21395bf7df 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -395,10 +395,25 @@ class GaudiTrainingArguments(TrainingArguments): default=False, metadata={ "help": "Whether to use fast softmax for Habana flash attention." - " It is applicable only when --attn_implementation gaudi_fused_sdpa." + " It is applicable only when --attn_implementation gaudi_fused_sdpa." }, ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when --attn_implementation gaudi_fused_sdpa." + ) + }, + ) + + flash_attention_fp8: bool = field( + default=False, + metadata={"help": ("Whether to enable flash attention in FP8.")}, + ) + sdp_on_bf16: bool = field( default=False, metadata={"help": "Allow pyTorch to use reduced precision in the SDPA math backend"}, @@ -942,6 +957,14 @@ def __post_init__(self): "flash_attention_fast_softmax works only with --attn_implementation gaudi_fused_sdpa" ) os.environ["FLASH_ATTENTION_FAST_SOFTMAX"] = "1" + if self.flash_attention_causal_mask: + assert self.attn_implementation == "gaudi_fused_sdpa", ( + "flash_attention_causal_mask works only with --attn_implementation gaudi_fused_sdpa" + ) + if self.flash_attention_fp8: + assert self.attn_implementation == "gaudi_fused_sdpa", ( + "flash_attention_causal_mask works only with --attn_implementation gaudi_fused_sdpa" + ) def __str__(self): self_as_dict = asdict(self)