Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,6 @@ class ModelArguments:
)
},
)
flash_attention_fp8: bool = field(
default=False,
metadata={
"help": (
"Whether to enable flash attention in FP8."
)
},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -595,7 +587,6 @@ def main():
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
model.generation_config.flash_attention_fp8 = model_args.flash_attention_fp8
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

Expand Down
5 changes: 0 additions & 5 deletions optimum/habana/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import functools

import torch
from optimum.habana.transformers.models.llama.modeling_llama import ModuleFusedSDPA


has_transformer_engine = False
Expand Down Expand Up @@ -76,10 +75,6 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True):
new_module.bias.copy_(module.bias)

setattr(model, name, new_module)
elif isinstance(module, ModuleFusedSDPA) and to_transformer_engine:
from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention as te_FusedAttention
module._hpu_kernel_fsdpa = te_FusedAttention(scale=module.scale, attention_dropout=module.attention_dropout, enable_recompute=False)
setattr(model, name, module)
else:
_convert_model(module, to_transformer_engine=to_transformer_engine, _convert_linear=_convert_linear)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
flash_attention_fp8 (`bool`, *optional*):
Whether to use flash attention in FP8.
"""

def __init__(self, **kwargs):
Expand All @@ -56,5 +54,4 @@ def __init__(self, **kwargs):
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.flash_attention_fp8 = kwargs.get("flash_attention_fp8", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
76 changes: 15 additions & 61 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@
import habana_frameworks.torch.core as htcore


flash_attention_in_fp8 = False


def gaudi_llama_rmsnorm_forward(self, hidden_states):
"""
Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -228,19 +225,12 @@ def gaudi_llama_repeat_kv(

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute=False):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA
self.scale = scale
self.attention_dropout = attention_dropout
self.enable_recompute = enable_recompute

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, fast_softmax_mode):
from habana_frameworks.torch.hpex.experimental.transformer_engine import FusedAttention
if isinstance(self._hpu_kernel_fsdpa, FusedAttention):
return self._hpu_kernel_fsdpa(query, key, value, attn_mask, is_causal, fast_softmax_mode)
else:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fast_softmax_mode)

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -301,6 +291,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
if config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
Expand All @@ -316,12 +307,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.v_proj = None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
) if FusedSDPA else None

def get_k_proj_weight(self):
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
Expand Down Expand Up @@ -499,54 +484,27 @@ def pre_attn_forward(

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

global flash_attention_in_fp8
if flash_attention_in_fp8 is True:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states,
key_states,
value_states,
attention_mask,
self.num_key_value_groups,
)

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
if flash_attention_in_fp8 is True:
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if flash_attention_in_fp8 is True:
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if flash_attention_in_fp8 is True:
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

if flash_attention_in_fp8 is True:
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand Down Expand Up @@ -1056,10 +1014,6 @@ def forward(
global has_fused_rope
has_fused_rope = False

if self.generation_config.flash_attention_fp8 is True:
global flash_attention_in_fp8
flash_attention_in_fp8 = True

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down