Skip to content

[Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass#38343

Merged
vadiklyutiy merged 4 commits intovllm-project:mainfrom
AuYang261:fix-gdn-bt-warmup
Mar 31, 2026
Merged

[Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass#38343
vadiklyutiy merged 4 commits intovllm-project:mainfrom
AuYang261:fix-gdn-bt-warmup

Conversation

@AuYang261
Copy link
Copy Markdown
Contributor

Purpose

This PR syncs fla-org/flash-linear-attention#619 into the vLLM fork, as suggested by @lgeiger on #36599.

Upstream simplified chunk_fwd_kernel_o to always use BT = chunk_size (64) instead of the dynamic formula min(chunk_size, max(16, next_power_of_2(T))). This PR applies the same change and removes the FLA_GDN_FIX_BT flag.

Why the loop was needed before: BT was part of the Triton autotune key for chunk_fwd_kernel_o. With small sequences (T < 64), BT was computed as next_power_of_2(T) (16 or 32), resulting in distinct kernel variants that each required a separate benchmark pass. With BT fixed at chunk_size, the autotuner cache is fully populated after one pass.

Changes

chunk_o.py: BT = chunk_size unconditionally; remove FLA_GDN_FIX_BT import
utils.py: remove FLA_GDN_FIX_BT env-var flag
gdn_linear_attn.py: replace 3-pass warmup loop (T = 16, 32, 64) with a single pass at T = 64

Test Plan

vllm serve Qwen/Qwen3.5-9B --host 0.0.0.0 --port 8001 --max-model-len 4096 --enforce-eager

Tested on Qwen3.5-9B, RTX5090 32GB, enforce_eager=True (Otherwise, an OOM error will occur on the RTX 5090. #37700—if merged—would resolve this issue.):

Metric Before After
Layer-0 warmup time 13.86 s (3 passes) 9.04 s (1 pass)
End-to-end inference

warmup time is the time consuming of _warmup_prefill_kernels

vllm serve Qwen/Qwen3.5-9B --host 0.0.0.0 --port 8001 --max-model-len 4096 --enforce-eager


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

…implify warmup to single pass

Signed-off-by: AuYang <459461160@qq.com>
@AuYang261 AuYang261 requested a review from tdoublep as a code owner March 27, 2026 08:47
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the FLA kernel configuration by removing the FLA_GDN_FIX_BT environment variable and standardizing the block size BT to always equal chunk_size. As a result, the kernel warmup logic in gdn_linear_attn.py has been streamlined from multiple iterations to a single pass with T=64. I have no feedback to provide.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 30, 2026

FIX the pre-commit error please

@AuYang261
Copy link
Copy Markdown
Contributor Author

FIX the pre-commit error please

Sorry, are you referring to the pre-commit / pre-run-check that failed due to the message "PR must have the 'ready' label or the author must have at least 4 merged PRs (found 2)," or referring to the skipped pre-commit / pre-commit?

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 30, 2026

FIX the pre-commit error please

Sorry, are you referring to the pre-commit / pre-run-check that failed due to the message "PR must have the 'ready' label or the author must have at least 4 merged PRs (found 2)," or referring to the skipped pre-commit / pre-commit?

Oh sorry, I didn't realize it

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 30, 2026

cc @vadiklyutiy @arpera

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@claude review

@vadiklyutiy vadiklyutiy moved this to In review in Qwen3.5 Mar 30, 2026
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

I'd propose to make some E2E bench just in case

Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clean upstream sync that simplifies the BT computation and warmup logic, but the inference kernel changes deserve expert eyes before approval.

Extended reasoning...

Overview

The PR syncs fla-org/flash-linear-attention#619 into vLLM's fork, making three coordinated changes: (1) fixing BT = chunk_size unconditionally in chunk_fwd_kernel_o rather than dynamically computing min(chunk_size, max(16, next_power_of_2(T))), (2) removing the now-unnecessary FLA_GDN_FIX_BT env-var flag from utils.py, and (3) collapsing the 3-pass warmup loop (T=16,32,64) in _warmup_prefill_kernels to a single pass at T=64.

Security Risks

None. This is purely a kernel dispatch and warmup optimization with no auth, crypto, or data-exposure surface.

Level of Scrutiny

The change is logically sound — fixing BT=64 eliminates the need to warm multiple Triton autotune cache entries — and the kernel already has boundary checks (m_t = o_t < T, boundary_check=(0,1)) that correctly handle T < 64. The test results demonstrate end-to-end correctness and a ~35% warmup speedup. That said, this touches production inference kernel dispatch and the warmup path that guards against OOM during autotuning. A domain expert familiar with the FLA/GDN kernel stack should give it a final look before merging.

Other Factors

No bugs were found. The author is a relatively new contributor (2 merged PRs), and a knowledgeable reviewer has been tagged. The diff is small and well-described, making human review straightforward.

@arpera
Copy link
Copy Markdown
Contributor

arpera commented Mar 30, 2026

Looks good to me. Since now we have less warmup workload probably Triton autotuning will consume less memory. If possible @AuYang261, please, check how much gpu memory this change saves now.

LGTM

@vadiklyutiy vadiklyutiy added ready ONLY add when PR is ready to merge/full CI is needed qwen Related to Qwen models labels Mar 30, 2026
@AuYang261
Copy link
Copy Markdown
Contributor Author

I have run an benchmark, including correctness testing and time consuming.
Setup: Qwen3.5-9B, A100 80 GB, enforce_eager=True, max_model_len=4096, v0.18.1rc1.dev156+gbf5eec638,
TRITON_CACHE_AUTOTUNING=0 for warmup timing.

  1. Kernel correctness (chunk_gated_delta_rule)
    Tested with torch.manual_seed(42), proper log-decay g = cumsum(-|randn|):
  • before
    T=  16: nan=False inf=False norm=0.6016
    T=  32: nan=False inf=False norm=0.7734
    T=  64: nan=False inf=False norm=1.0781
    T= 128: nan=False inf=False norm=1.5859
    T= 512: nan=False inf=False norm=3.0469
    
  • after
    T=  16: nan=False inf=False norm=0.6016
    T=  32: nan=False inf=False norm=0.7734
    T=  64: nan=False inf=False norm=1.0781
    T= 128: nan=False inf=False norm=1.5859
    T= 512: nan=False inf=False norm=3.0469
    

completely consistent

  1. E2E inference correctness
  • before
    ...
    Available KV cache memory: 2.35 GiB
    GPU KV cache size: 19,008 tokens
    ...
    Generation: 8.63s  (input=11 tok, output=363 tok)
    Throughput: input 1.3 tok/s, output 42.1 tok/s
    ...
    
  • after
    ...
    Available KV cache memory: 2.35 GiB
    GPU KV cache size: 19,008 tokens
    ...
    Generation: 8.09s  (input=11 tok, output=363 tok)
    Throughput: input 1.4 tok/s, output 44.9 tok/s
    ...
    

No regression in memory, throughput, or correctness.

Benchmark script
import time

import torch
from vllm import LLM, SamplingParams
from vllm.triton_utils.allocation import set_triton_allocator
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule

if __name__ == "__main__":
    # ── 1. Kernel correctness ────────────────────────────────────────
    print("=" * 60)
    print("1. Kernel correctness (chunk_gated_delta_rule)")
    print("=" * 60)

    device = torch.device("cuda")
    set_triton_allocator(device)

    for T in [16, 32, 64, 128, 512]:
        torch.manual_seed(42)
        B, H, K, V = 1, 4, 64, 64
        q = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16)
        k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16)
        v = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16)
        log_decay = -torch.abs(
            torch.randn(B, T, H, device=device, dtype=torch.bfloat16)
        )
        g = torch.cumsum(log_decay, dim=1)
        beta = torch.sigmoid(torch.randn(B, T, H, device=device, dtype=torch.bfloat16))
        state = torch.zeros(B, H, K, V, device=device, dtype=torch.bfloat16)
        cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
        o, _ = chunk_gated_delta_rule(
            q,
            k,
            v,
            g,
            beta,
            scale=K**-0.5,
            initial_state=state,
            output_final_state=True,
            cu_seqlens=cu_seqlens,
            use_qk_l2norm_in_kernel=True,
        )
        has_nan = o.isnan().any().item()
        has_inf = o.isinf().any().item()
        print(f"  T={T:4d}: nan={has_nan} inf={has_inf} norm={o.norm().item():.4f}")

    # ── 2. E2E inference (warmup + generation) ───────────────────────
    print()
    print("=" * 60)
    print("2. E2E inference (Qwen3.5-9B, enforce_eager, max_model_len=4096)")
    print("=" * 60)

    t_start = time.perf_counter()
    llm = LLM(
        model="Qwen/Qwen3.5-9B",
        tensor_parallel_size=1,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=0.82,
    )
    t_init = time.perf_counter() - t_start
    print(f"  Engine init (incl. warmup): {t_init:.2f}s")

    tokenizer = llm.get_tokenizer()
    formatted = tokenizer.apply_chat_template(
        [{"role": "user", "content": "hello"}],
        add_generation_prompt=True,
        tokenize=False,
    )

    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=4096,
        stop_token_ids=[tokenizer.eos_token_id],
    )

    t_gen = time.perf_counter()
    outputs = llm.generate(prompts=[formatted], sampling_params=sampling_params)
    t_gen = time.perf_counter() - t_gen

    out_text = outputs[0].outputs[0].text
    n_input = len(outputs[0].prompt_token_ids)
    n_output = len(outputs[0].outputs[0].token_ids)
    print(f"  Generation: {t_gen:.2f}s  (input={n_input} tok, output={n_output} tok)")
    print(
        f"  Throughput: input {n_input/t_gen:.1f} tok/s, output {n_output/t_gen:.1f} tok/s"
    )
    print(f"  Output preview: {out_text}...")

