Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
from optimum.habana.utils import set_seed
from optimum.habana.utils import is_gaudi3, set_seed


try:
Expand Down Expand Up @@ -158,6 +158,14 @@ class ModelArguments:
)
},
)
flash_attention_fp8: bool = field(
default=False,
metadata={
"help": (
"Whether to enable flash attention in FP8."
)
},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -432,6 +440,7 @@ def main():
"trust_remote_code": True if model_args.trust_remote_code else None,
"use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
"token": model_args.token,
"flash_attention_fp8": model_args.flash_attention_fp8,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
Expand Down Expand Up @@ -587,6 +596,9 @@ def main():
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask

if model_args.flash_attention_fp8:
assert is_gaudi3(), "Flash attention in FP8 is supported only on Gaudi3"
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
"""
Recursively converts the linear layer of a model to their `transformers_engine` counterpart.
"""
from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA

if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
for name, module in model.named_children():
Expand Down Expand Up @@ -75,6 +77,14 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
new_module.bias.copy_(module.bias)

setattr(model, name, new_module)
elif isinstance(module, ModuleFusedSDPA) and module.flash_attention_fp8 and to_transformer_engine:
from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as te_FusedAttention
module._hpu_kernel_fsdpa = te_FusedAttention(
scale=module.scale,
attention_dropout=module.attention_dropout,
enable_recompute=module.enable_recompute,
)
setattr(model, name, module)
else:
_convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
mlp_bias=False,
fused_qkv=False,
parallel_strategy=None,
flash_attention_fp8=False,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -57,3 +58,4 @@ def __init__(
self.mlp_bias = mlp_bias
self.fused_qkv = fused_qkv
self.parallel_strategy = parallel_strategy
self.flash_attention_fp8 = flash_attention_fp8
37 changes: 24 additions & 13 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,20 @@ def gaudi_llama_repeat_kv(

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side)
self.scale = scale
Comment thread
vivekgoe marked this conversation as resolved.
self.attention_dropout = attention_dropout
self.enable_recompute = enable_recompute
self.flash_attention_fp8 = flash_attention_fp8

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"):
from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as FusedAttentionTE
if isinstance(self._hpu_kernel_fsdpa, FusedAttentionTE):
return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, softmax_mode)
else:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -332,7 +340,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
if config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
Expand All @@ -348,6 +355,13 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.v_proj = None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=config.flash_attention_fp8,
) if FusedSDPA else None

def get_k_proj_weight(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
Expand Down Expand Up @@ -522,26 +536,24 @@ def pre_attn_forward(
past_key_value = None

if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode, False, None, "None"
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None", False, None, "None"
)
else:
softmax_mode = "fast" if flash_attention_fast_softmax else "None"

# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left"
)
else:
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode,flash_attention_recompute, None, "None"
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode, flash_attention_recompute, None, "None"
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand Down Expand Up @@ -597,7 +609,6 @@ def post_attn_forward(self, attn_output):
self.o_proj.post_all_reduce(attn_output)
return attn_output


class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule):
def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,8 @@ def get_device_name():
return "gaudi2"
else:
raise ValueError(f"Unsupported device: the device type is {device_type}.")


def is_gaudi3():
import habana_frameworks.torch as htorch
return htorch.hpu.get_device_name() == "GAUDI3"