Fix/fla triton autotune oom 34954#36384
Fix/fla triton autotune oom 34954#36384oneraghavan wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a potential out-of-memory error during Triton kernel autotuning for Flash Linear Attention by pre-warming the kernels during the profile_run phase. The approach is sound and the accompanying tests are thorough. I've identified a potential race condition in the one-shot warmup logic and suggested a fix to ensure thread safety.
| _fla_kernels_warmed_up: bool = False | ||
|
|
||
| @classmethod | ||
| def _warmup_fla_kernels(cls, layer: "Qwen3NextGatedDeltaNet") -> None: | ||
| """Warm up FLA Triton kernels by running them with small dummy inputs. | ||
|
|
||
| During vLLM's profile_run, attn_metadata is None so _forward_core | ||
| returns early and the FLA Triton kernels never execute. Since these | ||
| kernels use @triton.autotune, their first invocation triggers | ||
| benchmarking of all configurations. If that first invocation happens | ||
| after the KV cache has consumed most GPU memory, the autotuner OOMs. | ||
|
|
||
| This class method forces one execution with minimal tensors so | ||
| autotuning happens while memory is still plentiful. It only runs | ||
| once across all layers since they share the same autotune key values. | ||
| """ | ||
| if cls._fla_kernels_warmed_up: | ||
| return | ||
| cls._fla_kernels_warmed_up = True | ||
|
|
||
| H_k = layer.num_k_heads // layer.tp_size | ||
| H_v = layer.num_v_heads // layer.tp_size | ||
| K = layer.head_k_dim | ||
| V = layer.head_v_dim | ||
| T = 128 | ||
| N = 1 | ||
| device = next(layer.parameters()).device | ||
| dtype = next(layer.parameters()).dtype | ||
| if dtype == torch.float32: | ||
| dtype = torch.bfloat16 | ||
|
|
||
| q = torch.randn(1, T, H_k, K, dtype=dtype, device=device) | ||
| k = torch.randn(1, T, H_k, K, dtype=dtype, device=device) | ||
| v = torch.randn(1, T, H_v, V, dtype=dtype, device=device) | ||
| g = torch.randn(1, T, H_v, dtype=dtype, device=device) | ||
| beta = torch.rand(1, T, H_v, dtype=dtype, device=device).sigmoid() | ||
| initial_state = torch.zeros(N, H_v, V, K, dtype=dtype, device=device) | ||
| cu_seqlens = torch.tensor([0, T], dtype=torch.long, device=device) | ||
|
|
||
| fla_chunk_gated_delta_rule( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| beta=beta, | ||
| initial_state=initial_state, | ||
| output_final_state=True, | ||
| cu_seqlens=cu_seqlens, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
|
|
||
| del q, k, v, g, beta, initial_state, cu_seqlens |
There was a problem hiding this comment.
The current implementation of the one-shot warmup has a potential race condition. The _fla_kernels_warmed_up flag is checked and set without a lock. If multiple LLMEngine instances are initialized concurrently in different threads, they would share the Qwen3NextGatedDeltaNet class. This could lead to _warmup_fla_kernels being called simultaneously, causing the warmup to run multiple times, which is inefficient and could lead to an OOM error if it happens under memory pressure.
To ensure thread safety and guarantee the warmup runs exactly once, you should use a threading.Lock. Please also remember to add import threading at the top of the file.
_fla_kernels_warmed_up: bool = False
_warmup_lock = threading.Lock()
@classmethod
def _warmup_fla_kernels(cls, layer: "Qwen3NextGatedDeltaNet") -> None:
"""Warm up FLA Triton kernels by running them with small dummy inputs.
During vLLM's profile_run, attn_metadata is None so _forward_core
returns early and the FLA Triton kernels never execute. Since these
kernels use @triton.autotune, their first invocation triggers
benchmarking of all configurations. If that first invocation happens
after the KV cache has consumed most GPU memory, the autotuner OOMs.
This class method forces one execution with minimal tensors so
autotuning happens while memory is still plentiful. It only runs
once across all layers since they share the same autotune key values.
"""
if cls._fla_kernels_warmed_up:
return
with cls._warmup_lock:
if cls._fla_kernels_warmed_up:
return
cls._fla_kernels_warmed_up = True
H_k = layer.num_k_heads // layer.tp_size
H_v = layer.num_v_heads // layer.tp_size
K = layer.head_k_dim
V = layer.head_v_dim
T = 128
N = 1
device = next(layer.parameters()).device
dtype = next(layer.parameters()).dtype
if dtype == torch.float32:
dtype = torch.bfloat16
q = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
k = torch.randn(1, T, H_k, K, dtype=dtype, device=device)
v = torch.randn(1, T, H_v, V, dtype=dtype, device=device)
g = torch.randn(1, T, H_v, dtype=dtype, device=device)
beta = torch.rand(1, T, H_v, dtype=dtype, device=device).sigmoid()
initial_state = torch.zeros(N, H_v, V, K, dtype=dtype, device=device)
cu_seqlens = torch.tensor([0, T], dtype=torch.long, device=device)
fla_chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
del q, k, v, g, beta, initial_state, cu_seqlens|
This pull request has merge conflicts that must be resolved before it can be |
…roject#34954) The existing _warmup_prefill_kernels uses a per-instance hasattr check, which means every GDN layer (40+ in Qwen3-Next) independently runs the warmup even though Triton's autotune cache is process-global. More critically, the flag check-and-set has no synchronization: concurrent LLMEngine initializations in different threads can race through the guard and trigger overlapping warmup runs, risking OOM under memory pressure. Replace the instance-level hasattr guard with a class-level boolean flag protected by a threading.Lock using double-checked locking. This guarantees exactly-once execution across all layers and all threads. Also adds unit tests (mock-based, CPU) covering: - Flag management and one-shot semantics - Thread safety with 8 concurrent threads racing through a barrier - Correct tensor shapes derived from layer attributes and tp_size - _forward_core profile path (attn_metadata is None) And GPU integration tests covering: - FLA kernel basic execution and Qwen3-Next dimensions - Warmup-then-memory-pressure scenario Signed-off-by: raghavan <oneraghavan@gmail.com> Made-with: Cursor
5ddd9e6 to
26f2e56
Compare
|
Closing — the core fix was merged in #36599 (March 12) by @xu-jinyang, which is more thorough (multiple BT values, error handling, empty_cache). The incremental thread-safety and test additions here don't justify keeping this open. Thanks for the review feedback @gemini-code-assist. |
|
Thank you for the update, @oneraghavan! It's great to hear that a more thorough fix has been merged. I appreciate you taking the time to consider the feedback. Please feel free to reach out if you need any further reviews in the future. |
Fixes #34954
Qwen3-Next / Qwen3.5 models use Flash Linear Attention (FLA) with Triton kernels that are decorated with @triton.autotune. When these kernels are invoked for the first time, Triton benchmarks all candidate configurations (up to 24 per kernel across 8 kernel files), each requiring temporary GPU memory allocations.
During vLLM's initialization sequence:
profile_run() — runs a dummy forward pass to measure memory usage. attn_metadata is set to None, so Qwen3NextGatedDeltaNet._forward_core returns early and the FLA Triton kernels never execute.
initialize_kv_cache() — allocates KV cache blocks, consuming nearly all remaining GPU memory.
First real inference request — FLA kernels execute for the first time, triggering @triton.autotune benchmarking. With only a few hundred MB of free memory, the autotuner's temporary allocations fail with torch.cuda.OutOfMemoryError.
Fix
Add a _warmup_fla_kernels() class method to Qwen3NextGatedDeltaNet that runs fla_chunk_gated_delta_rule with small dummy tensors. This is called during profile_run() (when attn_metadata is None), forcing Triton to autotune while GPU memory is still plentiful. The cached autotune results are reused for all subsequent kernel invocations without needing further benchmarking.
Key properties:
One-shot: A class-level flag (_fla_kernels_warmed_up) ensures warmup runs exactly once across all GatedDeltaNet layers, since they share the same autotune key values (H, K, V, BT).
Minimal overhead: Uses small tensors (T=128, single sequence) — just enough to trigger autotuning.
fp32 safety: Casts fp32 parameters to bf16 for warmup, since the FLA kernels don't support fp32.