Skip to content
505 changes: 505 additions & 0 deletions optimum/habana/transformers/modeling_attn_mask_utils.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def forward(
if use_cache is True:
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (key.to(hidden_states.dtype), value)
#present = (key.to(hidden_states.dtype), value)
present = (key, value)
else:
present = None

Expand Down
8 changes: 4 additions & 4 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SDPContext = False

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV1 as FusedRoPE
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None
Expand All @@ -29,7 +29,7 @@
import habana_frameworks.torch.core as htcore
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import (
from optimum.habana.transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand All @@ -54,8 +54,8 @@
def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
# TODO: remove `.clone()` when SynapseAI v1.15 is released
return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply(
k, cos.clone(), sin.clone(), position_ids
return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def prepare_inputs_for_generation(
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(position_ids, 1, token_idx - 1).unsqueeze(-1)
position_ids = torch.index_select(position_ids, 1, token_idx - 1)#.unsqueeze(-1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,6 @@ def prepare_inputs_for_generation(

def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
return FusedRoPE.apply(q, cos, sin, position_ids), FusedRoPE.apply(k, cos, sin, position_ids)
return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
60 changes: 39 additions & 21 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
logger,
)
Expand Down Expand Up @@ -102,7 +103,6 @@ def forward(self, x, y):
class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.past_key = None
Expand Down Expand Up @@ -224,6 +224,7 @@ def pre_attn_forward(
kv_seq_len = past_key_value[0][-2]
else:
kv_seq_len = past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)

Expand Down Expand Up @@ -354,6 +355,17 @@ def post_mlp_forward(self, x):


class GaudiLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.hidden_size = config.hidden_size

self.self_attn = GaudiLlamaAttention(config=config, layer_idx=layer_idx)

self.mlp = GaudiLlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)

Expand Down Expand Up @@ -531,15 +543,22 @@ def forward(
)
use_cache = False

#seq_length_with_past = seq_length
past_key_values_length = 0
if use_cache:
if reuse_cache:
past_key_values_length = past_key_values[0][0][2]
else:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
use_legacy_cache = True
do_not_use_new_cache = True # Ignoring new Cache path for HPU
if past_key_values is not None:
if use_cache:
if reuse_cache:
past_key_values_length = past_key_values[0][2] #past_key_values[0][0][2]
else:
if not do_not_use_new_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
#seq_length_with_past = seq_length_with_past + past_key_values_length


if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand All @@ -550,21 +569,20 @@ def forward(

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

key_value_length = seq_length + past_key_values_length
if self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
key_value_length,
past_key_value_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, key_value_length

attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# embed positions
Expand All @@ -573,9 +591,9 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
next_decoder_cache = () if do_not_use_new_cache else None

for decoder_layer in self.layers:
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand All @@ -585,7 +603,7 @@ def forward(
hidden_states,
attention_mask,
position_ids,
past_key_values,
None if past_key_values is None else past_key_values[layer_idx],
output_attentions,
use_cache,
attn_softmax_bf16=attn_softmax_bf16,
Expand All @@ -597,7 +615,7 @@ def forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
past_key_value=None if past_key_values is None else past_key_values[layer_idx],
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
Expand All @@ -610,7 +628,7 @@ def forward(
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -623,7 +641,7 @@ def forward(

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if do_not_use_new_cache else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down Expand Up @@ -814,8 +832,8 @@ def prepare_inputs_for_generation(
def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
# TODO: remove `.clone()` when SynapseAI v1.15 is released
return FusedRoPE.apply(q, cos.clone(), sin.clone(), position_ids), FusedRoPE.apply(
k, cos.clone(), sin.clone(), position_ids
return FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
39 changes: 21 additions & 18 deletions optimum/habana/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
from optimum.habana.transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
Expand Down Expand Up @@ -77,13 +77,13 @@ def gaudi_mistral_attn_forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
shp = past_key_value[0].shape[-2] if type(past_key_value) == type(tuple()) else past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if token_idx is not None:
kv_seq_len = past_key_value[0].shape[-2]
kv_seq_len = shp
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += shp
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
if token_idx is not None:
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
Expand All @@ -94,6 +94,7 @@ def gaudi_mistral_attn_forward(
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -234,12 +235,14 @@ def gaudi_mistral_model_forward(
use_cache = False

past_key_values_length = 0

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
use_legacy_cache = True
do_not_use_new_cache = True
if past_key_values is not None:
if use_cache and not do_not_use_new_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -277,19 +280,20 @@ def gaudi_mistral_model_forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
next_decoder_cache = () if use_cache else None

for decoder_layer in self.layers:
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)


if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
None if past_key_values is None else past_key_values[layer_idx],
output_attentions,
use_cache,
)
Expand All @@ -298,7 +302,7 @@ def gaudi_mistral_model_forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
past_key_value=None if past_key_values is None else past_key_values[layer_idx],
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
Expand All @@ -307,7 +311,7 @@ def gaudi_mistral_model_forward(
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -319,9 +323,8 @@ def gaudi_mistral_model_forward(
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache

if next_decoder_cache and use_cache:
next_cache = next_decoder_cache if do_not_use_new_cache else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down