Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions optimum/habana/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# limitations under the License.
"""PyTorch Mistral model."""

import os
import math
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -47,7 +48,6 @@
_gaudi_prepare_4d_causal_attention_mask,
)


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE

Expand All @@ -68,6 +68,15 @@
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)

class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -113,16 +122,6 @@ def gaudi_mistral_repeat_kv(

return query_states, key_states, value_states, attention_mask


def update_sincos_cache(self, seq_len):
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
# This helps in avoiding creation of these caches during actual model forward pass and
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)


def gaudi_mistral_rmsnorm_forward(self, hidden_states):
"""
Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
Expand Down Expand Up @@ -153,6 +152,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
self.v_cache = KVCache()
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

Expand All @@ -177,6 +177,14 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)

def update_sincos_cache(self, seq_len):
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
# This helps in avoiding creation of these caches during actual model forward pass and
# reduce memory consumption and improve performance.
if seq_len > self.max_position_embeddings:
self.max_position_embeddings = seq_len
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -263,19 +271,20 @@ def forward(
import habana_frameworks.torch.hpu as ht
if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):#False):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, None, 0.0, True, None)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
Expand Down Expand Up @@ -421,6 +430,10 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)

def update_sincos_cache(self, seq_len):
for layer in self.layers:
layer.update_sincos_cache(seq_len)

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -535,6 +548,7 @@ def forward(
and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
):
htcore.mark_step()

if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down