diff --git a/cpp/tensorrt_llm/thop/IndexerTopKOp.cpp b/cpp/tensorrt_llm/thop/IndexerTopKOp.cpp index 471ee19be9b..8a5003238c7 100644 --- a/cpp/tensorrt_llm/thop/IndexerTopKOp.cpp +++ b/cpp/tensorrt_llm/thop/IndexerTopKOp.cpp @@ -57,7 +57,6 @@ void indexer_topk_decode( TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous"); TORCH_CHECK(next_n > 0, "next_n must be greater than 0"); - TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now"); int32_t num_rows = static_cast(numRows64); int32_t num_columns = static_cast(numColumns64); @@ -95,7 +94,6 @@ void indexer_topk_prefill(th::Tensor const& logits, th::Tensor const& row_starts TORCH_CHECK(indices.dim() == 2, "indices must be a 2D Tensor"); TORCH_CHECK(logits.dim() == 2, "logits must be a 2D Tensor"); - TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now"); auto const inputSize = logits.sizes(); auto const numRows64 = inputSize[0]; diff --git a/examples/llm-api/llm_sparse_attention.py b/examples/llm-api/llm_sparse_attention.py index 2739ecaa54c..3ebe4dcb61a 100644 --- a/examples/llm-api/llm_sparse_attention.py +++ b/examples/llm-api/llm_sparse_attention.py @@ -44,6 +44,7 @@ def parse_arguments(): type=str, default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl" ) + # Build config parser.add_argument('--algo', type=str, @@ -53,6 +54,8 @@ def parse_arguments(): type=str, default='TRTLLM', choices=['VANILLA', 'TRTLLM']) + + # RocketKV config parser.add_argument('--window_size', type=int, default=32, @@ -65,6 +68,14 @@ def parse_arguments(): type=int, default=2048, help="The prompt budget for RocketKV.") + parser.add_argument('--topk', + type=int, + default=64, + help='Top-k for RocketKV') + parser.add_argument('--kt_cache_dtype', + type=str, + default='float8_e5m2', + choices=['bfloat16', 'float8_e5m2']) parser.add_argument('--index_max_chunk_size', type=int, default=32768, @@ -106,6 +117,7 @@ def parse_arguments(): # KV cache parser.add_argument('--kv_cache_dtype', type=str, default='auto') parser.add_argument("--kv_cache_fraction", type=float, default=0.7) + parser.add_argument('--tokens_per_block', type=int, default=32) parser.add_argument('--num_samples', type=int, default=10) # Runtime @@ -139,8 +151,8 @@ def run_llm(args, sparse_attention_config): enable_block_reuse= False, # sparse attention does not support kv cache reuse now free_gpu_memory_fraction=args.kv_cache_fraction, + tokens_per_block=args.tokens_per_block, dtype=args.kv_cache_dtype, - tokens_per_block=64, ) cuda_graph_config = CudaGraphConfig( @@ -191,6 +203,8 @@ def run_RocketKV(args): window_size=args.window_size, kernel_size=args.kernel_size, prompt_budget=args.prompt_budget, + topk=args.topk, + kt_cache_dtype=args.kt_cache_dtype, ) run_llm(args, sparse_attention_config) diff --git a/examples/longbench/eval_longbench_v1.py b/examples/longbench/eval_longbench_v1.py index 5db743b1424..696971055ba 100644 --- a/examples/longbench/eval_longbench_v1.py +++ b/examples/longbench/eval_longbench_v1.py @@ -150,16 +150,22 @@ def parse_arguments() -> argparse.Namespace: type=int, default=63, help='Kernel size for RocketKV') - parser.add_argument('--topr', + parser.add_argument('--topk', type=int, - default=90, - help='Top-r for RocketKV') + default=64, + help='Top-k for RocketKV') + parser.add_argument('--kt_cache_dtype', + type=str, + default='float8_e5m2', + choices=['bfloat16', 'float8_e5m2'], + help='KT cache data type') # KV cache configuration parser.add_argument('--kv_cache_dtype', type=str, default='auto', help='KV cache data type') + parser.add_argument('--tokens_per_block', type=int, default=32) parser.add_argument('--kv_cache_fraction', type=float, default=0.7, @@ -320,6 +326,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]: # sparse attention doesn't support KV cache reuse enable_block_reuse=False, free_gpu_memory_fraction=args.kv_cache_fraction, + tokens_per_block=args.tokens_per_block, ) # Configure CUDA graph @@ -335,7 +342,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]: window_size=args.window_size, kernel_size=args.kernel_size, prompt_budget=args.token_budget, - topr=args.topr, + topk=args.topk, + kt_cache_dtype=args.kt_cache_dtype, ) logger.info(f"Using RocketKV sparse attention") else: @@ -427,6 +435,14 @@ def evaluate_single_dataset( formatted_prompt = format_prompt_style(sample, prompt_format, chat_template, dataset, tokenizer) + # Truncate prompt if it's too long + token_ids = tokenizer.encode(formatted_prompt, truncation=False) + if len(token_ids) > args.max_seq_len: + half = (args.max_seq_len - max_new_tokens) // 2 + formatted_prompt = tokenizer.decode( + token_ids[:half], skip_special_tokens=True) + tokenizer.decode( + token_ids[-half:], skip_special_tokens=True) + prompts.append(formatted_prompt) if len(prompts) == 0: diff --git a/examples/longbench/eval_longbench_v2.py b/examples/longbench/eval_longbench_v2.py index 4b33940e06b..5f8214dbc88 100644 --- a/examples/longbench/eval_longbench_v2.py +++ b/examples/longbench/eval_longbench_v2.py @@ -121,10 +121,15 @@ def parse_arguments() -> argparse.Namespace: type=int, default=63, help='Kernel size for RocketKV') - parser.add_argument('--topr', + parser.add_argument('--topk', type=int, - default=90, - help='Top-r for RocketKV') + default=64, + help='Top-k for RocketKV') + parser.add_argument('--kt_cache_dtype', + type=str, + default='float8_e5m2', + choices=['bfloat16', 'float8_e5m2'], + help='KT cache data type') # KV cache configuration parser.add_argument('--kv_cache_dtype', @@ -356,7 +361,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]: window_size=args.window_size, kernel_size=args.kernel_size, prompt_budget=args.token_budget, - topr=args.topr, + topk=args.topk, + kt_cache_dtype=args.kt_cache_dtype, ) logger.info(f"Using RocketKV sparse attention") else: diff --git a/tensorrt_llm/_torch/attention_backend/sparse/kernel.py b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py index 28dc31836e8..3ca6f8ddd27 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/kernel.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py @@ -163,6 +163,7 @@ def triton_rocket_qk_split( input_lens_cumsum: torch.Tensor, valid_seq_indices: torch.Tensor, k_output_offsets: torch.Tensor, + total_rocket_k_ctx_tokens: int, num_heads: int, num_kv_heads: int, head_dim: int, @@ -182,6 +183,7 @@ def triton_rocket_qk_split( input_lens_cumsum: Cumulative sum of context lengths [batch_size + 1] valid_seq_indices: Indices of valid sequences [valid_batch_size] k_output_offsets: Offset for each valid sequence [valid_batch_size] + total_rocket_k_ctx_tokens: Total number of RocketKV key context tokens num_heads: Number of query heads num_kv_heads: Number of key/value heads head_dim: Dimension of each head @@ -194,7 +196,7 @@ def triton_rocket_qk_split( """ q_total_output_tokens = window_size * valid_batch_size - k_total_output_tokens = k_output_offsets[valid_batch_size].item() + k_total_output_tokens = total_rocket_k_ctx_tokens q_window = torch.empty((num_heads, q_total_output_tokens, head_dim), device=input_tensor.device, @@ -503,14 +505,18 @@ def triton_softmax( Apply softmax to flattened input tensor. Args: - input_tensor: Input tensor [num_heads, q_len_per_seq, total_k_tokens] + input_tensor: Input tensor [num_heads, len_per_seq, total_k_tokens] or [num_heads, total_k_tokens] cum_lens: Cumulative lengths [batch_size + 1] batch_size: Number of batches Returns: - output: Softmax results [num_heads, q_len_per_seq, total_k_tokens] + output: Softmax results, shape is like input_tensor """ - num_heads, q_len_per_seq, total_k_tokens = input_tensor.shape + if input_tensor.ndim == 2: + num_heads, total_k_tokens = input_tensor.shape + len_per_seq = 1 + else: + num_heads, len_per_seq, total_k_tokens = input_tensor.shape output = torch.empty_like(input_tensor, dtype=input_tensor.dtype, @@ -518,7 +524,7 @@ def triton_softmax( BLOCK_SIZE = 512 - grid = (num_heads, batch_size * q_len_per_seq) + grid = (num_heads, batch_size * len_per_seq) softmax_kernel[grid]( input_tensor, @@ -526,7 +532,7 @@ def triton_softmax( cum_lens, batch_size, num_heads, - q_len_per_seq, + len_per_seq, total_k_tokens, BLOCK_SIZE=BLOCK_SIZE, ) @@ -666,7 +672,8 @@ def rocket_batch_to_flatten_kernel( token_mask = token_offsets < prefix_budget # Load from prefix_indices - prefix_indices = valid_idx_in_selected * num_kv_heads * prefix_budget + head_idx * prefix_budget + token_offsets + flattened_idx = valid_idx_in_selected * num_kv_heads + head_idx + prefix_indices = flattened_idx * prefix_budget + token_offsets prefix_values = tl.load(prefix_indices_ptr + prefix_indices, mask=token_mask, other=0) @@ -711,14 +718,15 @@ def triton_rocket_batch_to_flatten( prefix_indices: torch.Tensor, input_lens: torch.Tensor, valid_seq_indices: torch.Tensor, output_offsets: torch.Tensor, batch_size: int, total_output_tokens: int, window_size: int, - prompt_budget: int) -> tuple[torch.Tensor, torch.Tensor]: + prompt_budget: int, + num_kv_heads: int) -> tuple[torch.Tensor, torch.Tensor]: """ Flatten indices considering both valid and invalid batches. For valid sequences, combines prefix_indices with dynamically computed window indices. For invalid sequences, generates sequential indices. Args: - prefix_indices: Selected prefix indices [valid_batch_size, num_kv_heads, prefix_budget] + prefix_indices: Selected prefix indices [valid_batch_size * num_kv_heads, prefix_budget] input_lens: Lengths for all sequences [batch_size] valid_seq_indices: Valid sequence indices [valid_batch_size] output_offsets: Offset for each batch [batch_size + 1] @@ -726,11 +734,13 @@ def triton_rocket_batch_to_flatten( total_output_tokens: Total number of output tokens window_size: Size of sliding window at the end prompt_budget: Total number of tokens for valid sequences (prefix_budget + window_size) + num_kv_heads: Number of KV heads Returns: sparse_indices: Flattened sparse indices [num_kv_heads, total_output_tokens] """ - valid_batch_size, num_kv_heads, prefix_budget = prefix_indices.shape + total_tasks, prefix_budget = prefix_indices.shape + valid_batch_size = total_tasks // num_kv_heads # Create output tensor sparse_indices = torch.empty((num_kv_heads, total_output_tokens), @@ -774,6 +784,8 @@ def rocket_update_kt_cache_gen_kernel( kt_page_size, tokens_per_block, max_kt_blocks_per_seq, + k_stride_0, + k_stride_1, DIM_BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) @@ -788,8 +800,8 @@ def rocket_update_kt_cache_gen_kernel( dim_indices = dim_block_start + dim_offsets dim_mask = dim_indices < head_dim - k_base = batch_idx * num_kv_heads * head_dim + kv_head_idx * head_dim - k_indices = k_base + dim_indices + k_base = batch_idx * k_stride_0 + kv_head_idx * head_dim * k_stride_1 + k_indices = k_base + dim_indices * k_stride_1 k_values = tl.load(k_ptr + k_indices, mask=dim_mask, other=0.0) kv_len = tl.load(kv_lens_ptr + batch_idx) @@ -860,6 +872,8 @@ def triton_rocket_update_kt_cache_gen( grid = (num_gen_tokens, num_kv_heads, 1) + DIM_BLOCK_SIZE = triton.next_power_of_2(head_dim) + rocket_update_kt_cache_gen_kernel[grid](k, kt_cache_tensor, kt_cache_block_offsets, @@ -870,7 +884,9 @@ def triton_rocket_update_kt_cache_gen( kt_page_size, tokens_per_block, max_kt_blocks_per_seq, - DIM_BLOCK_SIZE=128) + k.stride(0), + k.stride(1), + DIM_BLOCK_SIZE=DIM_BLOCK_SIZE) @triton.jit @@ -885,7 +901,7 @@ def rocket_update_kt_cache_ctx_kernel( num_heads, num_kv_heads, head_dim, - kt_page_size, + kt_page_size: tl.constexpr, tokens_per_block, max_kt_blocks_per_seq, total_sparse_tokens, @@ -897,9 +913,7 @@ def rocket_update_kt_cache_ctx_kernel( """ batch_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - - if batch_idx >= batch_size or kv_head_idx >= num_kv_heads: - return + kt_block_idx = tl.program_id(2) context_start = tl.load(context_cumsum_ptr + batch_idx) @@ -910,100 +924,71 @@ def rocket_update_kt_cache_ctx_kernel( if num_sparse_tokens <= 0: return - # Calculate number of kt_tokens for this batch - num_kt_tokens = (num_sparse_tokens + kt_page_size - 1) // kt_page_size - q_hidden_size = num_heads * head_dim kv_hidden_size = num_kv_heads * head_dim k_dim_base = q_hidden_size + kv_head_idx * head_dim - # Process kt_tokens and dimensions in blocks - for kt_block_start in tl.range(0, num_kt_tokens, BLOCK_SIZE_KT): - # Get kt_token indices for this block [BLOCK_SIZE_KT] - kt_offsets = kt_block_start + tl.arange(0, BLOCK_SIZE_KT) - kt_mask = kt_offsets < num_kt_tokens - - # Calculate page boundaries for all kt_tokens in this block - page_starts = sparse_start + kt_offsets * kt_page_size - page_ends = tl.minimum(page_starts + kt_page_size, sparse_end) - - for dim_block_start in tl.range(0, head_dim, BLOCK_SIZE_DIM): - dim_offsets = tl.arange(0, BLOCK_SIZE_DIM) - dim_indices = dim_block_start + dim_offsets - dim_mask = dim_indices < head_dim - - k_min = tl.full((BLOCK_SIZE_KT, BLOCK_SIZE_DIM), - float('inf'), - dtype=tl.float32) - k_max = tl.full((BLOCK_SIZE_KT, BLOCK_SIZE_DIM), - float('-inf'), - dtype=tl.float32) - - # Iterate through all tokens in the page - for page_offset in range(kt_page_size): - # Calculate token indices within sparse range [BLOCK_SIZE_KT] - token_indices = page_starts + page_offset - token_mask = (token_indices < page_ends) & kt_mask - - # Load sparse indices for all valid tokens [BLOCK_SIZE_KT] - sparse_idx_offsets = kv_head_idx * total_sparse_tokens + token_indices - kv_token_indices = tl.load(sparse_kv_indices_ptr + - sparse_idx_offsets, - mask=token_mask, - other=0) - - # Broadcast for 2D operations [BLOCK_SIZE_KT, BLOCK_SIZE_DIM] - valid_mask_2d = token_mask[:, None] & dim_mask[None, :] - - # Calculate indices for loading keys [BLOCK_SIZE_KT, BLOCK_SIZE_DIM] - k_base_indices = (kv_token_indices[:, None] + context_start) * ( - q_hidden_size + 2 * kv_hidden_size) + k_dim_base - k_indices = k_base_indices + dim_indices[None, :] - - # Load key values [BLOCK_SIZE_KT, BLOCK_SIZE_DIM] - k_values = tl.load(k_ptr + k_indices, - mask=valid_mask_2d, - other=0.0) - - k_min = tl.where(valid_mask_2d, tl.minimum(k_min, k_values), - k_min) - k_max = tl.where(valid_mask_2d, tl.maximum(k_max, k_values), - k_max) - - k_min = k_min.to(kt_cache_tensor_ptr.dtype.element_ty) - k_max = k_max.to(kt_cache_tensor_ptr.dtype.element_ty) - - # Calculate cache locations [BLOCK_SIZE_KT] - block_offsets_in_seq = kt_offsets // tokens_per_block - valid_block_mask = (block_offsets_in_seq - < max_kt_blocks_per_seq) & kt_mask - - # Load block indices [BLOCK_SIZE_KT] - block_offset_addrs = batch_idx * max_kt_blocks_per_seq + block_offsets_in_seq - block_indices = tl.load(kt_cache_block_offsets_ptr + - block_offset_addrs, - mask=valid_block_mask, - other=0) + BLOCK_SIZE_KV: tl.constexpr = kt_page_size * BLOCK_SIZE_KT + + total_kt_tokens = (num_sparse_tokens + kt_page_size - 1) // kt_page_size + kt_offsets = kt_block_idx * BLOCK_SIZE_KT + tl.arange(0, BLOCK_SIZE_KT) + kt_mask = kt_offsets < total_kt_tokens + + kv_start = kt_block_idx * BLOCK_SIZE_KT * kt_page_size + kv_offsets = kv_start + tl.arange(0, BLOCK_SIZE_KV) + kv_mask = kv_offsets < num_sparse_tokens + kv_indices = kv_head_idx * total_sparse_tokens + sparse_start + kv_offsets + + for dim_block_start in tl.range(0, head_dim, BLOCK_SIZE_DIM): + dim_offsets = tl.arange(0, BLOCK_SIZE_DIM) + dim_indices = dim_block_start + dim_offsets + dim_mask = dim_indices < head_dim + + kv_token_indices = tl.load(sparse_kv_indices_ptr + kv_indices, + mask=kv_mask, + other=0) + # Calculate indices for loading keys [BLOCK_SIZE_DIM, BLOCK_SIZE_KV] + k_base_indices = (kv_token_indices[None, :] + context_start) * ( + q_hidden_size + 2 * kv_hidden_size) + k_dim_base + k_indices = k_base_indices + dim_indices[:, None] + + combined_mask = kv_mask[None, :] & dim_mask[:, None] - tokens_in_block = kt_offsets % tokens_per_block + # Load key values [BLOCK_SIZE_DIM, BLOCK_SIZE_KV] + k_values = tl.load(k_ptr + k_indices, mask=combined_mask, other=0.0) - # Calculate cache base addresses [BLOCK_SIZE_KT] - cache_bases = ( - (block_indices * tokens_per_block + tokens_in_block) * - num_kv_heads * 2 * head_dim + kv_head_idx * 2 * head_dim) + k_values = tl.reshape(k_values, + (BLOCK_SIZE_DIM, BLOCK_SIZE_KT, kt_page_size)) - cache_min_addrs = cache_bases[:, None] + dim_indices[None, :] - cache_max_addrs = cache_bases[:, None] + head_dim + dim_indices[ - None, :] + k_min = tl.min(k_values, + axis=-1).to(kt_cache_tensor_ptr.dtype.element_ty) + k_max = tl.max(k_values, + axis=-1).to(kt_cache_tensor_ptr.dtype.element_ty) - store_mask = valid_block_mask[:, None] & dim_mask[None, :] + # Calculate cache locations [BLOCK_SIZE_KT] + block_offsets_in_seq = kt_offsets // tokens_per_block + valid_block_mask = (block_offsets_in_seq + < max_kt_blocks_per_seq) & kt_mask - tl.store(kt_cache_tensor_ptr + cache_min_addrs, - k_min, - mask=store_mask) - tl.store(kt_cache_tensor_ptr + cache_max_addrs, - k_max, - mask=store_mask) + # Load block indices [BLOCK_SIZE_KT] + block_offset_addrs = batch_idx * max_kt_blocks_per_seq + block_offsets_in_seq + block_indices = tl.load(kt_cache_block_offsets_ptr + block_offset_addrs, + mask=valid_block_mask, + other=0) + + tokens_in_block = kt_offsets % tokens_per_block + + # Calculate cache base addresses [BLOCK_SIZE_KT] + cache_bases = ((block_indices * tokens_per_block + tokens_in_block) * + num_kv_heads * 2 * head_dim + kv_head_idx * 2 * head_dim) + + cache_min_addrs = cache_bases[None, :] + dim_indices[:, None] + cache_max_addrs = cache_min_addrs + head_dim + + store_mask = valid_block_mask[None, :] & dim_mask[:, None] + + tl.store(kt_cache_tensor_ptr + cache_min_addrs, k_min, mask=store_mask) + tl.store(kt_cache_tensor_ptr + cache_max_addrs, k_max, mask=store_mask) def triton_rocket_update_kt_cache_ctx( @@ -1017,6 +1002,7 @@ def triton_rocket_update_kt_cache_ctx( num_kv_heads: int, head_dim: int, kt_page_size: int, + prompt_budget: int, tokens_per_block: int, max_kt_blocks_per_seq: int, ): @@ -1033,16 +1019,20 @@ def triton_rocket_update_kt_cache_ctx( num_kv_heads: Number of KV heads head_dim: Head dimension kt_page_size: Page size for KT tokens + prompt_budget: Prompt budget tokens_per_block: Tokens per cache block max_kt_blocks_per_seq: Maximum KT blocks per sequence """ batch_size = sparse_kv_offsets.size(0) - 1 total_sparse_tokens = sparse_kv_indices.size(1) - BLOCK_SIZE_KT = 128 - BLOCK_SIZE_DIM = 128 + total_kt_tokens = (prompt_budget + kt_page_size - 1) // kt_page_size - grid = (batch_size, num_kv_heads) + BLOCK_SIZE_KT = 8 + BLOCK_SIZE_DIM = triton.next_power_of_2(head_dim) + + grid = (batch_size, num_kv_heads, + (total_kt_tokens + BLOCK_SIZE_KT - 1) // BLOCK_SIZE_KT) rocket_update_kt_cache_ctx_kernel[grid]( qkv_input, @@ -1074,56 +1064,62 @@ def rocket_paged_kt_cache_bmm_kernel( q_ptr, kt_cache_tensor_ptr, kt_cache_block_offsets_ptr, - dim_pos_ptr, kv_lens_ptr, output_ptr, output_offsets_ptr, num_gen_tokens, num_kv_heads, - num_heads_per_kv, + num_heads_per_kv: tl.constexpr, head_dim, kt_page_size, tokens_per_block, max_kt_blocks_per_seq, total_kt_tokens, sm_scale, + q_stride_0, + q_stride_1, + q_stride_2, + q_stride_3, + Q_BLOCK_SIZE: tl.constexpr, KT_BLOCK_SIZE: tl.constexpr, DIM_BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) - global_head_idx = tl.program_id(1) + kv_head_idx = tl.program_id(1) - if batch_idx >= num_gen_tokens or global_head_idx >= num_kv_heads * num_heads_per_kv: + if batch_idx >= num_gen_tokens or kv_head_idx >= num_kv_heads: return kv_len = tl.load(kv_lens_ptr + batch_idx) num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size - kv_head_idx = global_head_idx // num_heads_per_kv - q_head_idx = global_head_idx % num_heads_per_kv + q_base = batch_idx * q_stride_0 + kv_head_idx * q_stride_1 - q_base = (batch_idx * num_kv_heads * num_heads_per_kv * head_dim + - kv_head_idx * num_heads_per_kv * head_dim + q_head_idx * head_dim) - dim_pos_base = (batch_idx * num_kv_heads * head_dim + - kv_head_idx * head_dim) + q_head_offsets = tl.arange(0, Q_BLOCK_SIZE) + q_head_mask = q_head_offsets < num_heads_per_kv output_offset = tl.load(output_offsets_ptr + batch_idx) dim_indices = tl.arange(0, DIM_BLOCK_SIZE) dim_mask = dim_indices < head_dim - q_indices = q_base + dim_indices - q_values = tl.load(q_ptr + q_indices, mask=dim_mask, other=0.0) - q_values = tl.broadcast_to(q_values[None, :], - (KT_BLOCK_SIZE, DIM_BLOCK_SIZE)) + q_indices = q_base + q_head_offsets[:, None] * q_stride_2 + dim_indices[ + None, :] * q_stride_3 + q_values = tl.load(q_ptr + q_indices, + mask=q_head_mask[:, None] & dim_mask[None, :]) - dim_pos_indices = dim_pos_base + dim_indices - dim_pos_values = tl.load(dim_pos_ptr + dim_pos_indices, - mask=dim_mask, - other=0) + dim_pos_values = tl.sum(q_values, axis=0) > 0 dim_pos_values = tl.broadcast_to(dim_pos_values[None, :], (KT_BLOCK_SIZE, DIM_BLOCK_SIZE)) - for kt_block_idx_start in tl.range(0, num_kt_tokens, KT_BLOCK_SIZE): + q_values = q_values.to(kt_cache_tensor_ptr.dtype.element_ty) + + for kt_block_idx_start in tl.range( + 0, + num_kt_tokens, + KT_BLOCK_SIZE, + ): + kt_block_idx_start = tl.multiple_of(kt_block_idx_start, KT_BLOCK_SIZE) + kt_token_indices = kt_block_idx_start + tl.arange(0, KT_BLOCK_SIZE) kt_token_mask = kt_token_indices < num_kt_tokens @@ -1156,19 +1152,23 @@ def rocket_paged_kt_cache_bmm_kernel( kt_cache_values = tl.where(dim_pos_values > 0, kt_cache_values_max, kt_cache_values_min) - results = tl.sum(q_values * kt_cache_values, axis=1) + results = tl.dot(q_values, + kt_cache_values.T) # [Q_BLOCK_SIZE, KT_BLOCK_SIZE] + + output_mask = q_head_mask[:, None] & kt_token_mask[None, :] + output_indices = (kv_head_idx * num_heads_per_kv * total_kt_tokens + + q_head_offsets[:, None] * total_kt_tokens + + output_offset + kt_token_indices[None, :]) - output_indices = global_head_idx * total_kt_tokens + output_offset + kt_token_indices tl.store(output_ptr + output_indices, results * sm_scale, - mask=kt_token_mask) + mask=output_mask) def triton_rocket_paged_kt_cache_bmm( q: torch.Tensor, kt_cache_tensor: torch.Tensor, kt_cache_block_offsets: torch.Tensor, - dim_pos: torch.Tensor, kv_lens: torch.Tensor, output_offsets: torch.Tensor, kt_page_size: int, @@ -1184,7 +1184,6 @@ def triton_rocket_paged_kt_cache_bmm( q: Query tensor [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim] kt_cache_tensor: KT cache tensor kt_cache_block_offsets: Block offsets [num_gen_tokens, max_kt_blocks_per_seq] - dim_pos: Dimension offsets [num_gen_tokens, num_kv_heads, head_dim] (0 or head_dim for each dim) kv_lens: Sequence lengths [num_gen_tokens] output_offsets: Output offsets [num_gen_tokens + 1] kt_page_size: Page size for KT tokens @@ -1193,34 +1192,29 @@ def triton_rocket_paged_kt_cache_bmm( total_kt_tokens: Total number of KT tokens (fixed size) sm_scale: Scale factor for softmax Returns: - output: BMM results [num_heads, 1, total_kt_tokens] + output: BMM results [num_kv_heads, num_heads_per_kv, total_kt_tokens] """ num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape - total_num_heads = num_kv_heads * num_heads_per_kv + num_kv_heads * num_heads_per_kv if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim) - # Create output tensor with shape [num_heads, 1, total_kt_tokens] - output = torch.empty((total_num_heads, 1, total_kt_tokens), + # Create output tensor with shape [num_kv_heads, num_heads_per_kv, total_kt_tokens] + output = torch.empty((num_kv_heads, num_heads_per_kv, total_kt_tokens), dtype=torch.float32, device=q.device) - grid = lambda meta: (num_gen_tokens, total_num_heads) + grid = lambda meta: (num_gen_tokens, num_kv_heads) - KT_BLOCK_SIZE = 128 - if head_dim <= 128: - DIM_BLOCK_SIZE = 128 - elif head_dim <= 256: - DIM_BLOCK_SIZE = 256 - else: - assert False, f"Unsupported head_dim: {head_dim}" + Q_BLOCK_SIZE = num_heads_per_kv + KT_BLOCK_SIZE = 64 + DIM_BLOCK_SIZE = triton.next_power_of_2(head_dim) rocket_paged_kt_cache_bmm_kernel[grid]( q, kt_cache_tensor, kt_cache_block_offsets, - dim_pos, kv_lens, output, output_offsets, @@ -1233,6 +1227,11 @@ def triton_rocket_paged_kt_cache_bmm( max_kt_blocks_per_seq, total_kt_tokens, sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + Q_BLOCK_SIZE=Q_BLOCK_SIZE, KT_BLOCK_SIZE=KT_BLOCK_SIZE, DIM_BLOCK_SIZE=DIM_BLOCK_SIZE, ) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py index 607063b8bb9..c40908a7c3d 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py @@ -45,6 +45,9 @@ def __post_init__(self): self.page_size = self.sparse_attention_config.page_size self.topk = self.sparse_attention_config.topk + assert self.page_size == next_power_of_2( + self.page_size), "Page size must be a power of 2" + capture_graph = torch.cuda.is_current_stream_capturing() # Cumulative valid sequence lengths for query and key @@ -72,10 +75,24 @@ def __post_init__(self): dtype=torch.int32) # Context length of RocketKV key for each valid sequence - self.k_context_lens = torch.empty( - self.max_num_sequences, - device='cpu', + self.k_context_lens_cuda = self.get_empty( + self.cuda_graph_buffers, + (self.max_num_sequences, ), dtype=torch.int32, + cache_name="k_context_lens_cuda", + capture_graph=capture_graph, + ) + self.k_context_lens = torch.zeros_like(self.k_context_lens_cuda, + device='cpu', + dtype=torch.int32) + + # Start index of RocketKV key for each valid sequence + self.k_context_start_cuda = self.get_empty( + None, + (self.max_num_sequences, ), + dtype=torch.int32, + cache_name="k_context_start_cuda", + capture_graph=capture_graph, ) # Cumulative context lengths for each sequence @@ -228,13 +245,8 @@ def prepare(self): # Only consider sequences that are long enough for sparse kv indices prediction in context phase self.k_context_lens[:valid_batch_size] = self.prompt_lens_cpu[ valid_seq_indices] - self.window_size - if valid_batch_size > 0: - # Maximum context length of RocketKV key for valid sequences for padding - self.max_rocket_k_ctx_len = self.k_context_lens[: - valid_batch_size].max( - ).item() - else: - self.max_rocket_k_ctx_len = 0 + self.k_context_lens_cuda[:valid_batch_size].copy_( + self.k_context_lens[:valid_batch_size], non_blocking=True) sparse_counts_ctx = torch.zeros(self.num_contexts, dtype=torch.int32, @@ -259,6 +271,17 @@ def prepare(self): self.k_cu_seqlens_cuda[:valid_batch_size + 1].copy_( self.k_cu_seqlens[:valid_batch_size + 1], non_blocking=True) + if valid_batch_size > 0: + # Maximum context length of RocketKV key for valid sequences for padding + self.max_rocket_k_ctx_len = self.k_context_lens[: + valid_batch_size].max( + ).item() + self.total_rocket_k_ctx_tokens = self.k_cu_seqlens[ + valid_batch_size].item() + else: + self.max_rocket_k_ctx_len = 0 + self.total_rocket_k_ctx_tokens = 0 + self.valid_batch_size = valid_batch_size self.total_sparse_ctx_indices = self.sparse_offsets_ctx[ self.num_contexts].item() @@ -357,6 +380,7 @@ def sparse_kv_predict( metadata.context_cumsum_cuda, metadata.valid_seq_indices_cuda, metadata.k_cu_seqlens_cuda, + metadata.total_rocket_k_ctx_tokens, self.num_heads, self.num_kv_heads, self.head_dim, @@ -391,12 +415,32 @@ def sparse_kv_predict( padding=self.kernel_size // 2, stride=1) - selected_prefix_indices = scores.topk( - self.prompt_budget - self.window_size, - dim=-1).indices.sort().values.to(torch.int32) + # Use indexer topk prefill to select topk prefix indices + total_tasks = metadata.valid_batch_size * self.num_kv_heads + + selected_prefix_indices = torch.empty( + (total_tasks, self.prompt_budget - self.window_size), + device=qkv_input.device, + dtype=torch.int32) + + scores = scores.view(total_tasks, -1) + + row_starts = metadata.k_context_start_cuda[:metadata. + valid_batch_size].repeat_interleave( + self.num_kv_heads) + row_ends = metadata.k_context_lens_cuda[:metadata. + valid_batch_size].repeat_interleave( + self.num_kv_heads) + torch.ops.trtllm.indexer_topk_prefill( + scores, row_starts, row_ends, selected_prefix_indices, + self.prompt_budget - self.window_size) + + # Sort selected prefix indices to keep topk indices in ascending order + selected_prefix_indices = torch.sort(selected_prefix_indices, + dim=-1).values else: selected_prefix_indices = torch.empty( - (0, self.num_kv_heads, self.prompt_budget - self.window_size), + (0, self.prompt_budget - self.window_size), device=qkv_input.device, dtype=torch.int32) @@ -408,7 +452,7 @@ def sparse_kv_predict( selected_prefix_indices, metadata.prompt_lens_cuda, metadata.valid_seq_indices_cuda, sparse_kv_offsets, metadata.num_contexts, metadata.total_sparse_ctx_indices, - self.window_size, self.prompt_budget) + self.window_size, self.prompt_budget, self.num_kv_heads) # Update KT cache kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( @@ -425,6 +469,7 @@ def sparse_kv_predict( self.num_kv_heads, self.head_dim, self.page_size, + self.prompt_budget, metadata.kt_tokens_per_block, metadata.kv_cache_manager.max_kt_blocks_per_seq, ) @@ -453,20 +498,15 @@ def preprocess_for_gen( q = q.view(-1, self.num_kv_heads, self.num_heads // self.num_kv_heads, self.head_dim) - q_abs = torch.abs(q) - q_mask = torch.zeros_like(q) + return q, k - i1 = torch.topk(q_abs.mean(dim=2, keepdim=True), self.topr, + @torch.compile(dynamic=True) + def topr_filter(self, q: torch.Tensor) -> torch.Tensor: + i1 = torch.topk(q.abs().sum(dim=2, keepdim=True), self.topr, dim=-1).indices - + q_mask = torch.zeros_like(q) q_mask.scatter_(-1, i1.expand_as(q[..., :self.topr]), 1) - - q_valid = q * q_mask - - dim_pos = torch.where(q_valid.sum(dim=2) > 0, self.head_dim, - 0).to(torch.int32) - - return q_valid, k, dim_pos + return q * q_mask def sparse_attn_predict( self, @@ -478,7 +518,10 @@ def sparse_attn_predict( if metadata.num_generations == 0: return None, None - q, k, dim_pos = self.preprocess_for_gen(q, k, metadata) + q, k = self.preprocess_for_gen(q, k, metadata) + + if self.topr < self.head_dim: + q = self.topr_filter(q) kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( self.layer_idx) @@ -501,7 +544,6 @@ def sparse_attn_predict( q, kt_cache_tensor, metadata.kt_cache_block_offsets[metadata.num_contexts:], - dim_pos, metadata.kv_lens_cuda_runtime[metadata.num_contexts:], metadata.cum_kt_lens_cuda, metadata.page_size, @@ -611,6 +653,7 @@ def __init__( self.window_size = sparse_attention_config.window_size self.kernel_size = sparse_attention_config.kernel_size self.page_size = sparse_attention_config.page_size + assert sparse_attention_config.kt_cache_dtype == 'bfloat16', "Only bfloat16 kt cache is supported for Vanilla RocketKV" def _single_request_sparse_kv_predict( self, q: Optional[Tensor], k: Optional[Tensor], v: Optional[Tensor], @@ -877,6 +920,7 @@ def __init__( assert not kv_cache_config.enable_block_reuse, "RocketKV cache requires block reuse to be disabled in KV cache config" self.kt_tokens_per_block = next_power_of_2( math.ceil(tokens_per_block / sparse_attn_config.page_size)) + self.kt_cache_dtype = torch.bfloat16 if sparse_attn_config.kt_cache_dtype == 'bfloat16' else torch.float8_e5m2 super().__init__( kv_cache_config, @@ -910,7 +954,7 @@ def __init__( torch.empty((self.num_blocks, self.kt_tokens_per_block, num_kv_heads, head_dim * 2), device="cuda", - dtype=torch.bfloat16) + dtype=self.kt_cache_dtype) for _ in range(self.num_local_layers) ] self.max_kt_blocks_per_seq = self.num_blocks @@ -1026,13 +1070,19 @@ def get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping, sparse_attn_config = model_config.sparse_attention_config kt_tokens_per_block = next_power_of_2( math.ceil(tokens_per_block / sparse_attn_config.page_size)) - kv_factor = 2 + 2 * kt_tokens_per_block / tokens_per_block + kt_factor = 2 + if sparse_attn_config.kt_cache_dtype == "float8_e5m2": + kt_factor = 1 + kv_factor = 2 + kt_factor * kt_tokens_per_block / tokens_per_block mem_per_token *= kv_factor return mem_per_token def get_cache_bytes_per_token(self): # 2 for K and V, 2 * kt_tokens_per_block / tokens_per_block for KT cache - kv_factor = self.kv_factor + 2 * self.kt_tokens_per_block / self.tokens_per_block + kt_factor = 2 + if self.kt_cache_dtype == torch.float8_e5m2: + kt_factor = 1 + kv_factor = self.kv_factor + kt_factor * self.kt_tokens_per_block / self.tokens_per_block cache_size_per_token = math.ceil( kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 337127fa4ae..616f20488c1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -225,14 +225,19 @@ class RocketSparseAttentionConfig(BaseSparseAttentionConfig): """ algorithm: ClassVar[str] = "rocket" window_size: Optional[int] = Field( - default=None, description="The window size for snap KV.") + default=32, description="The window size for snap KV.") kernel_size: Optional[int] = Field( - default=None, description="The kernel size for snap KV.") - topr: Optional[Union[int, float]] = Field(default=76, description="Top-r") - topk: Optional[int] = Field(default=128, description="Top-k") - prompt_budget: Optional[int] = Field(default=1266, + default=63, description="The kernel size for snap KV.") + topr: Optional[Union[int, float]] = Field(default=128, description="Top-r") + topk: Optional[int] = Field(default=64, description="Top-k") + prompt_budget: Optional[int] = Field(default=2048, description="Prompt budget") - page_size: Optional[int] = Field(default=3, description="Page size") + page_size: Optional[int] = Field(default=4, description="Page size") + kt_cache_dtype: Optional[str] = Field( + default='float8_e5m2', + choices=['bfloat16', 'float8_e5m2'], + description="KT cache dtype", + ) @classmethod def from_dict(cls, data: dict): diff --git a/tests/unittest/_torch/attention/sparse/test_rocketkv.py b/tests/unittest/_torch/attention/sparse/test_rocketkv.py index c1b54bf4f75..0112fa5489c 100644 --- a/tests/unittest/_torch/attention/sparse/test_rocketkv.py +++ b/tests/unittest/_torch/attention/sparse/test_rocketkv.py @@ -4,6 +4,7 @@ import pytest import torch from utils.llm_data import llm_models_root +from utils.util import getSMVersion import tensorrt_llm from tensorrt_llm import LLM, SamplingParams @@ -16,6 +17,8 @@ from tensorrt_llm.mapping import Mapping +@pytest.mark.skipif(getSMVersion() < 100, + reason="RocketKV requires SM100 (Blackwell)") @pytest.mark.parametrize("backend", ["pytorch"]) @pytest.mark.parametrize("model_name", ["llama-3.1-model/Llama-3.1-8B-Instruct"]) @@ -27,10 +30,13 @@ def test_model(backend, model_name, attention_backend): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, enable_block_reuse=False) + kt_cache_dtype = 'float8_e5m2' if attention_backend == "TRTLLM" else 'bfloat16' + sparse_attention_config = RocketSparseAttentionConfig( window_size=32, kernel_size=63, prompt_budget=2048, + kt_cache_dtype=kt_cache_dtype, ) cuda_graph_config = CudaGraphConfig( @@ -164,7 +170,8 @@ def test_sparse_kv_predict(batch_size, num_contexts): window_size=32, kernel_size=3, prompt_budget=256, - page_size=3, + page_size=4, + kt_cache_dtype='bfloat16', ) # Create sequence lengths - mix short and long sequences in context phase @@ -367,17 +374,19 @@ def test_sparse_attn_predict(batch_size, num_contexts): window_size=32, kernel_size=3, prompt_budget=256, - page_size=3, + page_size=2, topk=128, topr=96, + kt_cache_dtype='bfloat16', ) sparse_attn_config_trtllm = RocketSparseAttentionConfig( window_size=32, kernel_size=3, prompt_budget=256, - page_size=3, - topk=43, + page_size=2, + topk=64, topr=96, + kt_cache_dtype='bfloat16', ) # Create sequence lengths @@ -459,7 +468,12 @@ def test_sparse_attn_predict(batch_size, num_contexts): trtllm_kt_buf = trtllm_kv_cache_manager.get_kt_buffers(layer_idx) vanilla_kt_buf = vanilla_kv_cache_manager.get_kt_buffers(layer_idx) - torch.nn.init.normal_(trtllm_kt_buf) + if trtllm_kt_buf.dtype == torch.float8_e5m2: + temp_buf = torch.empty_like(trtllm_kt_buf, dtype=torch.float16) + torch.nn.init.normal_(temp_buf) + trtllm_kt_buf.copy_(temp_buf.to(trtllm_kt_buf.dtype)) + else: + torch.nn.init.normal_(trtllm_kt_buf) # Map trtllm data to vanilla based on block offsets # TRTLLM: (num_blocks, kt_tokens_per_block, num_kv_heads, 2*head_dim) @@ -501,9 +515,11 @@ def test_sparse_attn_predict(batch_size, num_contexts): # Copy to vanilla buffer vanilla_block = vanilla_blocks[vanilla_block_idx] if vanilla_block >= 0: - vanilla_kt_buf[vanilla_block, kt_token_idx:kt_token_idx + - kt_tokens_in_this_block].copy_(trtllm_kt_buf[ - trtllm_block, :kt_tokens_in_this_block]) + vanilla_kt_buf[ + vanilla_block, kt_token_idx:kt_token_idx + + kt_tokens_in_this_block].copy_(trtllm_kt_buf[ + trtllm_block, :kt_tokens_in_this_block].to( + vanilla_kt_buf.dtype)) kt_token_idx += kt_tokens_in_this_block diff --git a/tests/unittest/_torch/attention/sparse/test_triton_bmm.py b/tests/unittest/_torch/attention/sparse/test_triton_bmm.py index b689e08c31c..93695fb26d6 100644 --- a/tests/unittest/_torch/attention/sparse/test_triton_bmm.py +++ b/tests/unittest/_torch/attention/sparse/test_triton_bmm.py @@ -167,13 +167,11 @@ def create_kt_cache_from_k( def pytorch_reference_paged_kt_cache_bmm( q: torch.Tensor, k: torch.Tensor, - dim_pos: torch.Tensor, kv_lens: torch.Tensor, kt_page_size: int, sm_scale: float = None, ) -> torch.Tensor: num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape - total_num_heads = num_kv_heads * num_heads_per_kv device = q.device if sm_scale is None: @@ -182,7 +180,10 @@ def pytorch_reference_paged_kt_cache_bmm( max_kt_tokens = max((kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens) total_kt_tokens = num_gen_tokens * max_kt_tokens - scores = torch.zeros((total_num_heads, 1, total_kt_tokens), dtype=torch.float32, device=device) + # Output shape matches kernel: [num_kv_heads, num_heads_per_kv, total_kt_tokens] + scores = torch.zeros( + (num_kv_heads, num_heads_per_kv, total_kt_tokens), dtype=torch.float32, device=device + ) # Process each generation token for batch_idx in range(num_gen_tokens): @@ -191,25 +192,27 @@ def pytorch_reference_paged_kt_cache_bmm( q_batch = q[batch_idx] # [num_kv_heads, num_heads_per_kv, head_dim] k_batch = k[batch_idx].view(num_kv_heads, head_dim) # [num_kv_heads, head_dim] - dim_pos_batch = dim_pos[batch_idx] # [num_kv_heads, head_dim] output_offset = batch_idx * max_kt_tokens for kv_head_idx in range(num_kv_heads): - for q_head_idx in range(num_heads_per_kv): - global_head_idx = kv_head_idx * num_heads_per_kv + q_head_idx + q_heads = q_batch[kv_head_idx] # [num_heads_per_kv, head_dim] + q_sum = q_heads.sum(dim=0) # [head_dim] + dim_pos_vec = q_sum > 0 # [head_dim], boolean mask - q_vec = q_batch[kv_head_idx, q_head_idx] # [head_dim] - k_vec = k_batch[kv_head_idx] # [head_dim] - dim_pos_vec = dim_pos_batch[kv_head_idx] # [head_dim] + k_vec = k_batch[kv_head_idx] # [head_dim] + + # Select k_max where dim_pos > 0, k_min otherwise + # For simplicity in test, we use k as both min and max + k_selected = torch.where(dim_pos_vec, k_vec, k_vec) - # Simulate KT selection based on dim_pos - k_selected = torch.where(dim_pos_vec > 0, k_vec, k_vec) + for q_head_idx in range(num_heads_per_kv): + q_vec = q_batch[kv_head_idx, q_head_idx] # [head_dim] # Compute score for each kt token (simplified) for kt_idx in range(num_kt_tokens): score = torch.dot(q_vec, k_selected) * sm_scale - scores[global_head_idx, 0, output_offset + kt_idx] = score + scores[kv_head_idx, q_head_idx, output_offset + kt_idx] = score return scores @@ -312,17 +315,6 @@ def test_triton_rocket_paged_kt_cache_bmm( # Create K tensor for reference: [num_gen_tokens, num_kv_heads * head_dim] k = torch.randn((num_gen_tokens, num_kv_heads * head_dim), dtype=dtype, device=device) - # Create dim_pos: [num_gen_tokens, num_kv_heads, head_dim] - # Randomly set some dimensions to head_dim (positive) and others to 0 - dim_pos = torch.zeros( - (num_gen_tokens, num_kv_heads, head_dim), dtype=torch.int32, device=device - ) - for i in range(num_gen_tokens): - for j in range(num_kv_heads): - num_positive = torch.randint(0, head_dim, (1,)).item() - positive_indices = torch.randperm(head_dim)[:num_positive] - dim_pos[i, j, positive_indices] = head_dim - # Create paged KT cache kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq = create_kt_cache_from_k( k=k, @@ -344,7 +336,6 @@ def test_triton_rocket_paged_kt_cache_bmm( q=q, kt_cache_tensor=kt_cache_tensor, kt_cache_block_offsets=kt_cache_block_offsets, - dim_pos=dim_pos, kv_lens=kv_lens_tensor, output_offsets=output_offsets, kt_page_size=kt_page_size, @@ -357,7 +348,6 @@ def test_triton_rocket_paged_kt_cache_bmm( reference_scores = pytorch_reference_paged_kt_cache_bmm( q=q, k=k, - dim_pos=dim_pos, kv_lens=kv_lens_tensor, kt_page_size=kt_page_size, sm_scale=None, diff --git a/tests/unittest/_torch/thop/parallel/test_indexer_topk.py b/tests/unittest/_torch/thop/parallel/test_indexer_topk.py index fc13106e1b6..846c6a4b103 100644 --- a/tests/unittest/_torch/thop/parallel/test_indexer_topk.py +++ b/tests/unittest/_torch/thop/parallel/test_indexer_topk.py @@ -159,7 +159,7 @@ def generate_seq_lens(batch_size, min_long_seq, num_tokens): @pytest.mark.parametrize("batch_size", [1, 64, 512, 2048]) @pytest.mark.parametrize("next_n", [1, 2]) -@pytest.mark.parametrize("index_topk", [2048]) +@pytest.mark.parametrize("index_topk", [2048, 128]) @pytest.mark.parametrize("num_tokens", [4096, 8192]) def test_indexer_topk_decode(batch_size, next_n, index_topk, num_tokens): torch.manual_seed(24) @@ -179,7 +179,7 @@ def test_indexer_topk_decode(batch_size, next_n, index_topk, num_tokens): indices = torch.empty((num_gen_tokens, index_topk), dtype=torch.int32, device="cuda") # Run CUDA implementation - torch.ops.trtllm.indexer_topk_decode(logits, seq_lens, indices, next_n) + torch.ops.trtllm.indexer_topk_decode(logits, seq_lens, indices, next_n, index_topk) torch.cuda.synchronize() @@ -198,7 +198,7 @@ def test_indexer_topk_decode(batch_size, next_n, index_topk, num_tokens): @pytest.mark.parametrize("batch_size", [1, 512, 2048]) -@pytest.mark.parametrize("index_topk", [2048]) +@pytest.mark.parametrize("index_topk", [2048, 128]) @pytest.mark.parametrize("num_tokens", [4096, 8192]) def test_indexer_topk_prefill(batch_size, index_topk, num_tokens): torch.manual_seed(24) @@ -214,7 +214,7 @@ def test_indexer_topk_prefill(batch_size, index_topk, num_tokens): indices = torch.empty((batch_size, index_topk), dtype=torch.int32, device="cuda") # Run CUDA implementation - torch.ops.trtllm.indexer_topk_prefill(logits, row_starts, row_ends, indices) + torch.ops.trtllm.indexer_topk_prefill(logits, row_starts, row_ends, indices, index_topk) # Run reference implementation torch_indices = logits.topk(min(index_topk, max(row_ends)), dim=-1)[1]