diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 7f303f965a..198dd8e512 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -344,6 +344,7 @@ def forward( ) if FusedSDPA and use_flash_attention: + attn_weights = None import habana_frameworks.torch.hpu as ht if q_len == 1: @@ -433,6 +434,7 @@ def forward( past_key_value = [key_states, value_states] if FusedSDPA and use_flash_attention: + attn_weights = None import habana_frameworks.torch.hpu as ht if q_len == 1: @@ -877,6 +879,13 @@ def __init__(self, config: MllamaConfig): # sdpa is better for vision model in HPU config._attn_implementation = "sdpa" super().__init__(config) + self.multi_modal_projector = self.model.multi_modal_projector + self.hidden_size = config.text_config.hidden_size + self._language_model = GaudiMllamaForCausalLM._from_config(config.text_config) + + @property + def language_model(self): + return self._language_model def forward( self,