Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/models/sam/configuration_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class SamPromptEncoderConfig(PretrainedConfig):
The non-linear activation function in the encoder and pooler.
"""

base_config_key = "prompt_encoder_config"

def __init__(
self,
hidden_size=256,
Expand Down Expand Up @@ -102,6 +104,8 @@ class SamMaskDecoderConfig(PretrainedConfig):

"""

base_config_key = "mask_decoder_config"

def __init__(
self,
hidden_size=256,
Expand Down Expand Up @@ -181,6 +185,8 @@ class SamVisionConfig(PretrainedConfig):
hidden_size`.
"""

base_config_key = "vision_config"

def __init__(
self,
hidden_size=768,
Expand Down Expand Up @@ -278,6 +284,11 @@ class SamConfig(PretrainedConfig):
```"""

model_type = "sam"
sub_configs = {
"prompt_encoder_config": SamPromptEncoderConfig,
"mask_decoder_config": SamMaskDecoderConfig,
"vision_config": SamVisionConfig,
}

def __init__(
self,
Expand Down
167 changes: 160 additions & 7 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,47 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarit
return out


class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""

def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)

def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)

point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)

# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)

out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)

# Get output
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)

return out


SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}


class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
Expand All @@ -266,18 +307,21 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps

self.self_attn = SamAttention(config, downsample_rate=1)
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)

self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe

def forward(
Expand Down Expand Up @@ -344,7 +388,7 @@ def __init__(self, config: SamMaskDecoderConfig):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))

self.final_attn_token_to_image = SamAttention(config)
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)

def forward(
Expand Down Expand Up @@ -431,7 +475,7 @@ def forward(self, hidden_states):
class SamMaskDecoder(nn.Module):
def __init__(self, config: SamMaskDecoderConfig):
super().__init__()

self.config = config
self.hidden_size = config.hidden_size

self.num_multimask_outputs = config.num_multimask_outputs
Expand Down Expand Up @@ -856,11 +900,118 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
return outputs


class SamVisionSdpaAttention(SamVisionAttention):
"""
Multi-head Attention block with relative position embeddings.
Using SDPA instead of the default attention.
"""

def __init__(self, config, window_size):
super().__init__(config, window_size)

def add_decomposed_rel_pos(
self,
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
This method is reimplemented to follow the implementation in:
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
This implementation is more memory efficient when using SDPA in the forward method.
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

batch_size, _, dim = query.shape
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)

return rel_h, rel_w

def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(hidden_states)
.reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
.permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)

rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = self.add_decomposed_rel_pos(
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)

query = query.view(batch_size, self.num_attention_heads, height * width, -1)
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
value = value.view(batch_size, self.num_attention_heads, height * width, -1)

if self.use_rel_pos:
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)

attn_output = (
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
.permute(0, 2, 3, 1, 4)
.reshape(batch_size, height, width, -1)
)

attn_output = self.proj(attn_output)

if output_attentions:
# For output_attentions, calculate the attention weights
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
if attn_bias is not None:
attn_weights = attn_weights + attn_bias
attn_weights = F.softmax(attn_weights, dim=-1)
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)

return outputs


SAM_VISION_ATTENTION_CLASSES = {
"eager": SamVisionAttention,
"sdpa": SamVisionSdpaAttention,
}


class SamVisionLayer(nn.Module):
def __init__(self, config, window_size):
super().__init__()
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = SamVisionAttention(config, window_size)
self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SamMLPBlock(config)
self.window_size = window_size
Expand Down Expand Up @@ -1071,6 +1222,8 @@ class SamPreTrainedModel(PreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]
supports_gradient_checkpointing = True
_supports_sdpa = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
83 changes: 69 additions & 14 deletions tests/models/sam/test_modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
"""Testing suite for the PyTorch SAM model."""

import tempfile
import unittest

import requests

from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -295,6 +296,7 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
_is_composite = True

# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
Expand All @@ -311,22 +313,13 @@ def is_pipeline_test_to_skip(

def setUp(self):
self.model_tester = SamModelTester(self)
self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
self.prompt_encoder_config_tester = ConfigTester(
self,
config_class=SamPromptEncoderConfig,
has_text_modality=False,
num_attention_heads=12,
num_hidden_layers=2,
)
self.mask_decoder_config_tester = ConfigTester(
self, config_class=SamMaskDecoderConfig, has_text_modality=False
common_properties = ["initializer_range"]
self.config_tester = ConfigTester(
self, config_class=SamConfig, has_text_modality=False, common_properties=common_properties
)

def test_config(self):
self.vision_config_tester.run_common_tests()
self.prompt_encoder_config_tester.run_common_tests()
self.mask_decoder_config_tester.run_common_tests()
self.config_tester.run_common_tests()

@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
def test_inputs_embeds(self):
Expand Down Expand Up @@ -450,6 +443,68 @@ def test_model_from_pretrained(self):
model = SamModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="SAM model can't be compiled dynamic yet")

@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention".
In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model
is loaded, because we manually replicate requested attn implementation on each sub-config when loading.
See https://github.com/huggingface/transformers/pull/32238 for more info

The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model
that has a different set of sub-configs has to overwrite this test.
"""
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
model_sdpa = model_sdpa.eval().to(torch_device)

model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)

# Root model determines SDPA support
attn_impl = "sdpa" if model._supports_sdpa else "eager"

# Check config propagation to submodels that support it
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl)
self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl)

self.assertTrue(model_eager.config._attn_implementation == "eager")
self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager")

# Verify SDPA/eager layer presence
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break

if not has_sdpa and attn_impl == "sdpa":
raise ValueError("The SDPA model should have SDPA attention layers")

for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")


def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
Expand Down
Loading