-
Notifications
You must be signed in to change notification settings - Fork 590
Open
Labels
documentationImprovements or additions to documentationImprovements or additions to documentationgood first issueGood for newcomersGood for newcomers
Milestone
Description
Some descriptions provided in docstrings and API documentations are incorrect:
- BatchPrefillWithPagedKVCacheWrapper states "The implementation backend, could be
auto/fa2,fa3orcudnn.", but in fact allows trtllm-gen as a backend.- However, the plan function of BatchPrefillWithPagedKVCacheWrapper returns an incorrect output when
causal=False. It is suspected thattrtllm-genbackend runs causal=True. Example fromflashinfer_benchmark.pybelow:
- However, the plan function of BatchPrefillWithPagedKVCacheWrapper returns an incorrect output when
## Causal case:
$ python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --causal --allow_output_mismatch --generate_repro_command
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --causal --allow_output_mismatch --generate_repro_command
[PERF] fa2 :: median time 2.492 ms; std 0.010 ms; achieved tflops 347.593 TFLOPs/sec; achieved tb_per_sec 0.108 TB/sec
[PERF] cudnn :: median time 0.862 ms; std 0.019 ms; achieved tflops 1005.157 TFLOPs/sec; achieved tb_per_sec 0.311 TB/sec
[PERF] trtllm-gen :: median time 0.841 ms; std 0.027 ms; achieved tflops 1030.423 TFLOPs/sec; achieved tb_per_sec 0.319 TB/sec
## Same dimensions but non-causal case
# Output mismatch happens and time is nearly identical to causal case when exepected to be 2x
python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command
[ERROR] Output tensor mismatch between backends fa2 and trtllm-gen: 26629186 / 59564032 (44.71%) elements are different
[PERF] fa2 :: median time 4.684 ms; std 0.005 ms; achieved tflops 369.869 TFLOPs/sec; achieved tb_per_sec 0.057 TB/sec
[PERF] cudnn :: median time 1.442 ms; std 0.039 ms; achieved tflops 1201.270 TFLOPs/sec; achieved tb_per_sec 0.186 TB/sec
[PERF] trtllm-gen :: median time 0.849 ms; std 0.033 ms; achieved tflops 2041.414 TFLOPs/sec; achieved tb_per_sec 0.316 TB/sec
- BatchPrefillWithRaggedKVCacheWrapper states that backends
auto,fa2,fa3andtrtllm-genare supported, but:cutlassis also supported but omitted.trtllm-genis not supported an outputs it is not supported yet (message printed here)
yzh119
Metadata
Metadata
Assignees
Labels
documentationImprovements or additions to documentationImprovements or additions to documentationgood first issueGood for newcomersGood for newcomers