Skip to content
84 changes: 55 additions & 29 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -281,6 +281,13 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if q_len > 1:
# prefill
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
else:
# decoding
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)
Comment on lines +284 to +289
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very much supprised that this now works in torch compile with reduce overhead, as when testing this use to always create the same tensor (constant cache lenght) outputs were different. WOuld need investigation on which torch version supports this!


if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand Down Expand Up @@ -340,7 +347,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand All @@ -366,6 +373,13 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if q_len > 1:
# prefill
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
else:
# decoding
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand Down Expand Up @@ -531,7 +545,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
Expand All @@ -546,7 +560,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_length=cache_length,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -562,6 +576,13 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if q_len > 1:
# prefill
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
else:
# decoding
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
Expand All @@ -585,6 +606,11 @@ def forward(
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

if cache_length > 0:
key_states = key_states[:, :, :cache_length, :]
value_states = value_states[:, :, :cache_length, :]
causal_mask = causal_mask[:, :, :, :cache_length] if causal_mask is not None else causal_mask

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down Expand Up @@ -628,7 +654,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -662,7 +688,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_length=cache_length,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -850,7 +876,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -878,17 +904,18 @@ def forward(
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_length is None:
cache_length = past_seen_tokens + inputs_embeds.shape[1]

if position_ids is None:
cache_position = torch.arange(
past_seen_tokens, cache_length, device=inputs_embeds.device
)
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask, inputs_embeds, cache_length, past_key_values, output_attentions
)

# embed positions
Expand Down Expand Up @@ -925,7 +952,7 @@ def forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
cache_length,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -935,7 +962,7 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_length=cache_length,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -969,7 +996,7 @@ def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
cache_length: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
Expand Down Expand Up @@ -1017,12 +1044,12 @@ def _update_causal_mask(
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
else:
# This computation is only required when `sequence_length = 1` in the case of static cache.
causal_mask *= torch.arange(target_length, device=device) > cache_length
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
Expand Down Expand Up @@ -1090,7 +1117,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_length: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1134,7 +1161,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
cache_length=cache_length,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1171,14 +1198,14 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
cached_length=None,
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
past_length = past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
Expand Down Expand Up @@ -1223,15 +1250,14 @@ def prepare_inputs_for_generation(
model_inputs = {"input_ids": input_ids.contiguous()}

input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
if cached_length is None:
# It must be a python int
cached_length = int(past_length + input_length)

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"cache_length": cached_length,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
Expand Down