diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index bda8738516..cf5fa6f2c0 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -518,7 +518,7 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _, __): + def allocate_kv_cache(self, batch_size, seq_len, _): self.model.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor):