Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
"observer": "maxabs",
"scale_method": "maxabs_hw",
"blocklist": {"types": [], "names": [
"matmul_qk",
"matmul_av",
"lm_head"
]},
"dump_stats_path": "./hqt_output/measure"
Expand Down
86 changes: 66 additions & 20 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,36 @@ def gaudi_gemma_repeat_kv(
return query_states, key_states, value_states, attention_mask


class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
enable_recompute,
):
import habana_frameworks.torch.hpu as ht

with ht.sdp_kernel(enable_recompute=enable_recompute):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -145,6 +175,8 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.block_size = 4096
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)

self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
Expand Down Expand Up @@ -174,7 +206,9 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)

def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size):
def gaudi_flash_attn_v1(
self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size, enable_recompute
):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Causal mask is not supported in this optimization
Expand All @@ -191,7 +225,9 @@ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mas
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None)
attn_output_partial = self.fused_scaled_dot_product_attention(
row_q, key_layer, value_layer, row_mask, dropout_rate, False, None, enable_recompute
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)

Expand Down Expand Up @@ -286,32 +322,42 @@ def pre_attn_forward(
past_key_value = None

if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, use_recompute
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, flash_attention_recompute
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
if q_len > 16384:
attn_output = self.gaudi_flash_attn_v1(
query_states, key_states, value_states, attention_mask, 0.0, self.block_size
)
htcore.mark_step()
else:
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
if q_len > 16384:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this 16384 number from? If it's copied it might needs to adjusted based on the model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess previous who enabling this flash attention might be copy the code from Starcoder2 model.

attn_output = self.gaudi_flash_attn_v1(
query_states,
key_states,
value_states,
attention_mask,
0.0,
self.block_size,
flash_attention_recompute,
)
htcore.mark_step()
else:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
flash_attention_recompute,
)

else:
query_states, key_states, value_states, attention_mask = gaudi_gemma_repeat_kv(
Expand Down