Skip to content

Commit

Permalink
Add test for the cas where query_len>contextlen
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Nov 4, 2024
1 parent 7c7ad4e commit 0a0600a
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _ref_jax_extended_paged_attention(

attn = jnp.einsum("qhd,khd->hqk", q[i], k)
attn = attn.astype('float32')
import pdb; pdb.set_trace()
q_span = (kv_len - query_len) + jax.lax.broadcasted_iota(
jnp.int32, (query_len, kv_len), 0)
kv_span = jax.lax.broadcasted_iota(jnp.int32, (query_len, kv_len), 1)
Expand Down Expand Up @@ -183,5 +184,69 @@ def test_paged_attention(
self.assertTrue(
jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol))

def test_paged_attention_query_len_longer_than_kv_seq_len(
self,
):
dtype = jnp.float32
page_size=16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256

max_kv_len = 2048
# Set query_len>kv_seq_lens
query_len = num_queries_per_compute_block
kv_seq_lens = jnp.array([3])

batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size

q, k_pages, v_pages, page_indices = _generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)

print(f'Running paged_attention with {query_len=}')
num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
# actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
)

self.assertEqual(actual_output.shape, expected_output.shape)

atol = 1e-2
rtol = 1e-2
print(f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}')
print(f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}')
self.assertTrue(
jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 0a0600a

Please sign in to comment.