diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 43d88232e364..3b8e1aba714c 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -507,8 +507,8 @@ def forward( # Expand per-image data in batch direction to be per-point image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = image_embeddings.repeat(point_batch_size, 1, 1, 1) - image_positional_embeddings = image_positional_embeddings.repeat(point_batch_size, 1, 1, 1) + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) # Run the transformer, image_positional_embedding are consumed point_embedding, image_embeddings, attentions = self.transformer( diff --git a/src/transformers/models/sam/modeling_tf_sam.py b/src/transformers/models/sam/modeling_tf_sam.py index a47b091a09f6..48b25ae134a8 100644 --- a/src/transformers/models/sam/modeling_tf_sam.py +++ b/src/transformers/models/sam/modeling_tf_sam.py @@ -517,8 +517,8 @@ def call( point_embeddings = tf.cast(tokens, self.iou_token.dtype) image_embeddings = image_embeddings + dense_prompt_embeddings - image_embeddings = tf.tile(image_embeddings, [point_batch_size, 1, 1, 1]) - image_positional_embeddings = tf.tile(image_positional_embeddings, [point_batch_size, 1, 1, 1]) + image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0) + image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0) point_embedding, image_embeddings, attentions = self.transformer( point_embeddings=point_embeddings,