Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cpp/tensorrt_llm/thop/IndexerTopKOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ void indexer_topk_decode(
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");

TORCH_CHECK(next_n > 0, "next_n must be greater than 0");
TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now");

int32_t num_rows = static_cast<int32_t>(numRows64);
int32_t num_columns = static_cast<int32_t>(numColumns64);
Expand Down Expand Up @@ -95,7 +94,6 @@ void indexer_topk_prefill(th::Tensor const& logits, th::Tensor const& row_starts

TORCH_CHECK(indices.dim() == 2, "indices must be a 2D Tensor");
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D Tensor");
TORCH_CHECK(index_topk == 2048, "index_topk must be 2048 for now");

auto const inputSize = logits.sizes();
auto const numRows64 = inputSize[0];
Expand Down
16 changes: 15 additions & 1 deletion examples/llm-api/llm_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def parse_arguments():
type=str,
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
)

# Build config
parser.add_argument('--algo',
type=str,
Expand All @@ -53,6 +54,8 @@ def parse_arguments():
type=str,
default='TRTLLM',
choices=['VANILLA', 'TRTLLM'])

# RocketKV config
parser.add_argument('--window_size',
type=int,
default=32,
Expand All @@ -65,6 +68,14 @@ def parse_arguments():
type=int,
default=2048,
help="The prompt budget for RocketKV.")
parser.add_argument('--topk',
type=int,
default=64,
help='Top-k for RocketKV')
parser.add_argument('--kt_cache_dtype',
type=str,
default='float8_e5m2',
choices=['bfloat16', 'float8_e5m2'])
parser.add_argument('--index_max_chunk_size',
type=int,
default=32768,
Expand Down Expand Up @@ -106,6 +117,7 @@ def parse_arguments():
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
parser.add_argument('--tokens_per_block', type=int, default=32)
parser.add_argument('--num_samples', type=int, default=10)

# Runtime
Expand Down Expand Up @@ -139,8 +151,8 @@ def run_llm(args, sparse_attention_config):
enable_block_reuse=
False, # sparse attention does not support kv cache reuse now
free_gpu_memory_fraction=args.kv_cache_fraction,
tokens_per_block=args.tokens_per_block,
dtype=args.kv_cache_dtype,
tokens_per_block=64,
)

cuda_graph_config = CudaGraphConfig(
Expand Down Expand Up @@ -191,6 +203,8 @@ def run_RocketKV(args):
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.prompt_budget,
topk=args.topk,
kt_cache_dtype=args.kt_cache_dtype,
)
run_llm(args, sparse_attention_config)

Expand Down
24 changes: 20 additions & 4 deletions examples/longbench/eval_longbench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,22 @@ def parse_arguments() -> argparse.Namespace:
type=int,
default=63,
help='Kernel size for RocketKV')
parser.add_argument('--topr',
parser.add_argument('--topk',
type=int,
default=90,
help='Top-r for RocketKV')
default=64,
help='Top-k for RocketKV')
parser.add_argument('--kt_cache_dtype',
type=str,
default='float8_e5m2',
choices=['bfloat16', 'float8_e5m2'],
help='KT cache data type')

# KV cache configuration
parser.add_argument('--kv_cache_dtype',
type=str,
default='auto',
help='KV cache data type')
parser.add_argument('--tokens_per_block', type=int, default=32)
parser.add_argument('--kv_cache_fraction',
type=float,
default=0.7,
Expand Down Expand Up @@ -320,6 +326,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
# sparse attention doesn't support KV cache reuse
enable_block_reuse=False,
free_gpu_memory_fraction=args.kv_cache_fraction,
tokens_per_block=args.tokens_per_block,
)

# Configure CUDA graph
Expand All @@ -335,7 +342,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.token_budget,
topr=args.topr,
topk=args.topk,
kt_cache_dtype=args.kt_cache_dtype,
)
logger.info(f"Using RocketKV sparse attention")
else:
Expand Down Expand Up @@ -427,6 +435,14 @@ def evaluate_single_dataset(
formatted_prompt = format_prompt_style(sample, prompt_format,
chat_template, dataset,
tokenizer)
# Truncate prompt if it's too long
token_ids = tokenizer.encode(formatted_prompt, truncation=False)
if len(token_ids) > args.max_seq_len:
half = (args.max_seq_len - max_new_tokens) // 2
formatted_prompt = tokenizer.decode(
token_ids[:half], skip_special_tokens=True) + tokenizer.decode(
token_ids[-half:], skip_special_tokens=True)

prompts.append(formatted_prompt)

if len(prompts) == 0:
Expand Down
14 changes: 10 additions & 4 deletions examples/longbench/eval_longbench_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,15 @@ def parse_arguments() -> argparse.Namespace:
type=int,
default=63,
help='Kernel size for RocketKV')
parser.add_argument('--topr',
parser.add_argument('--topk',
type=int,
default=90,
help='Top-r for RocketKV')
default=64,
help='Top-k for RocketKV')
parser.add_argument('--kt_cache_dtype',
type=str,
default='float8_e5m2',
choices=['bfloat16', 'float8_e5m2'],
help='KT cache data type')

# KV cache configuration
parser.add_argument('--kv_cache_dtype',
Expand Down Expand Up @@ -356,7 +361,8 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.token_budget,
topr=args.topr,
topk=args.topk,
kt_cache_dtype=args.kt_cache_dtype,
)
logger.info(f"Using RocketKV sparse attention")
else:
Expand Down
Loading