Skip to content

Inaccurate API Docstrings for Attention Prefill #1709

@bkryu

Description

@bkryu

Some descriptions provided in docstrings and API documentations are incorrect:

  1. BatchPrefillWithPagedKVCacheWrapper states "The implementation backend, could be auto/fa2,fa3 or cudnn.", but in fact allows trtllm-gen as a backend.
    • However, the plan function of BatchPrefillWithPagedKVCacheWrapper returns an incorrect output when causal=False. It is suspected that trtllm-gen backend runs causal=True. Example from flashinfer_benchmark.py below:
## Causal case:
$ python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --causal --allow_output_mismatch --generate_repro_command
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --causal --allow_output_mismatch --generate_repro_command
[PERF] fa2            :: median time 2.492 ms; std 0.010 ms; achieved tflops 347.593 TFLOPs/sec; achieved tb_per_sec 0.108 TB/sec
[PERF] cudnn          :: median time 0.862 ms; std 0.019 ms; achieved tflops 1005.157 TFLOPs/sec; achieved tb_per_sec 0.311 TB/sec
[PERF] trtllm-gen     :: median time 0.841 ms; std 0.027 ms; achieved tflops 1030.423 TFLOPs/sec; achieved tb_per_sec 0.319 TB/sec

## Same dimensions but non-causal case
# Output mismatch happens and time is nearly identical to causal case when exepected to be 2x
python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 8192 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command
[ERROR] Output tensor mismatch between backends fa2 and trtllm-gen: 26629186 / 59564032 (44.71%) elements are different
[PERF] fa2            :: median time 4.684 ms; std 0.005 ms; achieved tflops 369.869 TFLOPs/sec; achieved tb_per_sec 0.057 TB/sec
[PERF] cudnn          :: median time 1.442 ms; std 0.039 ms; achieved tflops 1201.270 TFLOPs/sec; achieved tb_per_sec 0.186 TB/sec
[PERF] trtllm-gen     :: median time 0.849 ms; std 0.033 ms; achieved tflops 2041.414 TFLOPs/sec; achieved tb_per_sec 0.316 TB/sec
  1. BatchPrefillWithRaggedKVCacheWrapper states that backends auto, fa2, fa3 and trtllm-gen are supported, but:
    • cutlass is also supported but omitted.
    • trtllm-gen is not supported an outputs it is not supported yet (message printed here)

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentationgood first issueGood for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions