-
Notifications
You must be signed in to change notification settings - Fork 584
unittest: improve the efficiency of xqa unittests #2075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,38 +31,10 @@ def div_up(a, b): | |
| beam_width = 1 | ||
|
|
||
|
|
||
| class CacheSeq: | ||
| def __init__( | ||
| self, | ||
| pool: torch.Tensor, | ||
| page_indices: torch.Tensor, | ||
| nb_heads: int, | ||
| idx_head: int, | ||
| tokens_per_page: int = 32, | ||
| kv_layout: str = "NHD", | ||
| ): | ||
| self.pool = pool | ||
| self.page_indices = page_indices | ||
| self.nb_heads = nb_heads | ||
| self.idx_head = idx_head | ||
| self.tokens_per_page = tokens_per_page | ||
| self.kv_layout = kv_layout | ||
|
|
||
| def __getitem__(self, i: int) -> torch.Tensor: | ||
| page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) | ||
| token_in_page = i % self.tokens_per_page | ||
| if self.kv_layout == "NHD": | ||
| # NHD layout: [page_idx, token_in_page, idx_head, :] | ||
| return self.pool[page_idx, token_in_page, self.idx_head, :] | ||
| else: # HND | ||
| # HND layout: [page_idx, idx_head, token_in_page, :] | ||
| return self.pool[page_idx, self.idx_head, token_in_page, :] | ||
|
|
||
|
|
||
| def ref_attention( | ||
| q, | ||
| k_cache_seq, | ||
| v_cache_seq, | ||
| k_cache, # Changed: now takes full tensor [seq_len, dim] | ||
| v_cache, # Changed: now takes full tensor [seq_len, dim] | ||
| seq_len, | ||
| q_scale, | ||
| kv_scale, | ||
|
|
@@ -89,18 +61,12 @@ def ref_attention( | |
|
|
||
| q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] | ||
|
|
||
| k_cache_f32 = torch.zeros( | ||
| seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" | ||
| ) | ||
| # V cache: load only valid_elems_per_v_head dimensions | ||
| v_cache_f32 = torch.zeros( | ||
| seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda" | ||
| ) | ||
|
|
||
| for j in range(seq_len): | ||
| k_cache_f32[j] = k_cache_seq[j].to(torch.float32) | ||
| # For MLA: V cache storage is 576 but only first 512 elements are valid | ||
| v_cache_f32[j] = v_cache_seq[j][:valid_elems_per_v_head].to(torch.float32) | ||
| # Directly use the pre-assembled cache tensors | ||
| k_cache_f32 = k_cache[:seq_len].to(torch.float32) # [seq_len, valid_elems_per_head] | ||
| # For MLA: V cache storage is 576 but only first 512 elements are valid | ||
| v_cache_f32 = v_cache[:seq_len, :valid_elems_per_v_head].to( | ||
| torch.float32 | ||
| ) # [seq_len, valid_elems_per_v_head] | ||
|
|
||
| # q_f32: [head_grp_size, valid_elems_per_head] | ||
| # k_cache_f32: [seq_len, valid_elems_per_head] | ||
|
|
@@ -223,12 +189,12 @@ def test_xqa( | |
| ) | ||
| q_heads.normal_(0, 1) | ||
| if use_attention_sinks: | ||
| attention_sinks = torch.zeros( | ||
| nb_k_heads, head_grp_size, dtype=torch.float32, device="cuda" | ||
| # Vectorized creation of attention_sinks | ||
| j_indices = torch.arange(head_grp_size, device="cuda") | ||
| attention_sinks = 2.0 + (j_indices % 4).float() | ||
| attention_sinks = ( | ||
| attention_sinks.unsqueeze(0).expand(nb_k_heads, head_grp_size).contiguous() | ||
| ) | ||
| for i in range(nb_k_heads): | ||
| for j in range(head_grp_size): | ||
| attention_sinks[i, j] = 2.0 + float(j % 4) | ||
| else: | ||
| attention_sinks = None | ||
| if use_sliding_window: | ||
|
|
@@ -287,65 +253,63 @@ def test_xqa( | |
| # and prevent overflow during computation. The factor 4.0 is chosen empirically. | ||
| cache_k_heads /= 4.0 | ||
| cache_v_heads /= 4.0 | ||
| page_list_arg = torch.zeros( | ||
| batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" | ||
| # Vectorized page list initialization | ||
| total_pages = batch_size * nb_pages_per_seq | ||
| page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( | ||
| batch_size, nb_pages_per_seq | ||
| ) | ||
|
|
||
| # Initialize page list sequentially | ||
| page_idx = 0 | ||
| for batch in range(batch_size): | ||
| for page in range(nb_pages_per_seq): | ||
| page_list_arg[batch, page] = page_idx | ||
| page_idx += 1 | ||
|
|
||
| # Shuffle page indices | ||
| flattened = page_list_arg.flatten() | ||
| indices = torch.randperm(flattened.numel()) | ||
| indices = torch.randperm(flattened.numel(), device="cuda") | ||
| shuffled_flat = flattened[indices] | ||
| page_list_arg = shuffled_flat.view(page_list_arg.shape) | ||
|
|
||
| def cache_head_at( | ||
| batch, | ||
| is_k, | ||
| idx_kv_head, | ||
| pos, | ||
| cache_k_heads, | ||
| cache_v_heads, | ||
| page_list, | ||
| beam_width, | ||
| nb_k_heads, | ||
| tokens_per_page, | ||
| kv_layout, | ||
| ): | ||
| # K and V share page indices | ||
| page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) | ||
| token_in_page = pos % tokens_per_page | ||
|
|
||
| cache = cache_k_heads if is_k else cache_v_heads | ||
| if kv_layout == "NHD": | ||
| # NHD layout: [page_idx, token_in_page, idx_kv_head, :] | ||
| return cache[page_idx, token_in_page, idx_kv_head, :] | ||
| else: # HND | ||
| # HND layout: [page_idx, idx_kv_head, token_in_page, :] | ||
| return cache[page_idx, idx_kv_head, token_in_page, :] | ||
|
|
||
| for batch in range(batch_size): | ||
| for kv in range(2): | ||
| for idx_kv_head in range(nb_k_heads): | ||
| for pos in range(seq_len, max_seq_len): | ||
| cache_head = cache_head_at( | ||
| batch, | ||
| kv == 0, | ||
| idx_kv_head, | ||
| pos, | ||
| cache_k_heads, | ||
| cache_v_heads, | ||
| page_list_arg, | ||
| beam_width, | ||
| nb_k_heads, | ||
| tokens_per_page, | ||
| kv_layout, | ||
| page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) | ||
|
|
||
| # Vectorized zeroing of unused cache positions using advanced indexing | ||
| if seq_len < max_seq_len: | ||
| # Collect all (page_id, token_pos) pairs that need to be zeroed across all batches | ||
| start_page = seq_len // tokens_per_page | ||
| end_page = nb_pages_per_seq | ||
|
|
||
| if start_page < end_page: | ||
| # Get all page IDs that need partial/full zeroing: [batch_size, num_pages_to_zero] | ||
| pages_to_zero = page_list_arg[ | ||
| :, start_page:end_page | ||
| ] # [batch_size, num_pages_to_zero] | ||
|
|
||
| # For the first page (start_page), zero from [seq_len % tokens_per_page, tokens_per_page) | ||
| # For subsequent pages, zero entirely [0, tokens_per_page) | ||
| first_page_ids = pages_to_zero[:, 0] # [batch_size] | ||
| token_start_in_first_page = seq_len % tokens_per_page | ||
|
|
||
| if token_start_in_first_page > 0: | ||
| # Zero partial first page for all batches at once | ||
| if kv_layout == "NHD": | ||
| cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = ( | ||
| 0.0 | ||
| ) | ||
| cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = ( | ||
| 0.0 | ||
| ) | ||
| else: # HND | ||
| cache_k_heads[first_page_ids, :, token_start_in_first_page:, :] = ( | ||
| 0.0 | ||
| ) | ||
| cache_v_heads[first_page_ids, :, token_start_in_first_page:, :] = ( | ||
| 0.0 | ||
| ) | ||
| cache_head.fill_(0.0) | ||
|
|
||
| # Zero all subsequent full pages (if any) for all batches at once | ||
| if pages_to_zero.shape[1] > 1: | ||
| remaining_page_ids = pages_to_zero[ | ||
| :, 1: | ||
| ].flatten() # Flatten all remaining pages | ||
| if kv_layout == "NHD": | ||
| cache_k_heads[remaining_page_ids, :, :, :] = 0.0 | ||
| cache_v_heads[remaining_page_ids, :, :, :] = 0.0 | ||
| else: # HND | ||
| cache_k_heads[remaining_page_ids, :, :, :] = 0.0 | ||
| cache_v_heads[remaining_page_ids, :, :, :] = 0.0 | ||
|
|
||
| seq_len_list = torch.zeros( | ||
| batch_size, beam_width, dtype=torch.uint32, device="cuda" | ||
|
|
@@ -385,30 +349,36 @@ def cache_head_at( | |
| for req in range(batch_size): | ||
| for b in range(beam_width): | ||
| for idx_k_head in range(nb_k_heads): | ||
| # K and V use separate pools but share page indices | ||
| k_cache_seq = CacheSeq( | ||
| pool=cache_k_heads, | ||
| page_indices=page_list_arg[req], | ||
| nb_heads=nb_k_heads, | ||
| idx_head=idx_k_head, | ||
| tokens_per_page=tokens_per_page, | ||
| kv_layout=kv_layout, | ||
| ) | ||
| v_cache_seq = CacheSeq( | ||
| pool=cache_v_heads, | ||
| page_indices=page_list_arg[req], | ||
| nb_heads=nb_k_heads, | ||
| idx_head=idx_k_head, | ||
| tokens_per_page=tokens_per_page, | ||
| kv_layout=kv_layout, | ||
| ) | ||
| # Assemble contiguous K/V cache from paged memory using advanced indexing | ||
| num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page | ||
| pages = page_list_arg[req, :num_pages] # [num_pages] | ||
|
|
||
| # Gather all pages at once | ||
| if kv_layout == "NHD": | ||
| # [num_pages, tokens_per_page, nb_k_heads, head_dim] | ||
| k_pages = cache_k_heads[ | ||
| pages, :, idx_k_head, : | ||
| ] # [num_pages, tokens_per_page, head_dim] | ||
| v_pages = cache_v_heads[pages, :, idx_k_head, :] | ||
| else: # HND | ||
| # [num_pages, nb_k_heads, tokens_per_page, head_dim] | ||
| k_pages = cache_k_heads[ | ||
| pages, idx_k_head, :, : | ||
| ] # [num_pages, tokens_per_page, head_dim] | ||
| v_pages = cache_v_heads[pages, idx_k_head, :, :] | ||
|
|
||
| # Reshape to contiguous sequence | ||
| k_cache = k_pages.reshape( | ||
| -1, valid_elems_per_head | ||
| ) # [num_pages*tokens_per_page, head_dim] | ||
| v_cache = v_pages.reshape(-1, valid_elems_per_head) | ||
|
|
||
| ref_output = ref_attention( | ||
| q=q_heads[req][b][ | ||
| idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size | ||
| ], | ||
| k_cache_seq=k_cache_seq, | ||
| v_cache_seq=v_cache_seq, | ||
| k_cache=k_cache, | ||
| v_cache=v_cache, | ||
| seq_len=seq_len, | ||
| q_scale=q_scale, | ||
| kv_scale=kv_cache_scale, | ||
|
|
@@ -520,59 +490,41 @@ def test_xqa_mla( | |
| cache_k_heads /= 4.0 | ||
| cache_v_heads /= 4.0 | ||
|
|
||
| page_list_arg = torch.zeros( | ||
| batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" | ||
| # Vectorized page list initialization | ||
| total_pages = batch_size * nb_pages_per_seq | ||
| page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( | ||
| batch_size, nb_pages_per_seq | ||
| ) | ||
|
|
||
| # Initialize page list sequentially | ||
| page_idx = 0 | ||
| for batch in range(batch_size): | ||
| for page in range(nb_pages_per_seq): | ||
| page_list_arg[batch, page] = page_idx | ||
| page_idx += 1 | ||
|
|
||
| # Shuffle page indices | ||
| flattened = page_list_arg.flatten() | ||
| indices = torch.randperm(flattened.numel()) | ||
| indices = torch.randperm(flattened.numel(), device="cuda") | ||
| shuffled_flat = flattened[indices] | ||
| page_list_arg = shuffled_flat.view(page_list_arg.shape) | ||
|
|
||
| def cache_head_at( | ||
| batch, | ||
| is_k, | ||
| idx_kv_head, | ||
| pos, | ||
| cache_k_heads, | ||
| cache_v_heads, | ||
| page_list, | ||
| beam_width, | ||
| nb_k_heads, | ||
| tokens_per_page, | ||
| ): | ||
| # K and V share page indices | ||
| page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) | ||
| token_in_page = pos % tokens_per_page | ||
|
|
||
| # NHD layout: [page_idx, token_in_page, idx_kv_head, :] | ||
| cache = cache_k_heads if is_k else cache_v_heads | ||
| return cache[page_idx, token_in_page, idx_kv_head, :] | ||
|
|
||
| for batch in range(batch_size): | ||
| for kv in range(2): | ||
| for idx_kv_head in range(nb_k_heads): | ||
| for pos in range(seq_len, max_seq_len): | ||
| cache_head = cache_head_at( | ||
| batch, | ||
| kv == 0, | ||
| idx_kv_head, | ||
| pos, | ||
| cache_k_heads, | ||
| cache_v_heads, | ||
| page_list_arg, | ||
| beam_width, | ||
| nb_k_heads, | ||
| tokens_per_page, | ||
| ) | ||
| cache_head.fill_(0.0) | ||
| page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) | ||
|
|
||
| # Vectorized zeroing of unused cache positions (NHD layout only for MLA) | ||
| if seq_len < max_seq_len: | ||
| start_page = seq_len // tokens_per_page | ||
| end_page = nb_pages_per_seq | ||
|
|
||
| if start_page < end_page: | ||
| pages_to_zero = page_list_arg[ | ||
| :, start_page:end_page | ||
| ] # [batch_size, num_pages_to_zero] | ||
|
|
||
| first_page_ids = pages_to_zero[:, 0] # [batch_size] | ||
| token_start_in_first_page = seq_len % tokens_per_page | ||
|
|
||
| if token_start_in_first_page > 0: | ||
| # Zero partial first page for all batches at once (NHD layout) | ||
| cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 | ||
| cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 | ||
|
|
||
| # Zero all subsequent full pages (if any) for all batches at once | ||
| if pages_to_zero.shape[1] > 1: | ||
| remaining_page_ids = pages_to_zero[:, 1:].flatten() | ||
| cache_k_heads[remaining_page_ids, :, :, :] = 0.0 | ||
| cache_v_heads[remaining_page_ids, :, :, :] = 0.0 | ||
|
Comment on lines
+518
to
+527
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section has the same logical bug as in I'm providing a similar fix to ensure all unused pages are correctly zeroed out in this case as well. if token_start_in_first_page > 0:
# Zero partial first page for all batches at once (NHD layout)
cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0
pages_to_zero_fully = pages_to_zero[:, 1:]
else: # token_start_in_first_page == 0
pages_to_zero_fully = pages_to_zero
# Zero all subsequent full pages (if any) for all batches at once
if pages_to_zero_fully.numel() > 0:
remaining_page_ids = pages_to_zero_fully.flatten()
cache_k_heads[remaining_page_ids, :, :, :] = 0.0
cache_v_heads[remaining_page_ids, :, :, :] = 0.0 |
||
|
|
||
| seq_len_list = torch.zeros( | ||
| batch_size, beam_width, dtype=torch.uint32, device="cuda" | ||
|
|
@@ -608,28 +560,26 @@ def cache_head_at( | |
| for req in range(batch_size): | ||
| for b in range(beam_width): | ||
| for idx_k_head in range(nb_k_heads): | ||
| # K and V use separate pools but share page indices | ||
| k_cache_seq = CacheSeq( | ||
| pool=cache_k_heads, | ||
| page_indices=page_list_arg[req], | ||
| nb_heads=nb_k_heads, | ||
| idx_head=idx_k_head, | ||
| tokens_per_page=tokens_per_page, | ||
| ) | ||
| v_cache_seq = CacheSeq( | ||
| pool=cache_v_heads, | ||
| page_indices=page_list_arg[req], | ||
| nb_heads=nb_k_heads, | ||
| idx_head=idx_k_head, | ||
| tokens_per_page=tokens_per_page, | ||
| ) | ||
| # Assemble contiguous K/V cache from paged memory using advanced indexing | ||
| num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page | ||
| pages = page_list_arg[req, :num_pages] # [num_pages] | ||
|
|
||
| # NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim] | ||
| k_pages = cache_k_heads[ | ||
| pages, :, idx_k_head, : | ||
| ] # [num_pages, tokens_per_page, head_dim] | ||
| v_pages = cache_v_heads[pages, :, idx_k_head, :] | ||
|
|
||
| # Reshape to contiguous sequence | ||
| k_cache = k_pages.reshape(-1, valid_elems_per_head_qk) | ||
| v_cache = v_pages.reshape(-1, valid_elems_per_head_qk) | ||
|
|
||
| ref_output = ref_attention( | ||
| q=q_heads[req][b][ | ||
| idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size | ||
| ], | ||
| k_cache_seq=k_cache_seq, | ||
| v_cache_seq=v_cache_seq, | ||
| k_cache=k_cache, | ||
| v_cache=v_cache, | ||
| seq_len=seq_len, | ||
| q_scale=q_scale * math.sqrt(576), | ||
| kv_scale=kv_cache_scale, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a logic error in how unused cache positions are zeroed out. When
seq_lenis a multiple oftokens_per_page,token_start_in_first_pagebecomes 0. In this scenario, the current code skips zeroing the first page that should be cleared and only processes subsequent pages. This leaves stale data in the cache, which can lead to incorrect test results.The suggested change corrects this by ensuring that when
token_start_in_first_pageis 0, all pages fromstart_pageonwards are correctly identified and zeroed out.