Skip to content

Fix/fla triton autotune oom 34954#36384

Closed
oneraghavan wants to merge 1 commit intovllm-project:mainfrom
oneraghavan:fix/fla-triton-autotune-oom-34954
Closed

Fix/fla triton autotune oom 34954#36384
oneraghavan wants to merge 1 commit intovllm-project:mainfrom
oneraghavan:fix/fla-triton-autotune-oom-34954

Conversation

@oneraghavan
Copy link
Contributor

@oneraghavan oneraghavan commented Mar 8, 2026

Fixes #34954

Qwen3-Next / Qwen3.5 models use Flash Linear Attention (FLA) with Triton kernels that are decorated with @triton.autotune. When these kernels are invoked for the first time, Triton benchmarks all candidate configurations (up to 24 per kernel across 8 kernel files), each requiring temporary GPU memory allocations.

During vLLM's initialization sequence:

profile_run() — runs a dummy forward pass to measure memory usage. attn_metadata is set to None, so Qwen3NextGatedDeltaNet._forward_core returns early and the FLA Triton kernels never execute.

initialize_kv_cache() — allocates KV cache blocks, consuming nearly all remaining GPU memory.

First real inference request — FLA kernels execute for the first time, triggering @triton.autotune benchmarking. With only a few hundred MB of free memory, the autotuner's temporary allocations fail with torch.cuda.OutOfMemoryError.

Fix
Add a _warmup_fla_kernels() class method to Qwen3NextGatedDeltaNet that runs fla_chunk_gated_delta_rule with small dummy tensors. This is called during profile_run() (when attn_metadata is None), forcing Triton to autotune while GPU memory is still plentiful. The cached autotune results are reused for all subsequent kernel invocations without needing further benchmarking.

Key properties:

One-shot: A class-level flag (_fla_kernels_warmed_up) ensures warmup runs exactly once across all GatedDeltaNet layers, since they share the same autotune key values (H, K, V, BT).

Minimal overhead: Uses small tensors (T=128, single sequence) — just enough to trigger autotuning.

fp32 safety: Casts fp32 parameters to bf16 for warmup, since the FLA kernels don't support fp32.

@mergify mergify bot added the qwen Related to Qwen models label Mar 8, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a potential out-of-memory error during Triton kernel autotuning for Flash Linear Attention by pre-warming the kernels during the profile_run phase. The approach is sound and the accompanying tests are thorough. I've identified a potential race condition in the one-shot warmup logic and suggested a fix to ensure thread safety.

Comment on lines +627 to +678
_fla_kernels_warmed_up: bool = False

@classmethod
def _warmup_fla_kernels(cls, layer: "Qwen3NextGatedDeltaNet") -> None:
"""Warm up FLA Triton kernels by running them with small dummy inputs.

During vLLM's profile_run, attn_metadata is None so _forward_core
returns early and the FLA Triton kernels never execute. Since these
kernels use @triton.autotune, their first invocation triggers
benchmarking of all configurations. If that first invocation happens
after the KV cache has consumed most GPU memory, the autotuner OOMs.

This class method forces one execution with minimal tensors so
autotuning happens while memory is still plentiful. It only runs
once across all layers since they share the same autotune key values.
"""
if cls._fla_kernels_warmed_up:
return
cls._fla_kernels_warmed_up = True

H_k = layer.num_k_heads // layer.tp_size
H_v = layer.num_v_heads // layer.tp_size
K = layer.head_k_dim
V = layer.head_v_dim
T = 128
N = 1
device = next(layer.parameters()).device
dtype = next(layer.parameters()).dtype
if dtype == torch.float32:
dtype = torch.bfloat16

q = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
k = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
v = torch.randn(1, T, H_v, V, dtype=dtype, device=device)
g = torch.randn(1, T, H_v, dtype=dtype, device=device)
beta = torch.rand(1, T, H_v, dtype=dtype, device=device).sigmoid()
initial_state = torch.zeros(N, H_v, V, K, dtype=dtype, device=device)
cu_seqlens = torch.tensor([0, T], dtype=torch.long, device=device)

fla_chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)

del q, k, v, g, beta, initial_state, cu_seqlens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of the one-shot warmup has a potential race condition. The _fla_kernels_warmed_up flag is checked and set without a lock. If multiple LLMEngine instances are initialized concurrently in different threads, they would share the Qwen3NextGatedDeltaNet class. This could lead to _warmup_fla_kernels being called simultaneously, causing the warmup to run multiple times, which is inefficient and could lead to an OOM error if it happens under memory pressure.