@AuYang261
Copy link
Copy Markdown
Contributor Author

Looks good to me. Since now we have less warmup workload probably Triton autotuning will consume less memory. If possible @AuYang261, please, check how much gpu memory this change saves now.

LGTM

It appears that no GPU memory was actually saved: the available KVCache capacity remained the same in tests conducted both before and after the modification. This is somewhat strange.

@arpera
Copy link
Copy Markdown
Contributor

arpera commented Mar 30, 2026

Try to count how many Triton autotuning kernels were done in warmup before and after this change

Signed-off-by: AuYang <459461160@qq.com>
@AuYang261 AuYang261 requested a review from vadiklyutiy as a code owner March 31, 2026 01:49
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you also add some eval? not just enforce-eager

@AuYang261
Copy link
Copy Markdown
Contributor Author

LGTM. Could you also add some eval? not just enforce-eager

I also ran the benchmark using the non-enforce-eager mode (with applying #37700 to avoid OOM). It executed successfully both before and after:

before

...
Available KV cache memory: 5.96 GiB
GPU KV cache size: 48,576 tokens
...
Generation: 3.48s  (input=11 tok, output=299 tok)
Throughput: input 3.2 tok/s, output 86.0 tok/s
...

after

...
Available KV cache memory: 5.96 GiB
GPU KV cache size: 48,576 tokens
...
Generation: 3.49s  (input=11 tok, output=299 tok)
Throughput: input 3.2 tok/s, output 85.8 tok/s
...

The performance is essentially the same.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 31, 2026

@AuYang261 what about accuracy test

@AuYang261
Copy link
Copy Markdown
Contributor Author

@AuYang261 what about accuracy test

I have accuracy validation at two levels:

Kernel correctness: I saved output tensors before the fix and compared after, confirming bitwise equality (max_diff = 0.0) for all T values (16, 32, 64, 128, 512). The bench script only prints norms for quick visual sanity checking.

E2E generation: With temperature=0 (greedy decoding), the model produces identical output text before and after the fix.

@AuYang261
Copy link
Copy Markdown
Contributor Author

Try to count how many Triton autotuning kernels were done in warmup before and after this change

Let me give it a try.

@AuYang261
Copy link
Copy Markdown
Contributor Author

Try to count how many Triton autotuning kernels were done in warmup before and after this change

I have profiled the Triton autotuning behavior during warmup using TRITON_PRINT_AUTOTUNING=1 and checked the disk cache. Here are the results:

  • Before the change (T in 16, 32, 64): The autotuner evaluated 176 kernel configurations. The ~/.triton/cache/ size was 78MB.
  • After the change (T=64 only): The autotuner evaluated 108 kernel configurations. The ~/.triton/cache/ size dropped to 56MB.

As expected, removing T=16 and T=32 avoided the compilation and autotuning of 36 configurations per T value (chunk_fwd_kernel_o kernel).

Regarding the KVCache size that remain exactly the same (even down to the byte). After digging into it, this makes sense for two reasons:

  • PyTorch Allocator vs. CUDA Context: Triton JIT compilations and configs consume disk space and CUDA Context memory, but they do not use the PyTorch memory allocator. Since vLLM calculates available KV Cache based on torch.accelerator.memory_stats(self.device).get("allocated_bytes.all.peak", 0), the memory saved from skipping these Triton configs is completely invisible to vLLM's memory profiler.
  • Peak Activation Memory: During warmup profiling, the peak memory allocation is determined by the largest forward pass (i.e., T=64). The runs for T=16 and T=32 produce smaller memory peaks that are overshadowed by T=64. Thus, removing them does not lower the absolute peak memory recorded.

While this change doesn't technically yield more KV Cache blocks due to how memory profiling works, it still brings solid benefits:

  • Faster startup time by skipping ~70 redundant autotuning sweeps and kernel compilations during warmup.
  • Reduced disk space usage (~22MB saved in ~/.triton/cache/).

arpera added a commit to arpera/vllm that referenced this pull request Mar 31, 2026
…ication

- Guard chunk_indices/chunk_offsets pre-computation with backend check
  so FlashInfer path skips unnecessary CPU work and HtoD copies
- Add warning_once in forward_cuda for unexpected non-None chunk params
- Fix chunk_kda_scaled_dot_kkt_fwd default to FLA_CHUNK_SIZE and pass
  chunk_size explicitly from chunk_kda_fwd
- Simplify BT computation in chunk_o.py (sync with PR vllm-project#38343)
- Remove unused FLA_GDN_FIX_BT env var
- Fix mypy union-attr error for additional_config.get()

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
@arpera
Copy link
Copy Markdown
Contributor

arpera commented Mar 31, 2026

Good results, thank you!

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contribution

arpera added a commit to arpera/vllm that referenced this pull request Mar 31, 2026
…ication

- Guard chunk_indices/chunk_offsets pre-computation with backend check
  so FlashInfer path skips unnecessary CPU work and HtoD copies
- Add warning_once in forward_cuda for unexpected non-None chunk params
- Fix chunk_kda_scaled_dot_kkt_fwd default to FLA_CHUNK_SIZE and pass
  chunk_size explicitly from chunk_kda_fwd
- Simplify BT computation in chunk_o.py (sync with PR vllm-project#38343)
- Remove unused FLA_GDN_FIX_BT env var
- Fix mypy union-attr error for additional_config.get()

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
@vadiklyutiy vadiklyutiy merged commit b779eb3 into vllm-project:main Mar 31, 2026
56 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in Qwen3.5 Mar 31, 2026
@AuYang261 AuYang261 deleted the fix-gdn-bt-warmup branch April 1, 2026 01:43
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…implify warmup to single pass (vllm-project#38343)

Signed-off-by: AuYang <459461160@qq.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants