Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 136 additions & 186 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,10 @@ def div_up(a, b):
beam_width = 1


class CacheSeq:
def __init__(
self,
pool: torch.Tensor,
page_indices: torch.Tensor,
nb_heads: int,
idx_head: int,
tokens_per_page: int = 32,
kv_layout: str = "NHD",
):
self.pool = pool
self.page_indices = page_indices
self.nb_heads = nb_heads
self.idx_head = idx_head
self.tokens_per_page = tokens_per_page
self.kv_layout = kv_layout

def __getitem__(self, i: int) -> torch.Tensor:
page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32)
token_in_page = i % self.tokens_per_page
if self.kv_layout == "NHD":
# NHD layout: [page_idx, token_in_page, idx_head, :]
return self.pool[page_idx, token_in_page, self.idx_head, :]
else: # HND
# HND layout: [page_idx, idx_head, token_in_page, :]
return self.pool[page_idx, self.idx_head, token_in_page, :]


def ref_attention(
q,
k_cache_seq,
v_cache_seq,
k_cache, # Changed: now takes full tensor [seq_len, dim]
v_cache, # Changed: now takes full tensor [seq_len, dim]
seq_len,
q_scale,
kv_scale,
Expand All @@ -89,18 +61,12 @@ def ref_attention(

q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head]

k_cache_f32 = torch.zeros(
seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda"
)
# V cache: load only valid_elems_per_v_head dimensions
v_cache_f32 = torch.zeros(
seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda"
)

for j in range(seq_len):
k_cache_f32[j] = k_cache_seq[j].to(torch.float32)
# For MLA: V cache storage is 576 but only first 512 elements are valid
v_cache_f32[j] = v_cache_seq[j][:valid_elems_per_v_head].to(torch.float32)
# Directly use the pre-assembled cache tensors
k_cache_f32 = k_cache[:seq_len].to(torch.float32) # [seq_len, valid_elems_per_head]
# For MLA: V cache storage is 576 but only first 512 elements are valid
v_cache_f32 = v_cache[:seq_len, :valid_elems_per_v_head].to(
torch.float32
) # [seq_len, valid_elems_per_v_head]

# q_f32: [head_grp_size, valid_elems_per_head]
# k_cache_f32: [seq_len, valid_elems_per_head]
Expand Down Expand Up @@ -223,12 +189,12 @@ def test_xqa(
)
q_heads.normal_(0, 1)
if use_attention_sinks:
attention_sinks = torch.zeros(
nb_k_heads, head_grp_size, dtype=torch.float32, device="cuda"
# Vectorized creation of attention_sinks
j_indices = torch.arange(head_grp_size, device="cuda")
attention_sinks = 2.0 + (j_indices % 4).float()
attention_sinks = (
attention_sinks.unsqueeze(0).expand(nb_k_heads, head_grp_size).contiguous()
)
for i in range(nb_k_heads):
for j in range(head_grp_size):
attention_sinks[i, j] = 2.0 + float(j % 4)
else:
attention_sinks = None
if use_sliding_window:
Expand Down Expand Up @@ -287,65 +253,63 @@ def test_xqa(
# and prevent overflow during computation. The factor 4.0 is chosen empirically.
cache_k_heads /= 4.0
cache_v_heads /= 4.0
page_list_arg = torch.zeros(
batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda"
# Vectorized page list initialization
total_pages = batch_size * nb_pages_per_seq
page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view(
batch_size, nb_pages_per_seq
)

# Initialize page list sequentially
page_idx = 0
for batch in range(batch_size):
for page in range(nb_pages_per_seq):
page_list_arg[batch, page] = page_idx
page_idx += 1

# Shuffle page indices
flattened = page_list_arg.flatten()
indices = torch.randperm(flattened.numel())
indices = torch.randperm(flattened.numel(), device="cuda")
shuffled_flat = flattened[indices]
page_list_arg = shuffled_flat.view(page_list_arg.shape)

def cache_head_at(
batch,
is_k,
idx_kv_head,
pos,
cache_k_heads,
cache_v_heads,
page_list,
beam_width,
nb_k_heads,
tokens_per_page,
kv_layout,
):
# K and V share page indices
page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32)
token_in_page = pos % tokens_per_page

cache = cache_k_heads if is_k else cache_v_heads
if kv_layout == "NHD":
# NHD layout: [page_idx, token_in_page, idx_kv_head, :]
return cache[page_idx, token_in_page, idx_kv_head, :]
else: # HND
# HND layout: [page_idx, idx_kv_head, token_in_page, :]
return cache[page_idx, idx_kv_head, token_in_page, :]

for batch in range(batch_size):
for kv in range(2):
for idx_kv_head in range(nb_k_heads):
for pos in range(seq_len, max_seq_len):
cache_head = cache_head_at(
batch,
kv == 0,
idx_kv_head,
pos,
cache_k_heads,
cache_v_heads,
page_list_arg,
beam_width,
nb_k_heads,
tokens_per_page,
kv_layout,
page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq)

# Vectorized zeroing of unused cache positions using advanced indexing
if seq_len < max_seq_len:
# Collect all (page_id, token_pos) pairs that need to be zeroed across all batches
start_page = seq_len // tokens_per_page
end_page = nb_pages_per_seq

if start_page < end_page:
# Get all page IDs that need partial/full zeroing: [batch_size, num_pages_to_zero]
pages_to_zero = page_list_arg[
:, start_page:end_page
] # [batch_size, num_pages_to_zero]

# For the first page (start_page), zero from [seq_len % tokens_per_page, tokens_per_page)
# For subsequent pages, zero entirely [0, tokens_per_page)
first_page_ids = pages_to_zero[:, 0] # [batch_size]
token_start_in_first_page = seq_len % tokens_per_page

if token_start_in_first_page > 0:
# Zero partial first page for all batches at once
if kv_layout == "NHD":
cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = (
0.0
)
cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = (
0.0
)
else: # HND
cache_k_heads[first_page_ids, :, token_start_in_first_page:, :] = (
0.0
)
cache_v_heads[first_page_ids, :, token_start_in_first_page:, :] = (
0.0
)
cache_head.fill_(0.0)

# Zero all subsequent full pages (if any) for all batches at once
if pages_to_zero.shape[1] > 1:
remaining_page_ids = pages_to_zero[
:, 1:
].flatten() # Flatten all remaining pages
if kv_layout == "NHD":
cache_k_heads[remaining_page_ids, :, :, :] = 0.0
cache_v_heads[remaining_page_ids, :, :, :] = 0.0
else: # HND
cache_k_heads[remaining_page_ids, :, :, :] = 0.0
cache_v_heads[remaining_page_ids, :, :, :] = 0.0
Comment on lines +285 to +312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a logic error in how unused cache positions are zeroed out. When seq_len is a multiple of tokens_per_page, token_start_in_first_page becomes 0. In this scenario, the current code skips zeroing the first page that should be cleared and only processes subsequent pages. This leaves stale data in the cache, which can lead to incorrect test results.

The suggested change corrects this by ensuring that when token_start_in_first_page is 0, all pages from start_page onwards are correctly identified and zeroed out.

            if token_start_in_first_page > 0:
                # Zero partial first page for all batches at once
                if kv_layout == "NHD":
                    cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
                    cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
                else:  # HND
                    cache_k_heads[first_page_ids, :, token_start_in_first_page:, :] = 0.0
                    cache_v_heads[first_page_ids, :, token_start_in_first_page:, :] = 0.0
                pages_to_zero_fully = pages_to_zero[:, 1:]
            else:  # token_start_in_first_page == 0
                pages_to_zero_fully = pages_to_zero

            # Zero all subsequent full pages (if any) for all batches at once
            if pages_to_zero_fully.numel() > 0:
                remaining_page_ids = pages_to_zero_fully.flatten()
                if kv_layout == "NHD":
                    cache_k_heads[remaining_page_ids, :, :, :] = 0.0
                    cache_v_heads[remaining_page_ids, :, :, :] = 0.0
                else:  # HND
                    cache_k_heads[remaining_page_ids, :, :, :] = 0.0
                    cache_v_heads[remaining_page_ids, :, :, :] = 0.0


seq_len_list = torch.zeros(
batch_size, beam_width, dtype=torch.uint32, device="cuda"
Expand Down Expand Up @@ -385,30 +349,36 @@ def cache_head_at(
for req in range(batch_size):
for b in range(beam_width):
for idx_k_head in range(nb_k_heads):
# K and V use separate pools but share page indices
k_cache_seq = CacheSeq(
pool=cache_k_heads,
page_indices=page_list_arg[req],
nb_heads=nb_k_heads,
idx_head=idx_k_head,
tokens_per_page=tokens_per_page,
kv_layout=kv_layout,
)
v_cache_seq = CacheSeq(
pool=cache_v_heads,
page_indices=page_list_arg[req],
nb_heads=nb_k_heads,
idx_head=idx_k_head,
tokens_per_page=tokens_per_page,
kv_layout=kv_layout,
)
# Assemble contiguous K/V cache from paged memory using advanced indexing
num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page
pages = page_list_arg[req, :num_pages] # [num_pages]

# Gather all pages at once
if kv_layout == "NHD":
# [num_pages, tokens_per_page, nb_k_heads, head_dim]
k_pages = cache_k_heads[
pages, :, idx_k_head, :
] # [num_pages, tokens_per_page, head_dim]
v_pages = cache_v_heads[pages, :, idx_k_head, :]
else: # HND
# [num_pages, nb_k_heads, tokens_per_page, head_dim]
k_pages = cache_k_heads[
pages, idx_k_head, :, :
] # [num_pages, tokens_per_page, head_dim]
v_pages = cache_v_heads[pages, idx_k_head, :, :]

# Reshape to contiguous sequence
k_cache = k_pages.reshape(
-1, valid_elems_per_head
) # [num_pages*tokens_per_page, head_dim]
v_cache = v_pages.reshape(-1, valid_elems_per_head)

ref_output = ref_attention(
q=q_heads[req][b][
idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size
],
k_cache_seq=k_cache_seq,
v_cache_seq=v_cache_seq,
k_cache=k_cache,
v_cache=v_cache,
seq_len=seq_len,
q_scale=q_scale,
kv_scale=kv_cache_scale,
Expand Down Expand Up @@ -520,59 +490,41 @@ def test_xqa_mla(
cache_k_heads /= 4.0
cache_v_heads /= 4.0

page_list_arg = torch.zeros(
batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda"
# Vectorized page list initialization
total_pages = batch_size * nb_pages_per_seq
page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view(
batch_size, nb_pages_per_seq
)

# Initialize page list sequentially
page_idx = 0
for batch in range(batch_size):
for page in range(nb_pages_per_seq):
page_list_arg[batch, page] = page_idx
page_idx += 1

# Shuffle page indices
flattened = page_list_arg.flatten()
indices = torch.randperm(flattened.numel())
indices = torch.randperm(flattened.numel(), device="cuda")
shuffled_flat = flattened[indices]
page_list_arg = shuffled_flat.view(page_list_arg.shape)

def cache_head_at(
batch,
is_k,
idx_kv_head,
pos,
cache_k_heads,
cache_v_heads,
page_list,
beam_width,
nb_k_heads,
tokens_per_page,
):
# K and V share page indices
page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32)
token_in_page = pos % tokens_per_page

# NHD layout: [page_idx, token_in_page, idx_kv_head, :]
cache = cache_k_heads if is_k else cache_v_heads
return cache[page_idx, token_in_page, idx_kv_head, :]

for batch in range(batch_size):
for kv in range(2):
for idx_kv_head in range(nb_k_heads):
for pos in range(seq_len, max_seq_len):
cache_head = cache_head_at(
batch,
kv == 0,
idx_kv_head,
pos,
cache_k_heads,
cache_v_heads,
page_list_arg,
beam_width,
nb_k_heads,
tokens_per_page,
)
cache_head.fill_(0.0)
page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq)

# Vectorized zeroing of unused cache positions (NHD layout only for MLA)
if seq_len < max_seq_len:
start_page = seq_len // tokens_per_page
end_page = nb_pages_per_seq

if start_page < end_page:
pages_to_zero = page_list_arg[
:, start_page:end_page
] # [batch_size, num_pages_to_zero]

first_page_ids = pages_to_zero[:, 0] # [batch_size]
token_start_in_first_page = seq_len % tokens_per_page

if token_start_in_first_page > 0:
# Zero partial first page for all batches at once (NHD layout)
cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0

# Zero all subsequent full pages (if any) for all batches at once
if pages_to_zero.shape[1] > 1:
remaining_page_ids = pages_to_zero[:, 1:].flatten()
cache_k_heads[remaining_page_ids, :, :, :] = 0.0
cache_v_heads[remaining_page_ids, :, :, :] = 0.0
Comment on lines +518 to +527
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This section has the same logical bug as in test_xqa. When seq_len is a multiple of tokens_per_page, token_start_in_first_page is 0, and the logic incorrectly skips zeroing out the first page that should be completely cleared. This can cause test failures due to stale data in the cache.

I'm providing a similar fix to ensure all unused pages are correctly zeroed out in this case as well.

            if token_start_in_first_page > 0:
                # Zero partial first page for all batches at once (NHD layout)
                cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
                cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
                pages_to_zero_fully = pages_to_zero[:, 1:]
            else:  # token_start_in_first_page == 0
                pages_to_zero_fully = pages_to_zero

            # Zero all subsequent full pages (if any) for all batches at once
            if pages_to_zero_fully.numel() > 0:
                remaining_page_ids = pages_to_zero_fully.flatten()
                cache_k_heads[remaining_page_ids, :, :, :] = 0.0
                cache_v_heads[remaining_page_ids, :, :, :] = 0.0


seq_len_list = torch.zeros(
batch_size, beam_width, dtype=torch.uint32, device="cuda"
Expand Down Expand Up @@ -608,28 +560,26 @@ def cache_head_at(
for req in range(batch_size):
for b in range(beam_width):
for idx_k_head in range(nb_k_heads):
# K and V use separate pools but share page indices
k_cache_seq = CacheSeq(
pool=cache_k_heads,
page_indices=page_list_arg[req],
nb_heads=nb_k_heads,
idx_head=idx_k_head,
tokens_per_page=tokens_per_page,
)
v_cache_seq = CacheSeq(
pool=cache_v_heads,
page_indices=page_list_arg[req],
nb_heads=nb_k_heads,
idx_head=idx_k_head,
tokens_per_page=tokens_per_page,
)
# Assemble contiguous K/V cache from paged memory using advanced indexing
num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page
pages = page_list_arg[req, :num_pages] # [num_pages]

# NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim]
k_pages = cache_k_heads[
pages, :, idx_k_head, :
] # [num_pages, tokens_per_page, head_dim]
v_pages = cache_v_heads[pages, :, idx_k_head, :]

# Reshape to contiguous sequence
k_cache = k_pages.reshape(-1, valid_elems_per_head_qk)
v_cache = v_pages.reshape(-1, valid_elems_per_head_qk)

ref_output = ref_attention(
q=q_heads[req][b][
idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size
],
k_cache_seq=k_cache_seq,
v_cache_seq=v_cache_seq,
k_cache=k_cache,
v_cache=v_cache,
seq_len=seq_len,
q_scale=q_scale * math.sqrt(576),
kv_scale=kv_cache_scale,
Expand Down
Loading