Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/transformers/models/edgetam/modeling_edgetam.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ class EdgeTamVisionEncoderOutput(BaseModelOutputWithPooling):
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
fpn_position_embeddings (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""

fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
fpn_position_embeddings: torch.FloatTensor | None = None


def eager_attention_forward(
Expand Down Expand Up @@ -392,7 +392,7 @@ def __init__(self, config: EdgeTamVisionConfig):

def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
fpn_position_embeddings = ()

# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
Expand All @@ -416,9 +416,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...]
).to(prev_features.dtype)

fpn_hidden_states += (prev_features,)
fpn_position_encoding += (prev_position_encoding,)
fpn_position_embeddings += (prev_position_encoding,)

return fpn_hidden_states, fpn_position_encoding
return fpn_hidden_states, fpn_position_embeddings


@auto_docstring(
Expand Down Expand Up @@ -458,15 +458,15 @@ def forward(
intermediate_hidden_states = backbone_output.last_hidden_state
intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]

fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
fpn_hidden_states, fpn_position_embeddings = self.neck(intermediate_hidden_states)
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
fpn_position_embeddings = fpn_position_embeddings[-self.num_feature_levels :][::-1]

return EdgeTamVisionEncoderOutput(
last_hidden_state=intermediate_hidden_states[-1],
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
fpn_position_embeddings=fpn_position_embeddings,
hidden_states=backbone_output.hidden_states,
)

Expand Down Expand Up @@ -1215,7 +1215,7 @@ def get_image_features(
vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)

feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
feature_maps_position_embeddings = vision_outputs.fpn_position_embeddings

# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand All @@ -1230,7 +1230,7 @@ def get_image_features(
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
vision_outputs.fpn_position_embeddings = feature_maps_position_embeddings

return vision_outputs

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/edgetam/modular_edgetam.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,15 @@ def forward(
intermediate_hidden_states = backbone_output.last_hidden_state
intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]

fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
fpn_hidden_states, fpn_position_embeddings = self.neck(intermediate_hidden_states)
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
fpn_position_embeddings = fpn_position_embeddings[-self.num_feature_levels :][::-1]

return EdgeTamVisionEncoderOutput(
last_hidden_state=intermediate_hidden_states[-1],
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
fpn_position_embeddings=fpn_position_embeddings,
hidden_states=backbone_output.hidden_states,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ class EdgeTamVideoVisionEncoderOutput(BaseModelOutputWithPooling):
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
fpn_position_embeddings (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""

fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
fpn_position_embeddings: torch.FloatTensor | None = None


class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -2244,7 +2244,7 @@ def get_image_features(
vision_outputs: EdgeTamVideoVisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)

feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
feature_maps_position_embeddings = vision_outputs.fpn_position_embeddings

# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand All @@ -2259,7 +2259,7 @@ def get_image_features(
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
vision_outputs.fpn_position_embeddings = feature_maps_position_embeddings

return vision_outputs

Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/sam2/modeling_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
fpn_position_embeddings (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""

fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
fpn_position_embeddings: torch.FloatTensor | None = None


@dataclass
Expand Down Expand Up @@ -218,7 +218,7 @@ def __init__(self, config: Sam2VisionConfig):

def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
fpn_position_embeddings = ()

# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
Expand All @@ -242,9 +242,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...]
).to(prev_features.dtype)

fpn_hidden_states += (prev_features,)
fpn_position_encoding += (prev_position_encoding,)
fpn_position_embeddings += (prev_position_encoding,)

return fpn_hidden_states, fpn_position_encoding
return fpn_hidden_states, fpn_position_embeddings


def eager_attention_forward(
Expand Down Expand Up @@ -684,15 +684,15 @@ def forward(
hidden_states = backbone_output.last_hidden_state
intermediate_hidden_states = backbone_output.intermediate_hidden_states

fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
fpn_hidden_states, fpn_position_embeddings = self.neck(intermediate_hidden_states)
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
fpn_position_embeddings = fpn_position_embeddings[-self.num_feature_levels :][::-1]

return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
fpn_position_embeddings=fpn_position_embeddings,
)


Expand Down Expand Up @@ -1580,7 +1580,7 @@ def get_image_features(
vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)

feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
feature_maps_position_embeddings = vision_outputs.fpn_position_embeddings

# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand All @@ -1595,7 +1595,7 @@ def get_image_features(
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
vision_outputs.fpn_position_embeddings = feature_maps_position_embeddings

return vision_outputs

Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/sam2/modular_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,13 @@ class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
fpn_position_embeddings (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""

fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
fpn_position_embeddings: torch.FloatTensor | None = None


@dataclass
Expand Down Expand Up @@ -408,7 +408,7 @@ def __init__(self, config: Sam2VisionConfig):

def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
fpn_position_embeddings = ()

# forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
Expand All @@ -432,9 +432,9 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...]
).to(prev_features.dtype)

fpn_hidden_states += (prev_features,)
fpn_position_encoding += (prev_position_encoding,)
fpn_position_embeddings += (prev_position_encoding,)

return fpn_hidden_states, fpn_position_encoding
return fpn_hidden_states, fpn_position_embeddings


def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
Expand Down Expand Up @@ -789,15 +789,15 @@ def forward(
hidden_states = backbone_output.last_hidden_state
intermediate_hidden_states = backbone_output.intermediate_hidden_states

fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
fpn_hidden_states, fpn_position_embeddings = self.neck(intermediate_hidden_states)
# Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
fpn_position_embeddings = fpn_position_embeddings[-self.num_feature_levels :][::-1]

return Sam2VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
fpn_position_embeddings=fpn_position_embeddings,
)


Expand Down Expand Up @@ -1255,7 +1255,7 @@ def get_image_features(
vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)

feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
feature_maps_position_embeddings = vision_outputs.fpn_position_embeddings

# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand All @@ -1270,7 +1270,7 @@ def get_image_features(
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
vision_outputs.fpn_position_embeddings = feature_maps_position_embeddings

return vision_outputs

Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/sam2_video/modeling_sam2_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,13 +1167,13 @@ class Sam2VideoVisionEncoderOutput(BaseModelOutputWithPooling):
fpn_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
fpn_position_encoding (`tuple(torch.FloatTensor)`):
fpn_position_embeddings (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
`(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
"""

fpn_hidden_states: torch.FloatTensor | None = None
fpn_position_encoding: torch.FloatTensor | None = None
fpn_position_embeddings: torch.FloatTensor | None = None


class Sam2VideoMaskEmbedding(nn.Module):
Expand Down Expand Up @@ -1848,7 +1848,7 @@ def get_image_features(
vision_outputs: Sam2VideoVisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)

feature_maps = vision_outputs.fpn_hidden_states
feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
feature_maps_position_embeddings = vision_outputs.fpn_position_embeddings

# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
Expand All @@ -1863,7 +1863,7 @@ def get_image_features(
for feature_map_position_embedding in feature_maps_position_embeddings
]
vision_outputs.fpn_hidden_states = feature_maps
vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
vision_outputs.fpn_position_embeddings = feature_maps_position_embeddings

return vision_outputs

Expand Down
20 changes: 10 additions & 10 deletions src/transformers/models/sam3/modeling_sam3.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ class Sam3VisionEncoderOutput(BaseModelOutputWithPooling):
r"""
fpn_hidden_states (`tuple[torch.FloatTensor]`):
Tuple of multi-level FPN feature maps.
fpn_position_encoding (`tuple[torch.FloatTensor]`):
fpn_position_embeddings (`tuple[torch.FloatTensor]`):
Tuple of position encodings for each FPN level.
"""

fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
fpn_position_embeddings: tuple[torch.FloatTensor, ...] = None


@dataclass
Expand Down Expand Up @@ -990,16 +990,16 @@ def __init__(self, config: Sam3VisionConfig):

def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
fpn_hidden_states = ()
fpn_position_encoding = ()
fpn_position_embeddings = ()

for fpn_layer in self.fpn_layers:
fpn_output = fpn_layer(hidden_states)
fpn_hidden_states += (fpn_output,)
# Generate position encoding for this FPN level
pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype)
fpn_position_encoding += (pos_enc,)
fpn_position_embeddings += (pos_enc,)

return fpn_hidden_states, fpn_position_encoding
return fpn_hidden_states, fpn_position_embeddings


@auto_docstring(
Expand Down Expand Up @@ -1043,12 +1043,12 @@ def forward(
height = pixel_values.shape[-2] // self.config.backbone_config.patch_size
width = pixel_values.shape[-1] // self.config.backbone_config.patch_size
hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2)
fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial)
fpn_hidden_states, fpn_position_embeddings = self.neck(hidden_states_spatial)

return Sam3VisionEncoderOutput(
last_hidden_state=hidden_states,
fpn_hidden_states=fpn_hidden_states,
fpn_position_encoding=fpn_position_encoding,
fpn_position_embeddings=fpn_position_embeddings,
)


Expand Down Expand Up @@ -2280,7 +2280,7 @@ def forward(
vision_outputs = vision_embeds

fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
fpn_position_embeddings = vision_outputs.fpn_position_embeddings[:-1]

if text_embeds is None:
text_features = self.get_text_features(
Expand Down Expand Up @@ -2319,7 +2319,7 @@ def forward(
box_mask=box_mask,
box_labels=box_labels,
img_feats=fpn_hidden_states,
img_pos_embeds=fpn_position_encoding,
img_pos_embeds=fpn_position_embeddings,
)

geometry_prompt_features = geometry_outputs.last_hidden_state
Expand Down Expand Up @@ -2352,7 +2352,7 @@ def forward(
encoder_outputs = self.detr_encoder(
vision_features=[fpn_hidden_states[-1]],
text_features=combined_prompt_features,
vision_pos_embeds=[fpn_position_encoding[-1]],
vision_pos_embeds=[fpn_position_embeddings[-1]],
text_mask=combined_prompt_mask,
**kwargs,
)
Expand Down
Loading