Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def generate(
)
model_kwargs["kv_cache_len"] = calculated_max_length

if self.config.model_type in ["llama", "falcon"]:
if self.config.model_type in ["llama", "falcon", "mistral"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)

Expand Down
136 changes: 102 additions & 34 deletions optimum/habana/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""PyTorch Mistral model."""

import math
import os
from typing import List, Optional, Tuple, Union

import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -65,6 +66,12 @@
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -109,6 +116,15 @@ def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -181,8 +197,10 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
self.v_cache = KVCache()
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.inp_seq_len = -1
self._init_rope()
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

def _init_rope(self):
"""
Expand Down Expand Up @@ -255,7 +273,9 @@ def forward(
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
use_fused_rope: Optional[bool] = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I missed all these use_fused_rope

use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -322,39 +342,63 @@ def forward(
else:
past_key_value = None

# repeat k/v heads if n_kv_heads < n_heads
query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

if attn_weights.size() not in [
(bsz, self.num_heads, q_len, kv_seq_len),
(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len),
]:
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or"
f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# repeat k/v heads if n_kv_heads < n_heads
query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor

if attention_mask is not None:
if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]:
if attn_weights.size() not in [
(bsz, self.num_heads, q_len, kv_seq_len),
(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len),
]:
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or"
f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

attn_weights = attn_weights + attention_mask
if attention_mask is not None:
if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]:
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)

if attn_softmax_bf16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
attn_weights = attn_weights + attention_mask

if attn_softmax_bf16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -405,7 +449,9 @@ def forward(
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
use_fused_rope: Optional[bool] = True,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -429,7 +475,9 @@ def forward(
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
use_fused_rope=use_fused_rope,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -458,6 +506,10 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)

def update_sincos_cache(self, seq_len):
for layer in self.layers:
layer.update_sincos_cache(seq_len)

def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -473,8 +525,10 @@ def forward(
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
use_fused_rope: Optional[bool] = True,
lazy_mode: Optional[bool] = True,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py
Expand Down Expand Up @@ -577,7 +631,12 @@ def forward(
output_attentions,
use_cache,
None,
use_fused_rope,
False,
cache_idx,
attn_softmax_bf16,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -591,7 +650,9 @@ def forward(
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
use_fused_rope=use_fused_rope,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -652,8 +713,10 @@ def forward(
trim_logits: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
use_fused_rope: Optional[bool] = True,
lazy_mode: Optional[bool] = True,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py
Expand Down Expand Up @@ -686,8 +749,10 @@ def forward(
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
use_fused_rope=use_fused_rope,
lazy_mode=lazy_mode,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -799,6 +864,9 @@ def prepare_inputs_for_generation(
"cache_idx": kwargs.get("cache_idx"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"lazy_mode": kwargs.get("lazy_mode"),
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
}
)
return model_inputs
Expand Down