From b9dbaa733ac4683f30df869b44e1c430b3a874c9 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 10 Nov 2025 23:40:11 -0500 Subject: [PATCH] upd --- tests/attention/test_xqa.py | 322 ++++++++++------------- tests/attention/test_xqa_batch_decode.py | 123 +++++---- 2 files changed, 213 insertions(+), 232 deletions(-) diff --git a/tests/attention/test_xqa.py b/tests/attention/test_xqa.py index 5701bdc1b8..172135c571 100644 --- a/tests/attention/test_xqa.py +++ b/tests/attention/test_xqa.py @@ -31,38 +31,10 @@ def div_up(a, b): beam_width = 1 -class CacheSeq: - def __init__( - self, - pool: torch.Tensor, - page_indices: torch.Tensor, - nb_heads: int, - idx_head: int, - tokens_per_page: int = 32, - kv_layout: str = "NHD", - ): - self.pool = pool - self.page_indices = page_indices - self.nb_heads = nb_heads - self.idx_head = idx_head - self.tokens_per_page = tokens_per_page - self.kv_layout = kv_layout - - def __getitem__(self, i: int) -> torch.Tensor: - page_idx = self.page_indices[i // self.tokens_per_page].to(torch.int32) - token_in_page = i % self.tokens_per_page - if self.kv_layout == "NHD": - # NHD layout: [page_idx, token_in_page, idx_head, :] - return self.pool[page_idx, token_in_page, self.idx_head, :] - else: # HND - # HND layout: [page_idx, idx_head, token_in_page, :] - return self.pool[page_idx, self.idx_head, token_in_page, :] - - def ref_attention( q, - k_cache_seq, - v_cache_seq, + k_cache, # Changed: now takes full tensor [seq_len, dim] + v_cache, # Changed: now takes full tensor [seq_len, dim] seq_len, q_scale, kv_scale, @@ -89,18 +61,12 @@ def ref_attention( q_f32 = q.to(torch.float32) # [head_grp_size, valid_elems_per_head] - k_cache_f32 = torch.zeros( - seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda" - ) - # V cache: load only valid_elems_per_v_head dimensions - v_cache_f32 = torch.zeros( - seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda" - ) - - for j in range(seq_len): - k_cache_f32[j] = k_cache_seq[j].to(torch.float32) - # For MLA: V cache storage is 576 but only first 512 elements are valid - v_cache_f32[j] = v_cache_seq[j][:valid_elems_per_v_head].to(torch.float32) + # Directly use the pre-assembled cache tensors + k_cache_f32 = k_cache[:seq_len].to(torch.float32) # [seq_len, valid_elems_per_head] + # For MLA: V cache storage is 576 but only first 512 elements are valid + v_cache_f32 = v_cache[:seq_len, :valid_elems_per_v_head].to( + torch.float32 + ) # [seq_len, valid_elems_per_v_head] # q_f32: [head_grp_size, valid_elems_per_head] # k_cache_f32: [seq_len, valid_elems_per_head] @@ -223,12 +189,12 @@ def test_xqa( ) q_heads.normal_(0, 1) if use_attention_sinks: - attention_sinks = torch.zeros( - nb_k_heads, head_grp_size, dtype=torch.float32, device="cuda" + # Vectorized creation of attention_sinks + j_indices = torch.arange(head_grp_size, device="cuda") + attention_sinks = 2.0 + (j_indices % 4).float() + attention_sinks = ( + attention_sinks.unsqueeze(0).expand(nb_k_heads, head_grp_size).contiguous() ) - for i in range(nb_k_heads): - for j in range(head_grp_size): - attention_sinks[i, j] = 2.0 + float(j % 4) else: attention_sinks = None if use_sliding_window: @@ -287,65 +253,63 @@ def test_xqa( # and prevent overflow during computation. The factor 4.0 is chosen empirically. cache_k_heads /= 4.0 cache_v_heads /= 4.0 - page_list_arg = torch.zeros( - batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" + # Vectorized page list initialization + total_pages = batch_size * nb_pages_per_seq + page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( + batch_size, nb_pages_per_seq ) - # Initialize page list sequentially - page_idx = 0 - for batch in range(batch_size): - for page in range(nb_pages_per_seq): - page_list_arg[batch, page] = page_idx - page_idx += 1 - + # Shuffle page indices flattened = page_list_arg.flatten() - indices = torch.randperm(flattened.numel()) + indices = torch.randperm(flattened.numel(), device="cuda") shuffled_flat = flattened[indices] - page_list_arg = shuffled_flat.view(page_list_arg.shape) - - def cache_head_at( - batch, - is_k, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list, - beam_width, - nb_k_heads, - tokens_per_page, - kv_layout, - ): - # K and V share page indices - page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) - token_in_page = pos % tokens_per_page - - cache = cache_k_heads if is_k else cache_v_heads - if kv_layout == "NHD": - # NHD layout: [page_idx, token_in_page, idx_kv_head, :] - return cache[page_idx, token_in_page, idx_kv_head, :] - else: # HND - # HND layout: [page_idx, idx_kv_head, token_in_page, :] - return cache[page_idx, idx_kv_head, token_in_page, :] - - for batch in range(batch_size): - for kv in range(2): - for idx_kv_head in range(nb_k_heads): - for pos in range(seq_len, max_seq_len): - cache_head = cache_head_at( - batch, - kv == 0, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list_arg, - beam_width, - nb_k_heads, - tokens_per_page, - kv_layout, + page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) + + # Vectorized zeroing of unused cache positions using advanced indexing + if seq_len < max_seq_len: + # Collect all (page_id, token_pos) pairs that need to be zeroed across all batches + start_page = seq_len // tokens_per_page + end_page = nb_pages_per_seq + + if start_page < end_page: + # Get all page IDs that need partial/full zeroing: [batch_size, num_pages_to_zero] + pages_to_zero = page_list_arg[ + :, start_page:end_page + ] # [batch_size, num_pages_to_zero] + + # For the first page (start_page), zero from [seq_len % tokens_per_page, tokens_per_page) + # For subsequent pages, zero entirely [0, tokens_per_page) + first_page_ids = pages_to_zero[:, 0] # [batch_size] + token_start_in_first_page = seq_len % tokens_per_page + + if token_start_in_first_page > 0: + # Zero partial first page for all batches at once + if kv_layout == "NHD": + cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = ( + 0.0 + ) + cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = ( + 0.0 + ) + else: # HND + cache_k_heads[first_page_ids, :, token_start_in_first_page:, :] = ( + 0.0 + ) + cache_v_heads[first_page_ids, :, token_start_in_first_page:, :] = ( + 0.0 ) - cache_head.fill_(0.0) + + # Zero all subsequent full pages (if any) for all batches at once + if pages_to_zero.shape[1] > 1: + remaining_page_ids = pages_to_zero[ + :, 1: + ].flatten() # Flatten all remaining pages + if kv_layout == "NHD": + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 + else: # HND + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 seq_len_list = torch.zeros( batch_size, beam_width, dtype=torch.uint32, device="cuda" @@ -385,30 +349,36 @@ def cache_head_at( for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # K and V use separate pools but share page indices - k_cache_seq = CacheSeq( - pool=cache_k_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - kv_layout=kv_layout, - ) - v_cache_seq = CacheSeq( - pool=cache_v_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - kv_layout=kv_layout, - ) + # Assemble contiguous K/V cache from paged memory using advanced indexing + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + pages = page_list_arg[req, :num_pages] # [num_pages] + + # Gather all pages at once + if kv_layout == "NHD": + # [num_pages, tokens_per_page, nb_k_heads, head_dim] + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + else: # HND + # [num_pages, nb_k_heads, tokens_per_page, head_dim] + k_pages = cache_k_heads[ + pages, idx_k_head, :, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, idx_k_head, :, :] + + # Reshape to contiguous sequence + k_cache = k_pages.reshape( + -1, valid_elems_per_head + ) # [num_pages*tokens_per_page, head_dim] + v_cache = v_pages.reshape(-1, valid_elems_per_head) ref_output = ref_attention( q=q_heads[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ], - k_cache_seq=k_cache_seq, - v_cache_seq=v_cache_seq, + k_cache=k_cache, + v_cache=v_cache, seq_len=seq_len, q_scale=q_scale, kv_scale=kv_cache_scale, @@ -520,59 +490,41 @@ def test_xqa_mla( cache_k_heads /= 4.0 cache_v_heads /= 4.0 - page_list_arg = torch.zeros( - batch_size, nb_pages_per_seq, dtype=torch.int32, device="cuda" + # Vectorized page list initialization + total_pages = batch_size * nb_pages_per_seq + page_list_arg = torch.arange(total_pages, dtype=torch.int32, device="cuda").view( + batch_size, nb_pages_per_seq ) - # Initialize page list sequentially - page_idx = 0 - for batch in range(batch_size): - for page in range(nb_pages_per_seq): - page_list_arg[batch, page] = page_idx - page_idx += 1 - + # Shuffle page indices flattened = page_list_arg.flatten() - indices = torch.randperm(flattened.numel()) + indices = torch.randperm(flattened.numel(), device="cuda") shuffled_flat = flattened[indices] - page_list_arg = shuffled_flat.view(page_list_arg.shape) - - def cache_head_at( - batch, - is_k, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list, - beam_width, - nb_k_heads, - tokens_per_page, - ): - # K and V share page indices - page_idx = page_list[batch][pos // tokens_per_page].to(torch.int32) - token_in_page = pos % tokens_per_page - - # NHD layout: [page_idx, token_in_page, idx_kv_head, :] - cache = cache_k_heads if is_k else cache_v_heads - return cache[page_idx, token_in_page, idx_kv_head, :] - - for batch in range(batch_size): - for kv in range(2): - for idx_kv_head in range(nb_k_heads): - for pos in range(seq_len, max_seq_len): - cache_head = cache_head_at( - batch, - kv == 0, - idx_kv_head, - pos, - cache_k_heads, - cache_v_heads, - page_list_arg, - beam_width, - nb_k_heads, - tokens_per_page, - ) - cache_head.fill_(0.0) + page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) + + # Vectorized zeroing of unused cache positions (NHD layout only for MLA) + if seq_len < max_seq_len: + start_page = seq_len // tokens_per_page + end_page = nb_pages_per_seq + + if start_page < end_page: + pages_to_zero = page_list_arg[ + :, start_page:end_page + ] # [batch_size, num_pages_to_zero] + + first_page_ids = pages_to_zero[:, 0] # [batch_size] + token_start_in_first_page = seq_len % tokens_per_page + + if token_start_in_first_page > 0: + # Zero partial first page for all batches at once (NHD layout) + cache_k_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 + cache_v_heads[first_page_ids, token_start_in_first_page:, :, :] = 0.0 + + # Zero all subsequent full pages (if any) for all batches at once + if pages_to_zero.shape[1] > 1: + remaining_page_ids = pages_to_zero[:, 1:].flatten() + cache_k_heads[remaining_page_ids, :, :, :] = 0.0 + cache_v_heads[remaining_page_ids, :, :, :] = 0.0 seq_len_list = torch.zeros( batch_size, beam_width, dtype=torch.uint32, device="cuda" @@ -608,28 +560,26 @@ def cache_head_at( for req in range(batch_size): for b in range(beam_width): for idx_k_head in range(nb_k_heads): - # K and V use separate pools but share page indices - k_cache_seq = CacheSeq( - pool=cache_k_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - ) - v_cache_seq = CacheSeq( - pool=cache_v_heads, - page_indices=page_list_arg[req], - nb_heads=nb_k_heads, - idx_head=idx_k_head, - tokens_per_page=tokens_per_page, - ) + # Assemble contiguous K/V cache from paged memory using advanced indexing + num_pages = (seq_len + tokens_per_page - 1) // tokens_per_page + pages = page_list_arg[req, :num_pages] # [num_pages] + + # NHD layout: [num_pages, tokens_per_page, nb_k_heads, head_dim] + k_pages = cache_k_heads[ + pages, :, idx_k_head, : + ] # [num_pages, tokens_per_page, head_dim] + v_pages = cache_v_heads[pages, :, idx_k_head, :] + + # Reshape to contiguous sequence + k_cache = k_pages.reshape(-1, valid_elems_per_head_qk) + v_cache = v_pages.reshape(-1, valid_elems_per_head_qk) ref_output = ref_attention( q=q_heads[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ], - k_cache_seq=k_cache_seq, - v_cache_seq=v_cache_seq, + k_cache=k_cache, + v_cache=v_cache, seq_len=seq_len, q_scale=q_scale * math.sqrt(576), kv_scale=kv_cache_scale, diff --git a/tests/attention/test_xqa_batch_decode.py b/tests/attention/test_xqa_batch_decode.py index fbeac45354..7a2bd3356a 100644 --- a/tests/attention/test_xqa_batch_decode.py +++ b/tests/attention/test_xqa_batch_decode.py @@ -143,28 +143,33 @@ def create_kv_cache( def create_page_table(batch_size, seq_lens, page_size): + # Ensure seq_lens is on GPU and calculate page_per_seq on GPU + seq_lens = seq_lens.to(GPU_DEVICE) page_per_seq = (seq_lens + page_size - 1) // page_size max_num_pages_per_seq = torch.max(page_per_seq).item() - # Generate random but unique page IDs for all sequences + # Generate sequential page IDs total_pages_needed = torch.sum(page_per_seq).item() - all_page_ids = torch.randperm( + all_page_ids = torch.arange( total_pages_needed, dtype=torch.int32, device=GPU_DEVICE ) - # Generate unique page IDs for all sequences - page_tables = torch.zeros( - (batch_size, max_num_pages_per_seq), dtype=torch.int32, device=GPU_DEVICE + # Use cumsum to create page offsets for each sequence + page_offsets = torch.cat( + [ + torch.tensor([0], device=GPU_DEVICE, dtype=torch.int32), + torch.cumsum(page_per_seq[:-1], dim=0, dtype=torch.int32), + ] ) - # Populate page tables and track page assignments - page_id = 0 - for i in range(batch_size): - num_pages_needed = page_per_seq[i] - page_tables[i, :num_pages_needed] = all_page_ids[ - page_id : page_id + num_pages_needed - ] - page_id += num_pages_needed + # Create page tables using broadcasting + page_idx_range = torch.arange( + max_num_pages_per_seq, device=GPU_DEVICE, dtype=torch.int32 + ).unsqueeze(0) + page_tables = ( + page_offsets.unsqueeze(1) + page_idx_range + ) # [batch_size, max_num_pages_per_seq] + return page_tables, all_page_ids, page_per_seq @@ -179,43 +184,69 @@ def flatten_paged_kv( """Build flat K/V and token-level indptr from paged KV cache and page table. Supports both NHD and HND layouts. + Optimized to avoid loops using vectorized operations. """ device = ref_kv_cache.device batch_size = int(page_table.shape[0]) - # Move loop-control tensors to CPU to avoid GPU sync in loops - page_table_cpu = page_table.cpu() - seq_lens_cpu = seq_lens.cpu() - kv_last_page_len_cpu = kv_last_page_len.cpu() - page_per_seq = (seq_lens_cpu + page_size - 1) // page_size - k_list = [] - v_list = [] - for i in range(batch_size): - pages_i = int(page_per_seq[i].item()) - last_len_i = int(kv_last_page_len_cpu[i].item()) - for j in range(pages_i): - page_id = int(page_table_cpu[i, j].item()) - if kv_layout == "NHD": - # NHD: [page_id, 0/1, page_size, num_heads, head_dim] - k_page = ref_kv_cache[page_id, 0] # [page_size, num_heads, head_dim] - v_page = ref_kv_cache[page_id, 1] - if j == pages_i - 1: - k_page = k_page[:last_len_i, :, :] - v_page = v_page[:last_len_i, :, :] - else: # HND - # HND: [page_id, 0/1, num_heads, page_size, head_dim] - k_page = ref_kv_cache[page_id, 0] # [num_heads, page_size, head_dim] - v_page = ref_kv_cache[page_id, 1] - if j == pages_i - 1: - k_page = k_page[:, :last_len_i, :] - v_page = v_page[:, :last_len_i, :] - # Transpose to NHD: [num_heads, page_size, head_dim] -> [page_size, num_heads, head_dim] - k_page = k_page.transpose(0, 1) - v_page = v_page.transpose(0, 1) - k_list.append(k_page) - v_list.append(v_page) - k_flat = torch.cat(k_list, dim=0) - v_flat = torch.cat(v_list, dim=0) + # Calculate number of pages per sequence + page_per_seq = (seq_lens + page_size - 1) // page_size + max_pages = int(page_per_seq.max().item()) + + # Gather all pages at once using advanced indexing + # page_table shape: [batch_size, max_pages] + if kv_layout == "NHD": + # ref_kv_cache: [num_pages_total, 2, page_size, num_heads, head_dim] + # Gather: [batch_size, max_pages, page_size, num_heads, head_dim] + k_pages = ref_kv_cache[ + page_table, 0 + ] # [batch_size, max_pages, page_size, num_heads, head_dim] + v_pages = ref_kv_cache[page_table, 1] + else: # HND + # ref_kv_cache: [num_pages_total, 2, num_heads, page_size, head_dim] + # Gather: [batch_size, max_pages, num_heads, page_size, head_dim] + k_pages = ref_kv_cache[ + page_table, 0 + ] # [batch_size, max_pages, num_heads, page_size, head_dim] + v_pages = ref_kv_cache[page_table, 1] + # Transpose to NHD: [batch_size, max_pages, num_heads, page_size, head_dim] -> [batch_size, max_pages, page_size, num_heads, head_dim] + k_pages = k_pages.transpose(2, 3) + v_pages = v_pages.transpose(2, 3) + + # Reshape to [batch_size, max_pages * page_size, num_heads, head_dim] + num_heads = k_pages.shape[-2] + head_dim = k_pages.shape[-1] + k_pages = k_pages.reshape(batch_size, max_pages * page_size, num_heads, head_dim) + v_pages = v_pages.reshape(batch_size, max_pages * page_size, num_heads, head_dim) + + # Create token indices for each sequence using vectorized operations + # For each batch, we need to extract [:seq_len] tokens + max_seq_len = seq_lens.max().item() + token_idx = torch.arange(max_seq_len, device=device, dtype=torch.int32).unsqueeze( + 0 + ) # [1, max_seq_len] + token_mask = token_idx < seq_lens.unsqueeze(1) # [batch_size, max_seq_len] + + # Gather valid tokens for all sequences at once + # Expand k_pages and v_pages to max_seq_len, then mask + k_gathered = k_pages[ + :, :max_seq_len, :, : + ] # [batch_size, max_seq_len, num_heads, head_dim] + v_gathered = v_pages[ + :, :max_seq_len, :, : + ] # [batch_size, max_seq_len, num_heads, head_dim] + + # Flatten and filter by mask + k_gathered_flat = k_gathered.reshape( + -1, num_heads, head_dim + ) # [batch_size * max_seq_len, num_heads, head_dim] + v_gathered_flat = v_gathered.reshape(-1, num_heads, head_dim) + token_mask_flat = token_mask.reshape(-1) # [batch_size * max_seq_len] + + # Keep only valid tokens + k_flat = k_gathered_flat[token_mask_flat] + v_flat = v_gathered_flat[token_mask_flat] + kv_indptr_tokens = torch.cat( [ torch.tensor([0], dtype=torch.int32, device=device),