diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index d9fbe9a41a12..9aeba5f18aa7 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -1062,10 +1062,21 @@ def _update_causal_mask(
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
- token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
+
+ # Find where a new image block starts: 1 if image and previous not image
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
+ is_image = token_type_ids == 1
+ new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
+ image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
+
+ same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
+ same_image_mask[image_group_ids == -1] = False # remove non-image
+ image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
+
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
- token_type_mask, 0.0
+ image_mask, 0.0
)
if attention_mask is not None:
diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py
index ab1db5eb74e7..495fe167d79c 100644
--- a/src/transformers/models/gemma3/modular_gemma3.py
+++ b/src/transformers/models/gemma3/modular_gemma3.py
@@ -781,10 +781,21 @@ def _update_causal_mask(
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
- token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
+
+ # Find where a new image block starts: 1 if image and previous not image
+ # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
+ is_image = token_type_ids == 1
+ new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
+ image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
+ image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
+
+ same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
+ same_image_mask[image_group_ids == -1] = False # remove non-image
+ image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
+
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
- token_type_mask, 0.0
+ image_mask, 0.0
)
if attention_mask is not None:
diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py
index bad6a7471e54..ac52dd46ce4e 100644
--- a/src/transformers/utils/attention_visualizer.py
+++ b/src/transformers/utils/attention_visualizer.py
@@ -36,7 +36,9 @@
WHITE_SQUARE = "⬚"
-def generate_attention_matrix_from_mask(words, mask, img_token="
", sliding_window=None, token_type_ids=None):
+def generate_attention_matrix_from_mask(
+ words, mask, img_token="
", sliding_window=None, token_type_ids=None, image_seq_length=None
+):
"""
Generates an attention matrix from a given attention mask.
@@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="
", sliding_
for j in range(n)
)
+ if token_type_ids is not None:
+ is_special = token_type_ids == 1
+ token_type_buckets = torch.where(
+ (token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
+ )
+ boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
+ token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
+
# Print headers
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
output.append(" " + legend)
@@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="
", sliding_
if sliding_window is not None
else ""
)
-
for i, word in enumerate(words):
word_repr = repr(word).ljust(max_word_length)
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
@@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="
", sliding_
if sliding_window is not None:
sliding_window_row = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
- if img_token in words[j] and img_token in words[i]
+ if img_token in words[j]
+ and img_token in words[i]
+ and token_type_buckets[0, i] == token_type_buckets[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
@@ -170,7 +181,8 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
img = Image.open(requests.get(img, stream=True).raw)
- processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
+ image_seq_length = 5
+ processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
if hasattr(processor, "image_token"):
image_token = processor.image_token
else:
@@ -179,7 +191,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
if image_token:
input_sentence = input_sentence.replace("
", image_token)
- inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
+ inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
@@ -223,6 +235,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""):
img_token=self.image_token,
sliding_window=getattr(self.config, "sliding_window", None),
token_type_ids=kwargs.get("token_type_ids", None),
+ image_seq_length=image_seq_length,
)
print(f_string)
print(f"{top_bottom_border}")