Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
layer_norm_epsilon=1e-6,
rope_wavelength=10_000.0,
rope_scaling_factor=1.0,
use_bidirectional_attention=False,
dropout=0,
**kwargs,
):
Expand All @@ -61,6 +62,7 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.rope_wavelength = rope_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.use_bidirectional_attention = use_bidirectional_attention
self.dropout = dropout

self._kernel_initializer = keras.initializers.get(
Expand Down Expand Up @@ -240,12 +242,58 @@ def _compute_attention(
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
return ops.reshape(results, (b, q_len, self.num_query_heads, h))

def _compute_bidirectional_sliding_mask(self, batch_size, sequence_length):
"""Computes a bidirectional sliding window attention mask.

A token can attend to any other token if their absolute distance is
within half the sliding window size. This mask is used in embedding
models like `EmbeddingGemma`.

Args:
batch_size: The batch size for the mask.
sequence_length: The length of the sequence.

Returns:
A boolean attention mask with shape
`(batch_size, sequence_length, sequence_length)`.
"""
i = keras.ops.expand_dims(
keras.ops.arange(sequence_length, dtype="int32"), axis=1
)
j = keras.ops.arange(sequence_length, dtype="int32")

# If sliding window size is 4, the token in question attends to 1
# token before and 2 tokens after.
w_right = self.sliding_window_size // 2
w_left = self.sliding_window_size - w_right - 1

# Calculate the relative distance.
distance = i - j

mask = keras.ops.logical_and(distance <= w_left, distance >= -w_right)

mask = keras.ops.expand_dims(mask, axis=0)
return keras.ops.broadcast_to(
mask, (batch_size, sequence_length, sequence_length)
)

def _mask_sliding_window(
self,
attention_mask,
cache_update_index=0,
):
batch_size, query_len, key_len = ops.shape(attention_mask)

if self.use_bidirectional_attention:
bidirectional_sliding_mask = (
self._compute_bidirectional_sliding_mask(
batch_size=batch_size,
# `query_len = key_len` for embedding models
sequence_length=query_len,
)
)
return ops.logical_and(attention_mask, bidirectional_sliding_mask)

# Compute the sliding window for square attention.
all_ones = ops.ones((key_len, key_len), "bool")
if keras.config.backend() == "tensorflow":
Expand Down
4 changes: 4 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
global_rope_scaling_factor=1.0,
vision_encoder=None,
layer_norm_epsilon=1e-6,
use_bidirectional_attention=False,
dropout=0,
dtype=None,
**kwargs,
Expand Down Expand Up @@ -251,6 +252,7 @@ def __init__(
sliding_window_size=sliding_window_size,
rope_wavelength=rope_wavelength,
rope_scaling_factor=rope_scaling_factor,
use_bidirectional_attention=use_bidirectional_attention,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
Expand Down Expand Up @@ -357,6 +359,7 @@ def __init__(
self.sliding_window_size = sliding_window_size
self.local_rope_scaling_factor = local_rope_scaling_factor
self.global_rope_scaling_factor = global_rope_scaling_factor
self.use_bidirectional_attention = use_bidirectional_attention
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout

Expand Down Expand Up @@ -396,6 +399,7 @@ def get_config(self):
"vision_encoder": None
if self.vision_encoder is None
else keras.layers.serialize(self.vision_encoder),
"use_bidirectional_attention": self.use_bidirectional_attention,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
}
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/src/models/gemma3/gemma3_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
layer_norm_epsilon=1e-6,
rope_wavelength=10_000.0,
rope_scaling_factor=1.0,
use_bidirectional_attention=False,
dropout=0,
**kwargs,
):
Expand All @@ -66,6 +67,7 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.rope_wavelength = rope_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.use_bidirectional_attention = use_bidirectional_attention
self.dropout = dropout

self.pre_attention_norm = RMSNormalization(
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
rope_wavelength=rope_wavelength,
rope_scaling_factor=rope_scaling_factor,
dropout=dropout,
use_bidirectional_attention=use_bidirectional_attention,
dtype=self.dtype_policy,
name="attention",
)
Expand Down Expand Up @@ -209,6 +212,14 @@ def _compute_attention_mask(
if cache is not None:
input_length = ops.shape(cache)[2]

if self.use_bidirectional_attention:
# `output_length` and `input_length` will be the same in this case
# because we use bidirectional attention for models like
# `EmbeddingGemma` which aren't used for text generation.
mask_1 = decoder_mask
mask_2 = ops.transpose(mask_1, (0, 2, 1))
return mask_1 * mask_2

causal_mask = compute_causal_mask(
batch_size=batch_size,
input_length=input_length,
Expand Down Expand Up @@ -304,6 +315,7 @@ def get_config(self):
"dropout": self.dropout,
"rope_wavelength": self.rope_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"use_bidirectional_attention": self.use_bidirectional_attention,
}
)
return config
Loading