diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 5e2cafeba1..b8acd70e19 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -99,14 +99,12 @@ "starcoder2", "persimmon", "qwen2", - "starcoder2", "llava", "llava_next", "stablelm", "mamba", "deci", "qwen2_moe", - "gemma", "whisper", ] diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 9f82fae821..1c270b62f6 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -41,6 +41,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from ...modeling_rope_utils import GaudiRotaryEmbedding try: @@ -141,6 +142,7 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.block_size = 4096 + self.rotary_emb = GaudiRotaryEmbedding(config=self.config) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) @@ -155,7 +157,7 @@ def update_sincos_cache(self, seq_len): # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + self.rotary_emb._set_cos_sin_cache(seq_len, self.k_proj.weight.device, self.k_proj.weight.dtype) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -252,8 +254,8 @@ def pre_attn_forward( else: kv_seq_len = past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos[position_ids], sin[position_ids]) if use_cache: # reuse k, v, self_attention @@ -697,6 +699,15 @@ def forward( class GaudiGemmaForCausalLM(GemmaForCausalLM): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None,