To ensure thread safety and guarantee the warmup runs exactly once, you should use a threading.Lock. Please also remember to add import threading at the top of the file.

    _fla_kernels_warmed_up: bool = False
    _warmup_lock = threading.Lock()

    @classmethod
    def _warmup_fla_kernels(cls, layer: "Qwen3NextGatedDeltaNet") -> None:
        """Warm up FLA Triton kernels by running them with small dummy inputs.

        During vLLM's profile_run, attn_metadata is None so _forward_core
        returns early and the FLA Triton kernels never execute. Since these
        kernels use @triton.autotune, their first invocation triggers
        benchmarking of all configurations. If that first invocation happens
        after the KV cache has consumed most GPU memory, the autotuner OOMs.

        This class method forces one execution with minimal tensors so
        autotuning happens while memory is still plentiful. It only runs
        once across all layers since they share the same autotune key values.
        """
        if cls._fla_kernels_warmed_up:
            return

        with cls._warmup_lock:
            if cls._fla_kernels_warmed_up:
                return
            cls._fla_kernels_warmed_up = True

            H_k = layer.num_k_heads // layer.tp_size
            H_v = layer.num_v_heads // layer.tp_size
            K = layer.head_k_dim
            V = layer.head_v_dim
            T = 128
            N = 1
            device = next(layer.parameters()).device
            dtype = next(layer.parameters()).dtype
            if dtype == torch.float32:
                dtype = torch.bfloat16

            q = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
            k = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
            v = torch.randn(1, T, H_v, V, dtype=dtype, device=device)
            g = torch.randn(1, T, H_v, dtype=dtype, device=device)
            beta = torch.rand(1, T, H_v, dtype=dtype, device=device).sigmoid()
            initial_state = torch.zeros(N, H_v, V, K, dtype=dtype, device=device)
            cu_seqlens = torch.tensor([0, T], dtype=torch.long, device=device)

            fla_chunk_gated_delta_rule(
                q=q,
                k=k,
                v=v,
                g=g,
                beta=beta,
                initial_state=initial_state,
                output_final_state=True,
                cu_seqlens=cu_seqlens,
                use_qk_l2norm_in_kernel=True,
            )

            del q, k, v, g, beta, initial_state, cu_seqlens

@mergify
Copy link

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @oneraghavan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
…roject#34954)

The existing _warmup_prefill_kernels uses a per-instance hasattr check,
which means every GDN layer (40+ in Qwen3-Next) independently runs the
warmup even though Triton's autotune cache is process-global.  More
critically, the flag check-and-set has no synchronization: concurrent
LLMEngine initializations in different threads can race through the
guard and trigger overlapping warmup runs, risking OOM under memory
pressure.

Replace the instance-level hasattr guard with a class-level boolean flag
protected by a threading.Lock using double-checked locking.  This
guarantees exactly-once execution across all layers and all threads.

Also adds unit tests (mock-based, CPU) covering:
- Flag management and one-shot semantics
- Thread safety with 8 concurrent threads racing through a barrier
- Correct tensor shapes derived from layer attributes and tp_size
- _forward_core profile path (attn_metadata is None)

And GPU integration tests covering:
- FLA kernel basic execution and Qwen3-Next dimensions
- Warmup-then-memory-pressure scenario

Signed-off-by: raghavan <oneraghavan@gmail.com>
Made-with: Cursor
@oneraghavan oneraghavan force-pushed the fix/fla-triton-autotune-oom-34954 branch from 5ddd9e6 to 26f2e56 Compare March 18, 2026 03:47
@mergify mergify bot removed the needs-rebase label Mar 18, 2026
@oneraghavan
Copy link
Contributor Author

Closing — the core fix was merged in #36599 (March 12) by @xu-jinyang, which is more thorough (multiple BT values, error handling, empty_cache). The incremental thread-safety and test additions here don't justify keeping this open. Thanks for the review feedback @gemini-code-assist.

@gemini-code-assist
Copy link
Contributor

Thank you for the update, @oneraghavan! It's great to hear that a more thorough fix has been merged. I appreciate you taking the time to consider the feedback. Please feel free to reach out if you need any further reviews in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Triton Error [CUDA]: out of memory when received query

1 participant