Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 12 additions & 47 deletions optimum/habana/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
from transformers.cache_utils import Cache
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -797,9 +796,6 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

ignore_cache_position = True # Ignoring cache position for HPU
# use_new_cache = False # Ignoring new Cache path for HPU

past_seen_tokens = 0

if past_key_values is not None and use_cache: # kept for BC (cache positions)
Expand All @@ -812,50 +808,19 @@ def forward(
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None and cache_position:
position_ids = cache_position.unsqueeze(0)
else:
if position_ids is None:
position_ids = torch.arange(
past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)
cache_position = None

# HPU specific mask generation
if ignore_cache_position:
causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
if position_ids is None:
position_ids = torch.arange(
past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
)
else:
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_seen_tokens,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
causal_mask = causal_mask_mapping
position_ids = position_ids.unsqueeze(0)
cache_position = None # HPU path ignores explicit cache positions

causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
)

# embed positions
hidden_states = inputs_embeds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache, StaticCache
from transformers.cache_utils import Cache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
MoeCausalLMOutputWithPast,
Expand Down Expand Up @@ -894,9 +893,6 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

ignore_cache_position = True # Ignoring cache position for HPU
# use_new_cache = False # Ignoring new Cache path for HPU

past_seen_tokens = 0

if past_key_values is not None and use_cache: # kept for BC (cache positions)
Expand All @@ -910,44 +906,20 @@ def forward(
if past_key_values[0] is not None: ##added for (None, None)
past_seen_tokens = past_key_values[0][0].shape[2]

if ignore_cache_position is False:
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None and cache_position:
position_ids = cache_position.unsqueeze(0)

else:
if position_ids is None:
position_ids = torch.arange(
past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)
cache_position = None
if position_ids is None:
position_ids = torch.arange(
past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)
cache_position = None

# HPU specific mask generation
if ignore_cache_position:
causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
)
else:
mask_function = (
create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
)
causal_mask = mask_function(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_seen_tokens,
position_ids=position_ids,
)
causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask,
input_ids.shape if input_ids is not None else (batch_size, seq_length),
inputs_embeds,
past_seen_tokens,
)

# embed positions
hidden_states = inputs_embeds
Expand Down
Loading