Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 35 additions & 27 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""PyTorch Gemma model."""

import math
import os
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -214,7 +215,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -289,7 +290,8 @@ def pre_attn_forward(

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
Expand Down Expand Up @@ -407,23 +409,23 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
return hidden_states, attn_weights, present_key_value
Expand All @@ -443,7 +445,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
Expand All @@ -453,16 +455,16 @@ def forward(
residual = hidden_states

hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -717,6 +719,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
reuse_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand Down Expand Up @@ -746,6 +749,7 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
reuse_cache=reuse_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down Expand Up @@ -859,9 +863,13 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"reuse_cache": kwargs.get("reuse_cache"),
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
"token_idx": token_idx,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
}
)
return model_inputs