Skip to content

Commit

Permalink
Fix and update GQA tests (#20522)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Nov 20, 2024
1 parent c802ca6 commit 4895458
Showing 1 changed file with 58 additions and 10 deletions.
68 changes: 58 additions & 10 deletions keras/src/layers/attention/grouped_query_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,25 @@ def test_initializer(self):
)
def test_query_mask_propagation(self):
"""Test automatic propagation of the query's mask."""
layer = layers.GroupedQueryAttention(
num_query_heads=2, num_key_value_heads=2, head_dim=2
)
self.assertTrue(layer.supports_masking)
query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.normal(size=(3, 3, 8))
output = layer(query=masked_query, value=value)
try:
layer = layers.GroupedQueryAttention(
num_query_heads=2, num_key_value_heads=2, head_dim=2
)
self.assertTrue(layer.supports_masking)
query = np.array(
[[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]
)
masked_query = layers.Embedding(4, 8, mask_zero=True)(query)
value = np.random.normal(size=(3, 3, 8))
output = layer(query=masked_query, value=value)
except RuntimeError as e:
if e.args[0].startswith(
"(*bias): last dimension must be contiguous"
):
self.skipTest(
"PyTorch errors out on GPU: issue to track bug is here "
"https://github.com/keras-team/keras/issues/20459"
)
self.assertAllClose(masked_query._keras_mask, output._keras_mask)

@parameterized.named_parameters(("causal", True), ("not_causal", 0))
Expand Down Expand Up @@ -278,8 +289,7 @@ def test_masking(self, use_causal_mask):
self.assertAllClose(output, output_with_manual_mask)

@parameterized.named_parameters(
("disable_flash_attention", False),
("enable_flash_attention", True),
("disable_flash_attention", False), ("enable_flash_attention", True)
)
def test_correctness(self, flash_attention):
if flash_attention:
Expand Down Expand Up @@ -348,3 +358,41 @@ def test_correctness(self, flash_attention):
)
self.assertAllClose(output, expected_output, atol=1e-2)
self.assertAllClose(scores, expected_score, atol=1e-2)

def test_flash_attention_with_errors(self):
if backend.backend() in ("numpy", "tensorflow"):
pytest.skip(
reason=(
"Flash attention is not supported on tensorflow and numpy."
)
)
# Check `flash_attention=True` and `dropout=0.1`
with self.assertRaisesRegex(
ValueError,
"Dropout is not supported when flash attention is enabled.",
):
layer = layers.GroupedQueryAttention(
head_dim=2,
num_query_heads=2,
num_key_value_heads=2,
flash_attention=True,
dropout=0.1,
)

# Check `flash_attention=True` and `return_attention_scores=True`
layer = layers.GroupedQueryAttention(
head_dim=2,
num_query_heads=2,
num_key_value_heads=2,
flash_attention=True,
)
self.assertTrue(layer._flash_attention)
query = np.random.random((2, 4, 8))
value = np.random.random((2, 4, 8))
with self.assertRaisesRegex(
ValueError,
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
" attention scores.",
):
layer(query=query, value=value, return_attention_scores=True)

0 comments on commit 4895458

Please sign in to comment.