diff --git a/keras_hub/src/models/gemma3/gemma3_attention.py b/keras_hub/src/models/gemma3/gemma3_attention.py index b21a81ea5f..7288166252 100644 --- a/keras_hub/src/models/gemma3/gemma3_attention.py +++ b/keras_hub/src/models/gemma3/gemma3_attention.py @@ -46,6 +46,7 @@ def __init__( layer_norm_epsilon=1e-6, rope_wavelength=10_000.0, rope_scaling_factor=1.0, + use_bidirectional_attention=False, dropout=0, **kwargs, ): @@ -61,6 +62,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.rope_wavelength = rope_wavelength self.rope_scaling_factor = rope_scaling_factor + self.use_bidirectional_attention = use_bidirectional_attention self.dropout = dropout self._kernel_initializer = keras.initializers.get( @@ -240,12 +242,58 @@ def _compute_attention( results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length): + """Computes a bidirectional sliding window attention mask. + + A token can attend to any other token if their absolute distance is + within half the sliding window size. This mask is used in embedding + models like `EmbeddingGemma`. + + Args: + batch_size: The batch size for the mask. + sequence_length: The length of the sequence. + + Returns: + A boolean attention mask with shape + `(batch_size, sequence_length, sequence_length)`. + """ + i = keras.ops.expand_dims( + keras.ops.arange(sequence_length, dtype="int32"), axis=1 + ) + j = keras.ops.arange(sequence_length, dtype="int32") + + # If sliding window size is 4, the token in question attends to 1 + # token before and 2 tokens after. + w_right = self.sliding_window_size // 2 + w_left = self.sliding_window_size - w_right - 1 + + # Calculate the relative distance. + distance = i - j + + mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right) + + mask = keras.ops.expand_dims(mask, axis=0) + return keras.ops.broadcast_to( + mask, (batch_size, sequence_length, sequence_length) + ) + def _mask_sliding_window( self, attention_mask, cache_update_index=0, ): batch_size, query_len, key_len = ops.shape(attention_mask) + + if self.use_bidirectional_attention: + bidirectional_sliding_mask = ( + self._compute_bidirectional_sliding_mask( + batch_size=batch_size, + # `query_len = key_len` for embedding models + sequence_length=query_len, + ) + ) + return ops.logical_and(attention_mask, bidirectional_sliding_mask) + # Compute the sliding window for square attention. all_ones = ops.ones((key_len, key_len), "bool") if keras.config.backend() == "tensorflow": diff --git a/keras_hub/src/models/gemma3/gemma3_backbone.py b/keras_hub/src/models/gemma3/gemma3_backbone.py index df0c3c4d44..0a3a5f027a 100644 --- a/keras_hub/src/models/gemma3/gemma3_backbone.py +++ b/keras_hub/src/models/gemma3/gemma3_backbone.py @@ -196,6 +196,7 @@ def __init__( global_rope_scaling_factor=1.0, vision_encoder=None, layer_norm_epsilon=1e-6, + use_bidirectional_attention=False, dropout=0, dtype=None, **kwargs, @@ -251,6 +252,7 @@ def __init__( sliding_window_size=sliding_window_size, rope_wavelength=rope_wavelength, rope_scaling_factor=rope_scaling_factor, + use_bidirectional_attention=use_bidirectional_attention, dropout=dropout, dtype=dtype, name=f"decoder_block_{i}", @@ -357,6 +359,7 @@ def __init__( self.sliding_window_size = sliding_window_size self.local_rope_scaling_factor = local_rope_scaling_factor self.global_rope_scaling_factor = global_rope_scaling_factor + self.use_bidirectional_attention = use_bidirectional_attention self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout @@ -396,6 +399,7 @@ def get_config(self): "vision_encoder": None if self.vision_encoder is None else keras.layers.serialize(self.vision_encoder), + "use_bidirectional_attention": self.use_bidirectional_attention, "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, } diff --git a/keras_hub/src/models/gemma3/gemma3_decoder_block.py b/keras_hub/src/models/gemma3/gemma3_decoder_block.py index 673969bc47..2789e4ce18 100644 --- a/keras_hub/src/models/gemma3/gemma3_decoder_block.py +++ b/keras_hub/src/models/gemma3/gemma3_decoder_block.py @@ -45,6 +45,7 @@ def __init__( layer_norm_epsilon=1e-6, rope_wavelength=10_000.0, rope_scaling_factor=1.0, + use_bidirectional_attention=False, dropout=0, **kwargs, ): @@ -66,6 +67,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.rope_wavelength = rope_wavelength self.rope_scaling_factor = rope_scaling_factor + self.use_bidirectional_attention = use_bidirectional_attention self.dropout = dropout self.pre_attention_norm = RMSNormalization( @@ -93,6 +95,7 @@ def __init__( rope_wavelength=rope_wavelength, rope_scaling_factor=rope_scaling_factor, dropout=dropout, + use_bidirectional_attention=use_bidirectional_attention, dtype=self.dtype_policy, name="attention", ) @@ -209,6 +212,14 @@ def _compute_attention_mask( if cache is not None: input_length = ops.shape(cache)[2] + if self.use_bidirectional_attention: + # `output_length` and `input_length` will be the same in this case + # because we use bidirectional attention for models like + # `EmbeddingGemma` which aren't used for text generation. + mask_1 = decoder_mask + mask_2 = ops.transpose(mask_1, (0, 2, 1)) + return mask_1 * mask_2 + causal_mask = compute_causal_mask( batch_size=batch_size, input_length=input_length, @@ -304,6 +315,7 @@ def get_config(self): "dropout": self.dropout, "rope_wavelength": self.rope_wavelength, "rope_scaling_factor": self.rope_scaling_factor, + "use_bidirectional_attention": self.use_bidirectional_attention, } ) return config