Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ PT_TE_CUSTOM_OP=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py \
--torch_compile_backend hpu_backend \
--torch_compile \
--fp8 \
--use_flash_attention True \
--attn_implementation gaudi_fused_sdpa \
--flash_attention_causal_mask True \
--per_device_eval_batch_size 4 \
--cache_size_limit 64 \
Expand Down Expand Up @@ -435,7 +435,7 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_rank 4 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True \
--attn_implementation gaudi_fused_sdpa \
--flash_attention_causal_mask True \
--fp8 True
```
Expand Down Expand Up @@ -478,7 +478,7 @@ python3 ../gaudi_spawn.py --world_size 8 --use_mpi run_lora_clm.py \
--torch_compile_backend hpu_backend \
--torch_compile \
--gradient_accumulation_steps 2 \
--use_flash_attention True \
--attn_implementation gaudi_fused_sdpa \
--flash_attention_causal_mask True
```

Expand Down Expand Up @@ -647,7 +647,7 @@ PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py \
--eval_strategy epoch \
--pipelining_fwd_bwd \
--use_lazy_mode \
--use_flash_attention True \
--attn_implementation gaudi_fused_sdpa \
--deepspeed llama3_ds_zero1_config.json \
--num_train_epochs 3 \
--eval_delay 3 \
Expand Down
40 changes: 5 additions & 35 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,36 +146,6 @@ class ModelArguments:
)
},
)
use_flash_attention: bool = field(
default=False,
metadata={
"help": (
"Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only."
)
},
)
flash_attention_recompute: bool = field(
default=False,
metadata={
"help": (
"Whether to enable recompute in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True."
)
},
)
flash_attention_causal_mask: bool = field(
default=False,
metadata={
"help": (
"Whether to enable causal mask in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True."
)
},
)
flash_attention_fp8: bool = field(
default=False,
metadata={"help": ("Whether to enable flash attention in FP8.")},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -513,7 +483,7 @@ def main():
"trust_remote_code": True if model_args.trust_remote_code else None,
"use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
"token": model_args.token,
"flash_attention_fp8": model_args.flash_attention_fp8,
"flash_attention_fp8": training_args.flash_attention_fp8,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
Expand Down Expand Up @@ -707,12 +677,12 @@ def main():
model.generation_config.pad_token_id = model.generation_config.eos_token_id[0]
if model_args.attn_softmax_bf16:
model.generation_config.attn_softmax_bf16 = True
if model_args.use_flash_attention:
if training_args.attn_implementation == "gaudi_fused_sdpa":
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
model.generation_config.flash_attention_recompute = training_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = training_args.flash_attention_causal_mask

if model_args.flash_attention_fp8:
if training_args.flash_attention_fp8:
import habana_frameworks.torch.hpu as hthpu

assert hthpu.get_device_name() == "GAUDI3", "Flash attention in FP8 is supported only on Gaudi3"
Expand Down
25 changes: 24 additions & 1 deletion optimum/habana/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,25 @@ class GaudiTrainingArguments(TrainingArguments):
default=False,
metadata={
"help": "Whether to use fast softmax for Habana flash attention."
" It is applicable only when --attn_implementation gaudi_fused_sdpa."
" It is applicable only when --attn_implementation gaudi_fused_sdpa."
},
)

flash_attention_causal_mask: bool = field(
default=False,
metadata={
"help": (
"Whether to enable causal mask in Habana flash attention for fine-tuning."
" It is applicable only when --attn_implementation gaudi_fused_sdpa."
)
},
)

flash_attention_fp8: bool = field(
default=False,
metadata={"help": ("Whether to enable flash attention in FP8.")},
)

sdp_on_bf16: bool = field(
default=False,
metadata={"help": "Allow pyTorch to use reduced precision in the SDPA math backend"},
Expand Down Expand Up @@ -942,6 +957,14 @@ def __post_init__(self):
"flash_attention_fast_softmax works only with --attn_implementation gaudi_fused_sdpa"
)
os.environ["FLASH_ATTENTION_FAST_SOFTMAX"] = "1"
if self.flash_attention_causal_mask:
assert self.attn_implementation == "gaudi_fused_sdpa", (
"flash_attention_causal_mask works only with --attn_implementation gaudi_fused_sdpa"
)
if self.flash_attention_fp8:
assert self.attn_implementation == "gaudi_fused_sdpa", (
"flash_attention_causal_mask works only with --attn_implementation gaudi_fused_sdpa"
)

def __str__(self):
self_as_dict = asdict(self)
Expand Down
Loading