diff --git a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py index fe8352d563..172a5f218d 100755 --- a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py +++ b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py @@ -899,7 +899,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, + num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, @@ -956,12 +956,7 @@ def forward( else: hidden_states = hidden_states[:, -1, :] - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: