Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ class Aimv2Output(ModelOutput):
vision_model_output: BaseModelOutputWithPooling = None

def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())


@use_kernel_forward_from_hub("RMSNorm")
Expand Down
136 changes: 28 additions & 108 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.generic import is_flash_attention_requested
from ...utils.generic import check_model_inputs, is_flash_attention_requested
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig


Expand Down Expand Up @@ -85,10 +85,7 @@ class AltCLIPOutput(ModelOutput):
vision_model_output: BaseModelOutputWithPooling = None

def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())


# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->AltRoberta
Expand Down Expand Up @@ -482,7 +479,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
causal_attention_mask: torch.Tensor | None = None,
output_attentions: bool | None = False,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -518,12 +515,11 @@ def forward(
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights


Expand Down Expand Up @@ -557,26 +553,16 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: bool | None = False,
) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, torch.Tensor | None]:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states

Expand All @@ -585,12 +571,7 @@ def forward(
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
outputs += (attn_weights,)

return outputs
return hidden_states, attn_weights


class AltCLIPEncoder(nn.Module):
Expand All @@ -608,16 +589,13 @@ def __init__(self, config: AltCLIPConfig):
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

@can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: torch.Tensor | None = None,
causal_attention_mask: torch.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | BaseModelOutput:
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Expand All @@ -638,45 +616,20 @@ def forward(
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
**kwargs,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
last_hidden_state=hidden_states,
)


Expand Down Expand Up @@ -770,6 +723,10 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
base_model_prefix = "altclip"
input_modalities = ("image", "text")
supports_gradient_checkpointing = True
_can_record_outputs = {
"hidden_states": AltCLIPEncoderLayer,
"attentions": AltCLIPAttention,
}
_no_split_module = []

@torch.no_grad()
Expand Down Expand Up @@ -833,22 +790,13 @@ def __init__(self, config: AltCLIPVisionConfig):
self.encoder = AltCLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
interpolate_pos_encoding: bool | None = False,
) -> tuple | BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

Expand All @@ -857,9 +805,7 @@ def forward(

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)

last_hidden_state = encoder_outputs[0]
Expand All @@ -869,8 +815,6 @@ def forward(
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand All @@ -888,16 +832,14 @@ def __init__(self, config: AltCLIPVisionConfig):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
interpolate_pos_encoding: bool = False,
return_dict: bool | None = None,
**kwargs,
) -> tuple | BaseModelOutputWithPooling:
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
r"""
Examples:

Expand All @@ -920,14 +862,10 @@ def forward(
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
**kwargs,
)


Expand Down Expand Up @@ -1194,7 +1132,7 @@ def get_text_features(

return text_outputs

@can_return_tuple
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def get_image_features(
self,
Expand Down Expand Up @@ -1223,14 +1161,14 @@ def get_image_features(
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
**kwargs,
)
pooled_output = vision_outputs.pooler_output
vision_outputs.pooler_output = self.visual_projection(pooled_output)

return vision_outputs

@can_return_tuple
@auto_docstring
def forward(
self,
Expand All @@ -1240,11 +1178,8 @@ def forward(
position_ids: torch.LongTensor | None = None,
token_type_ids: torch.Tensor | None = None,
return_loss: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
interpolate_pos_encoding: bool = False,
return_dict: bool | None = None,
**kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | AltCLIPOutput:
r"""
return_loss (`bool`, *optional*):
Expand All @@ -1270,29 +1205,18 @@ def forward(
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
# Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
**kwargs,
)

image_embeds = vision_outputs[1]
Expand All @@ -1314,10 +1238,6 @@ def forward(
if return_loss:
loss = clip_loss(logits_per_text)

if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output

return AltCLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ class ClapOutput(ModelOutput):
audio_model_output: BaseModelOutputWithPooling = None

def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "audio_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())


# Adapted from transformers.models.swin.modeling_swin.SwinDropPath
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ class CLIPOutput(ModelOutput):
vision_model_output: BaseModelOutputWithPooling = None

def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())


class CLIPVisionEmbeddings(nn.Module):
Expand Down
Loading