Skip to content

[Fix]Upgrade FlashInfer to v0.6.4+ to Resolve SM90 Performance Regression#99

Closed
cswuyg wants to merge 1 commit intosgl-project:mainfrom
cswuyg:feature/cswuyg_flashinfer_version2
Closed

[Fix]Upgrade FlashInfer to v0.6.4+ to Resolve SM90 Performance Regression#99
cswuyg wants to merge 1 commit intosgl-project:mainfrom
cswuyg:feature/cswuyg_flashinfer_version2

Conversation

@cswuyg
Copy link
Contributor

@cswuyg cswuyg commented Mar 9, 2026

🚀 Problem Description

On SM90 (Hopper) architectures, FlashInfer versions prior to v0.6.4 exhibit a significant performance regression within the CUDAGraphBatchDecodeWithPagedKVCacheWrapper.

  • Root Cause: In earlier versions, the auto backend incorrectly defaults to FA3 (FlashAttention-3).
  • Performance Impact: Benchmarks during the Decode phase under long sequence evaluations show that the PrefillWithKVCacheKernel (FA3) performs approximately 5x slower than the BatchPrefillWithPagedKVCacheKernel (FA2).

🛠️ Proposed Solution

We are specifying the FlashInfer dependency version to be >= v0.6.4.

Two solutions were considered:

  1. Manual Workaround (Rejected): Replace CUDAGraphBatchDecodeWithPagedKVCacheWrapper with BatchDecodeWithPagedKVCacheWrapper and manually specify backend="fa2".
  2. Version Upgrade (Selected): Enforce a version bump to leverage the fixed auto dispatch logic in newer releases.

📊 Benchmark Comparison (Decode Phase - Long Sequence)

Backend Kernel Performance
FA3 (Old Auto) PrefillWithKVCacheKernel ~5x Latency ❌
FA2 (New Auto) BatchPrefillWithPagedKVCacheKernel Baseline

🔗 FlashInfer Fix Reference

@DarkSharpness
Copy link
Collaborator

Thanks! We already fix the attn backend to fa2 long ago to avoid the regression.

self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self.float_workspace_buffer,
kv_layout="NHD",
backend="fa2", # flashinfer fa3 is slow, use fa2 instead
)
self.decode_wrappers = BatchDecodeWithPagedKVCacheWrapper(
self.float_workspace_buffer,
use_tensor_cores=self.use_tensor_cores,
kv_layout="NHD",
backend="fa2", # flashinfer fa3 is slow, use fa2 instead
)

self.graph_wrappers[bs]._backend = "fa2"

@cswuyg cswuyg closed this Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants