diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 6860cf30d884..3546b95e09cf 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -1080,7 +1080,7 @@ def forward( output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, - ) -> BaseModelOutput | tuple[torch.Tensor, ...]: + ) -> BaseModelOutputWithPooling | tuple[torch.Tensor, ...]: r""" Example: @@ -1166,7 +1166,7 @@ def forward( if not return_dict: return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) - return BaseModelOutput( + return BaseModelOutputWithPooling( last_hidden_state=hidden_state, hidden_states=hidden_states, attentions=attentions, @@ -1321,7 +1321,7 @@ def forward( pixel_values=pixel_values, vision_feature_select_strategy=vision_feature_select_strategy, return_dict=True, - ).pooler_output + ).last_hidden_state vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat).to(