From 231dee4adeddeb0bd7d31f9c82c608a34c2abfda Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 12 May 2025 13:03:24 +0200 Subject: [PATCH 1/4] fix attn mask --- .../models/gemma3/modeling_gemma3.py | 16 +++++++++++++++- src/transformers/models/gemma3/modular_gemma3.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index d9fbe9a41a12..6156868a9134 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -1063,9 +1063,23 @@ def _update_causal_mask( token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + + # Find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = token_type_ids == 1 + padded = nn.functional.pad(is_image, (1, 0), value=0) + new_image_start = is_image & ~padded[:, :-1] + image_group_id = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_positions = torch.where(is_image, image_group_id, torch.full_like(token_type_ids, -1)) + same_image_mask = (image_group_positions.unsqueeze(1) == image_group_positions.unsqueeze(2)) & ( + image_group_positions.unsqueeze(1) != -1 + ) + same_image_mask = same_image_mask.unsqueeze(1) + + image_mask = token_type_mask & same_image_mask causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - token_type_mask, 0.0 + image_mask, 0.0 ) if attention_mask is not None: diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index ab1db5eb74e7..67cb40948e5c 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -782,9 +782,23 @@ def _update_causal_mask( token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + + # Find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = token_type_ids == 1 + padded = nn.functional.pad(is_image, (1, 0), value=0) + new_image_start = is_image & ~padded[:, :-1] + image_group_id = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_positions = torch.where(is_image, image_group_id, torch.full_like(token_type_ids, -1)) + same_image_mask = (image_group_positions.unsqueeze(1) == image_group_positions.unsqueeze(2)) & ( + image_group_positions.unsqueeze(1) != -1 + ) + same_image_mask = same_image_mask.unsqueeze(1) + + image_mask = token_type_mask & same_image_mask causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - token_type_mask, 0.0 + image_mask, 0.0 ) if attention_mask is not None: From 444f23f54ae2b9ba10d389fffe1e09cd0d4ca9ff Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 20 May 2025 13:34:36 +0200 Subject: [PATCH 2/4] attn viz doesn't show yello cubes between images --- .../utils/attention_visualizer.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py index bad6a7471e54..ac52dd46ce4e 100644 --- a/src/transformers/utils/attention_visualizer.py +++ b/src/transformers/utils/attention_visualizer.py @@ -36,7 +36,9 @@ WHITE_SQUARE = "⬚" -def generate_attention_matrix_from_mask(words, mask, img_token="", sliding_window=None, token_type_ids=None): +def generate_attention_matrix_from_mask( + words, mask, img_token="", sliding_window=None, token_type_ids=None, image_seq_length=None +): """ Generates an attention matrix from a given attention mask. @@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="", sliding_ for j in range(n) ) + if token_type_ids is not None: + is_special = token_type_ids == 1 + token_type_buckets = torch.where( + (token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0 + ) + boundaries = torch.arange(0, image_seq_length + 1, image_seq_length) + token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries) + # Print headers legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids" output.append(" " + legend) @@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="", sliding_ if sliding_window is not None else "" ) - for i, word in enumerate(words): word_repr = repr(word).ljust(max_word_length) colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr @@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="", sliding_ if sliding_window is not None: sliding_window_row = " ".join( f"{YELLOW}{BLACK_SQUARE}{RESET}" - if img_token in words[j] and img_token in words[i] + if img_token in words[j] + and img_token in words[i] + and token_type_buckets[0, i] == token_type_buckets[0, j] else f"{GREEN}{BLACK_SQUARE}{RESET}" if i == j else BLACK_SQUARE @@ -170,7 +181,8 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): if self.config.model_type in PROCESSOR_MAPPING_NAMES: img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true" img = Image.open(requests.get(img, stream=True).raw) - processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5) + image_seq_length = 5 + processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length) if hasattr(processor, "image_token"): image_token = processor.image_token else: @@ -179,7 +191,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): if image_token: input_sentence = input_sentence.replace("", image_token) - inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt") + inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt") self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0] @@ -223,6 +235,7 @@ def visualize_attention_mask(self, input_sentence: str, suffix=""): img_token=self.image_token, sliding_window=getattr(self.config, "sliding_window", None), token_type_ids=kwargs.get("token_type_ids", None), + image_seq_length=image_seq_length, ) print(f_string) print(f"{top_bottom_border}") From 9ed9c17cdafb5f648d9171ade4cbf8e8ff45a1f8 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 20 May 2025 14:07:24 +0200 Subject: [PATCH 3/4] bucketize made it hard with different number of crops --- .../models/gemma3/modeling_gemma3.py | 17 +++++++---------- .../models/gemma3/modular_gemma3.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 6156868a9134..9aeba5f18aa7 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -1062,21 +1062,18 @@ def _update_causal_mask( if token_type_ids is not None and sequence_length != 1: token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything - token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) # Find where a new image block starts: 1 if image and previous not image # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally is_image = token_type_ids == 1 - padded = nn.functional.pad(is_image, (1, 0), value=0) - new_image_start = is_image & ~padded[:, :-1] - image_group_id = torch.cumsum(new_image_start.int(), dim=1) - 1 - image_group_positions = torch.where(is_image, image_group_id, torch.full_like(token_type_ids, -1)) - same_image_mask = (image_group_positions.unsqueeze(1) == image_group_positions.unsqueeze(2)) & ( - image_group_positions.unsqueeze(1) != -1 - ) - same_image_mask = same_image_mask.unsqueeze(1) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) + + same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2) + same_image_mask[image_group_ids == -1] = False # remove non-image + image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - image_mask = token_type_mask & same_image_mask causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( image_mask, 0.0 diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 67cb40948e5c..694687bb78fd 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -786,16 +786,14 @@ def _update_causal_mask( # Find where a new image block starts: 1 if image and previous not image # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally is_image = token_type_ids == 1 - padded = nn.functional.pad(is_image, (1, 0), value=0) - new_image_start = is_image & ~padded[:, :-1] - image_group_id = torch.cumsum(new_image_start.int(), dim=1) - 1 - image_group_positions = torch.where(is_image, image_group_id, torch.full_like(token_type_ids, -1)) - same_image_mask = (image_group_positions.unsqueeze(1) == image_group_positions.unsqueeze(2)) & ( - image_group_positions.unsqueeze(1) != -1 - ) - same_image_mask = same_image_mask.unsqueeze(1) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) + + same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2) + same_image_mask[image_group_ids == -1] = False # remove non-image + image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - image_mask = token_type_mask & same_image_mask causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( image_mask, 0.0 From a96c73dd8a52e23e69aa696329e2aa703cd2eb29 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 20 May 2025 14:28:38 +0200 Subject: [PATCH 4/4] fixup --- src/transformers/models/gemma3/modular_gemma3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 694687bb78fd..495fe167d79c 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -781,7 +781,6 @@ def _update_causal_mask( if token_type_ids is not None and sequence_length != 1: token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything - token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) # Find where a new image block starts: 1 if image and previous not image # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally