[Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass#38343
Conversation
…implify warmup to single pass Signed-off-by: AuYang <459461160@qq.com>
There was a problem hiding this comment.
Code Review
This pull request simplifies the FLA kernel configuration by removing the FLA_GDN_FIX_BT environment variable and standardizing the block size BT to always equal chunk_size. As a result, the kernel warmup logic in gdn_linear_attn.py has been streamlined from multiple iterations to a single pass with T=64. I have no feedback to provide.
|
FIX the pre-commit error please |
Sorry, are you referring to the pre-commit / pre-run-check that failed due to the message "PR must have the 'ready' label or the author must have at least 4 merged PRs (found 2)," or referring to the skipped pre-commit / pre-commit? |
Oh sorry, I didn't realize it |
|
@claude review |
|
I'd propose to make some E2E bench just in case |
There was a problem hiding this comment.
This is a clean upstream sync that simplifies the BT computation and warmup logic, but the inference kernel changes deserve expert eyes before approval.
Extended reasoning...
Overview
The PR syncs fla-org/flash-linear-attention#619 into vLLM's fork, making three coordinated changes: (1) fixing BT = chunk_size unconditionally in chunk_fwd_kernel_o rather than dynamically computing min(chunk_size, max(16, next_power_of_2(T))), (2) removing the now-unnecessary FLA_GDN_FIX_BT env-var flag from utils.py, and (3) collapsing the 3-pass warmup loop (T=16,32,64) in _warmup_prefill_kernels to a single pass at T=64.
Security Risks
None. This is purely a kernel dispatch and warmup optimization with no auth, crypto, or data-exposure surface.
Level of Scrutiny
The change is logically sound — fixing BT=64 eliminates the need to warm multiple Triton autotune cache entries — and the kernel already has boundary checks (m_t = o_t < T, boundary_check=(0,1)) that correctly handle T < 64. The test results demonstrate end-to-end correctness and a ~35% warmup speedup. That said, this touches production inference kernel dispatch and the warmup path that guards against OOM during autotuning. A domain expert familiar with the FLA/GDN kernel stack should give it a final look before merging.
Other Factors
No bugs were found. The author is a relatively new contributor (2 merged PRs), and a knowledgeable reviewer has been tagged. The diff is small and well-described, making human review straightforward.
|
Looks good to me. Since now we have less warmup workload probably Triton autotuning will consume less memory. If possible @AuYang261, please, check how much gpu memory this change saves now. LGTM |
|
I have run an benchmark, including correctness testing and time consuming.
completely consistent
No regression in memory, throughput, or correctness. Benchmark script |
It appears that no GPU memory was actually saved: the available KVCache capacity remained the same in tests conducted both before and after the modification. This is somewhat strange. |
|
Try to count how many Triton autotuning kernels were done in warmup before and after this change |
Signed-off-by: AuYang <459461160@qq.com>
I also ran the benchmark using the non- before after The performance is essentially the same. |
|
@AuYang261 what about accuracy test |
I have accuracy validation at two levels: Kernel correctness: I saved output tensors before the fix and compared after, confirming bitwise equality (max_diff = 0.0) for all T values (16, 32, 64, 128, 512). The bench script only prints norms for quick visual sanity checking. E2E generation: With temperature=0 (greedy decoding), the model produces identical output text before and after the fix. |
Let me give it a try. |
I have profiled the Triton autotuning behavior during warmup using TRITON_PRINT_AUTOTUNING=1 and checked the disk cache. Here are the results:
As expected, removing T=16 and T=32 avoided the compilation and autotuning of 36 configurations per T value ( Regarding the KVCache size that remain exactly the same (even down to the byte). After digging into it, this makes sense for two reasons:
While this change doesn't technically yield more KV Cache blocks due to how memory profiling works, it still brings solid benefits:
|
…ication - Guard chunk_indices/chunk_offsets pre-computation with backend check so FlashInfer path skips unnecessary CPU work and HtoD copies - Add warning_once in forward_cuda for unexpected non-None chunk params - Fix chunk_kda_scaled_dot_kkt_fwd default to FLA_CHUNK_SIZE and pass chunk_size explicitly from chunk_kda_fwd - Simplify BT computation in chunk_o.py (sync with PR vllm-project#38343) - Remove unused FLA_GDN_FIX_BT env var - Fix mypy union-attr error for additional_config.get() Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
|
Good results, thank you! |
…ication - Guard chunk_indices/chunk_offsets pre-computation with backend check so FlashInfer path skips unnecessary CPU work and HtoD copies - Add warning_once in forward_cuda for unexpected non-None chunk params - Fix chunk_kda_scaled_dot_kkt_fwd default to FLA_CHUNK_SIZE and pass chunk_size explicitly from chunk_kda_fwd - Simplify BT computation in chunk_o.py (sync with PR vllm-project#38343) - Remove unused FLA_GDN_FIX_BT env var - Fix mypy union-attr error for additional_config.get() Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
…implify warmup to single pass (vllm-project#38343) Signed-off-by: AuYang <459461160@qq.com> Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Purpose
This PR syncs fla-org/flash-linear-attention#619 into the vLLM fork, as suggested by @lgeiger on #36599.
Upstream simplified chunk_fwd_kernel_o to always use BT = chunk_size (64) instead of the dynamic formula min(chunk_size, max(16, next_power_of_2(T))). This PR applies the same change and removes the FLA_GDN_FIX_BT flag.
Why the loop was needed before: BT was part of the Triton autotune key for chunk_fwd_kernel_o. With small sequences (T < 64), BT was computed as next_power_of_2(T) (16 or 32), resulting in distinct kernel variants that each required a separate benchmark pass. With BT fixed at chunk_size, the autotuner cache is fully populated after one pass.
Changes
chunk_o.py: BT = chunk_size unconditionally; remove FLA_GDN_FIX_BT import
utils.py: remove FLA_GDN_FIX_BT env-var flag
gdn_linear_attn.py: replace 3-pass warmup loop (T = 16, 32, 64) with a single pass at T = 64
Test Plan
vllm serve Qwen/Qwen3.5-9B --host 0.0.0.0 --port 8001 --max-model-len 4096 --enforce-eagerTested on Qwen3.5-9B, RTX5090 32GB, enforce_eager=True (Otherwise, an OOM error will occur on the RTX 5090. #37700—if merged—would resolve this issue.):
warmup time is the time consuming of
_warmup_prefill_kernelsvllm serve Qwen/Qwen3.5-9B --host 0.0.0.0 --port 8001 --max-model-len 4096 --enforce-eagerEssential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.