Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2",
"allowlist": {"types": [], "names": []},
"blocklist": {"types": [], "names": []},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
Comment thread
libinta marked this conversation as resolved.
3 changes: 3 additions & 0 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, tokenizer, model, args, options):
self.model_inputs.update(
{
"attn_softmax_bf16": self.options.attn_softmax_bf16,
"use_flash_attention": self.options.use_flash_attention,
"flash_attention_recompute": self.options.flash_attention_recompute,
"flash_attention_causal_mask": self.options.flash_attention_causal_mask,
}
)
if args.warmup:
Expand Down
23 changes: 19 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import math
import warnings
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -140,6 +141,16 @@ def gaudi_llama_repeat_kv(
return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -267,6 +278,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

Expand Down Expand Up @@ -423,19 +435,22 @@ def pre_attn_forward(

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)

Expand Down