diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index f3e86bdef8..b48a46a86a 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -20,7 +20,6 @@ import torch from transformers.cache_utils import Cache -from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -797,9 +796,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - ignore_cache_position = True # Ignoring cache position for HPU - # use_new_cache = False # Ignoring new Cache path for HPU - past_seen_tokens = 0 if past_key_values is not None and use_cache: # kept for BC (cache positions) @@ -812,50 +808,19 @@ def forward( # HPU uses legacy cache path (use_new_cache = False) past_seen_tokens = past_key_values[0][0].shape[2] - if ignore_cache_position is False: - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None and cache_position: - position_ids = cache_position.unsqueeze(0) - else: - if position_ids is None: - position_ids = torch.arange( - past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device - ) - position_ids = position_ids.unsqueeze(0) - cache_position = None - - # HPU specific mask generation - if ignore_cache_position: - causal_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - input_ids.shape if input_ids is not None else (batch_size, seq_length), - inputs_embeds, - past_seen_tokens, + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device ) - else: - # It may already have been prepared by e.g. `generate` - if not isinstance(causal_mask_mapping := attention_mask, dict): - # Prepare mask arguments - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_seen_tokens, - "position_ids": position_ids, - } - # Create the masks - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - } - # The sliding window alternating layers are not always activated depending on the config - if self.has_sliding_layers: - causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) - causal_mask = causal_mask_mapping + position_ids = position_ids.unsqueeze(0) + cache_position = None # HPU path ignores explicit cache positions + + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) # embed positions hidden_states = inputs_embeds diff --git a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 2baad24cca..9033c29aae 100644 --- a/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/optimum/habana/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -25,9 +25,8 @@ import torch import torch.nn.functional as F from torch import nn -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache from transformers.integrations.deepspeed import is_deepspeed_available -from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -894,9 +893,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - ignore_cache_position = True # Ignoring cache position for HPU - # use_new_cache = False # Ignoring new Cache path for HPU - past_seen_tokens = 0 if past_key_values is not None and use_cache: # kept for BC (cache positions) @@ -910,44 +906,20 @@ def forward( if past_key_values[0] is not None: ##added for (None, None) past_seen_tokens = past_key_values[0][0].shape[2] - if ignore_cache_position is False: - if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None and cache_position: - position_ids = cache_position.unsqueeze(0) - - else: - if position_ids is None: - position_ids = torch.arange( - past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device - ) - position_ids = position_ids.unsqueeze(0) - cache_position = None + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None # HPU specific mask generation - if ignore_cache_position: - causal_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - input_ids.shape if input_ids is not None else (batch_size, seq_length), - inputs_embeds, - past_seen_tokens, - ) - else: - mask_function = ( - create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask - ) - causal_mask = mask_function( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_seen_tokens, - position_ids=position_ids, - ) + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) # embed positions hidden_states = inputs_embeds