Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class ModelArguments:
)
},
)
flash_attention_fp8: bool = field(
default=False,
metadata={"help": ("Whether to enable flash attention in FP8.")},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -509,6 +513,7 @@ def main():
"trust_remote_code": True if model_args.trust_remote_code else None,
"use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
"token": model_args.token,
"flash_attention_fp8": model_args.flash_attention_fp8,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
Expand Down Expand Up @@ -705,6 +710,11 @@ def main():
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask

if model_args.flash_attention_fp8:
import habana_frameworks.torch.hpu as hthpu

assert hthpu.get_device_name() == "GAUDI3", "Flash attention in FP8 is supported only on Gaudi3"
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

Expand Down
20 changes: 20 additions & 0 deletions optimum/habana/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
"""
Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
"""
from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA

if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
for name, module in model.named_children():
Expand Down Expand Up @@ -75,6 +77,24 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
new_module.bias.copy_(module.bias)

setattr(model, name, new_module)
elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
from habana_frameworks.torch.hpex.experimental.transformer_engine import (
FusedAttention as TE_FusedAttention,
)

class TE_ModuleFusedSDPA(torch.nn.Module):
def __init__(self):
super().__init__()
self._hpu_kernel_fsdpa = TE_FusedAttention(
scale=module.scale,
attention_dropout=module.attention_dropout,
enable_recompute=module.enable_recompute,
)

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)

setattr(model, name, TE_ModuleFusedSDPA())
else:
_convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
mlp_bias=False,
fused_qkv=False,
parallel_strategy=None,
flash_attention_fp8=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -56,3 +57,4 @@ def __init__(

self.fused_qkv = fused_qkv
self.parallel_strategy = parallel_strategy
self.flash_attention_fp8 = flash_attention_fp8
18 changes: 16 additions & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,13 @@ def gaudi_llama_repeat_kv(

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA
self.scale = scale
self.attention_dropout = attention_dropout
self.enable_recompute = enable_recompute
self.flash_attention_fp8 = flash_attention_fp8

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
Expand Down Expand Up @@ -412,7 +416,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
Expand All @@ -428,6 +431,17 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.v_proj = None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.fused_scaled_dot_product_attention = (
ModuleFusedSDPA(
FusedSDPA,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

Suggested change
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
flash_attention_fp8=config.flash_attention_fp8,

?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because Harish Subramony noticed that in some cases this field is not set, so config.flash_attention_fp8 would fail with an AttributeError

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsubramony Can you give the commands or describe the cases where this error happens please?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/transformers/tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_eager_matches_sdpa_generate
tests/test_trainer.py::�[1mGaudiTrainerIntegrationTest::test_multiple_peft_adapters

  • AttributeError: 'LlamaConfig' object has no attribute 'flash_attention_fp8'

)
if FusedSDPA
else None
)

def get_k_proj_weight(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
Expand Down