Skip to content

[Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling#36599

Merged
ywang96 merged 6 commits intovllm-project:mainfrom
AuYang261:fix/gdn-triton-warmup
Mar 12, 2026
Merged

[Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling#36599
ywang96 merged 6 commits intovllm-project:mainfrom
AuYang261:fix/gdn-triton-warmup

Conversation

@AuYang261
Copy link
Contributor

@AuYang261 AuYang261 commented Mar 10, 2026

Purpose

Fix Triton autotuner OOM for Qwen3.5 / Qwen3-Next models with Gated Delta Net (GDN) linear attention layers.

As is mentioned in #36598, during V1 profile runs, _forward_core in Qwen3NextGatedDeltaNet returns early when attn_metadata is None, so the Triton-autotuned kernels used by GDN (solve_tril, chunk_scaled_dot_kkt, chunk_gated_delta_rule_fwd_h, chunk_fwd_o) are never invoked. After profiling, vLLM allocates KV cache using most of the remaining GPU memory. The first real inference then triggers the Triton autotuner, which needs temporary GPU memory for benchmarking kernel configurations, causing an OOM on GPUs where post-allocation headroom is tight.

This only affects GPUs that use the Triton-based forward_native path (i.e., non-SM90 GPUs — anything other than H100/H200), since SM90 uses the FlashInfer forward_cuda path which has no Triton autotuner.

Fix: Add _warmup_triton_kernels() which runs a minimal forward pass through chunk_gated_delta_rule with small dummy tensors (B=1, T=64) during the V1 profile phase, forcing Triton autotuning to complete while GPU memory is still plentiful. Autotuner results are cached globally by Triton, so only the first GDN layer incurs actual benchmarking cost.

Test Plan

Run Qwen3.5-9B inference on a non-SM90 GPU (tested on RTX 5090, SM120, 32GB VRAM):

python -c "
from vllm import LLM, SamplingParams
llm = LLM(model='Qwen/Qwen3.5-9B', tensor_parallel_size=1, max_model_len=4096, enforce_eager=True)
output = llm.generate(['Hello'], SamplingParams(max_tokens=16))
print(output[0].outputs[0].text)
"

Test Result

Before fix

Triton OOM during first inference:

triton Error [CUDA]: out of memory
...
File ".../vllm/model_executor/layers/fla/ops/solve_tril.py", line 63, in solve_tril_kernel

After fix

All 24 GDN layers warm up successfully during profiling, model generates normally:

(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:14 [qwen3_next.py:700] GDN Triton kernel warmup completed for layer language_model.model.layers.0.linear_attn
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:14 [qwen3_next.py:700] GDN Triton kernel warmup completed for layer language_model.model.layers.1.linear_attn
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:14 [qwen3_next.py:700] GDN Triton kernel warmup completed for layer language_model.model.layers.2.linear_attn
... (24 GDN layers total)
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:15 [qwen3_next.py:700] GDN Triton kernel warmup completed for layer language_model.model.layers.30.linear_attn
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:16 [gpu_worker.py:456] Available KV cache memory: 4.93 GiB
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:16 [kv_cache_utils.py:1314] GPU KV cache size: 40,128 tokens
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:16 [kv_cache_utils.py:1319] Maximum concurrency for 4,096 tokens per request: 27.73x
(EngineCore_DP0 pid=3501575) 2026-03-10 14:25:16,214 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
(EngineCore_DP0 pid=3501575) 2026-03-10 14:25:16,289 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:16 [core.py:288] init engine (profile, create kv cache, warmup model) took 15.96 seconds
(EngineCore_DP0 pid=3501575) INFO 03-10 14:25:17 [vllm.py:754] Asynchronous scheduling is enabled.
INFO 03-10 14:25:17 [llm.py:391] Supported tasks: ['generate']
Rendering prompts: 100%|█████| 1/1 [00:00<00:00, 37.93it/s]
Processed prompts: 100%|█████| 1/1 [00:00<00:00, 1.19it/s, est. speed input: 1.19 toks/s, output: 19.09 toks/s]
.
Please correct my writing.
yesterday, i jumped on a platform

@AuYang261 AuYang261 requested a review from sighingnow as a code owner March 10, 2026 06:37
@mergify mergify bot added qwen Related to Qwen models bug Something isn't working labels Mar 10, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warmup mechanism for Triton kernels in GDN layers to prevent out-of-memory errors during the first inference. The implementation is sound, adding a _warmup_triton_kernels method that is invoked during the V1 profiling run. My review identifies a minor logic bug in the logging within the new warmup function, where a success message is logged even if the warmup fails. I've provided a suggestion to correct this behavior.

Signed-off-by: AuYang <459461160@qq.com>
@ZJY0516
Copy link
Member

ZJY0516 commented Mar 10, 2026

And it seems that triton kernel will autotune if these keys change

@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],

I'm afraid your warmup is not enough

@AuYang261
Copy link
Contributor Author

The warmup tensor shapes and parameters are chosen to match the exact autotune key values that the real prefill path produces. In varlen mode, B=1 and IS_VARLEN=True (cu_seqlens is not None) for both warmup and real inference. BT=64 is the hardcoded chunk_size in chunk_gated_delta_rule_fwd. REVERSE=False is the default, and chunk_local_cumsum is called without passing reverse, so it stays False. H, K, V are determined by the model config, which is the same in both cases. So the autotuner cache entries created during warmup will be reused by real inference without re-benchmarking.

The one edge case is chunk_fwd_o, where BT = min(64, max(16, next_power_of_2(T))) when FLA_GDN_FIX_BT is unset. If a prefill request has fewer than 64 tokens, BT could be 16 or 32, triggering a re-autotune for that single kernel. However, this is a lightweight re-tune (one kernel, not the full chain) and unlikely to OOM given the small tensor sizes involved.

@ZJY0516
Copy link
Member

ZJY0516 commented Mar 10, 2026

If a prefill request has fewer than 64 tokens, BT could be 16 or 32, triggering a re-autotune for that single kernel. However, this is a lightweight re-tune (one kernel, not the full chain) and unlikely to OOM given the small tensor sizes involved.

Could you please help to verify this? (Need to remove your triton cache first)

Signed-off-by: AuYang <459461160@qq.com>
@AuYang261
Copy link
Contributor Author

Thanks for the suggestion. I've updated the warmup to run three passes with T=16, 32, and 64, covering all possible BT values for chunk_fwd_kernel_o (BT = min(64, max(16, next_power_of_2(T)))). The other kernels always use BT = chunk_size (64), so they are fully cached after the first pass.

Verified on V100-32GB with Triton cache fully cleared (rm -rf ~/.triton/cache/*):

20 tokens (BT=32): no OOM, inference succeeded
46 tokens (BT=64): no OOM, inference succeeded
No autotune cache miss is possible at inference time now.

Test script:

import os
import shutil
import subprocess
import sys


def clear_triton_cache():
    """Remove Triton kernel cache to force re-autotuning."""
    cache_dir = os.environ.get(
        "TRITON_CACHE_DIR",
        os.path.expanduser("~/.triton/cache"),
    )
    if os.path.isdir(cache_dir):
        entries = os.listdir(cache_dir)
        if entries:
            print(
                f"[Setup] Clearing Triton cache at {cache_dir} "
                f"({len(entries)} entries)"
            )
            shutil.rmtree(cache_dir)
            os.makedirs(cache_dir, exist_ok=True)
        else:
            print(f"[Setup] Triton cache already empty: {cache_dir}")
    else:
        print(f"[Setup] Triton cache dir does not exist: {cache_dir}")


def gpu_mem_info() -> str:
    """Return a human-readable GPU memory summary."""
    import torch

    if not torch.cuda.is_available():
        return "CUDA not available"
    alloc = torch.cuda.memory_allocated() / (1024**3)
    reserved = torch.cuda.memory_reserved() / (1024**3)
    total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    return (
        f"allocated={alloc:.2f}GB, reserved={reserved:.2f}GB, " f"total={total:.2f}GB"
    )


def test_short_prompt():
    """Run Qwen3.5-9B with a very short prompt (< 64 tokens).

    After chat template is applied, the total token count should be
    around 20-30 tokens, resulting in BT = 32 in chunk_fwd_o
    (next_power_of_2(~27) = 32).

    This tests the edge case where warmup uses BT=64 but inference
    uses BT=32, triggering a re-tune of chunk_fwd_kernel_o only.
    """
    from vllm import LLM, SamplingParams

    import torch

    print(f"\n[GPU] Before model load: {gpu_mem_info()}")

    model_name = "Qwen/Qwen3.5-9B"
    print(f"\n[Test] Loading model: {model_name}")
    llm = LLM(
        model=model_name,
        tensor_parallel_size=1,
        max_model_len=4096,
        enforce_eager=True,
    )
    print(f"[GPU] After model load: {gpu_mem_info()}")

    tokenizer = llm.get_tokenizer()

    # --- Test 1: Short prompt (< 64 tokens → BT=32) ---
    short_message = [
        {"role": "system", "content": "You are helpful."},
        {"role": "user", "content": "Hi"},
    ]
    short_prompt = tokenizer.apply_chat_template(
        short_message,
        add_generation_prompt=True,
        tokenize=False,
    )
    short_tokens = tokenizer(short_prompt, add_special_tokens=False)
    n_tokens = len(short_tokens["input_ids"])
    import triton as _triton

    expected_bt = min(64, max(16, _triton.next_power_of_2(n_tokens)))
    print(
        f"\n[Test 1] Short prompt: {n_tokens} tokens " f"(expected BT = {expected_bt})"
    )

    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.9,
        max_tokens=32,
        stop_token_ids=[tokenizer.eos_token_id],
    )

    print(f"[GPU] Before inference: {gpu_mem_info()}")
    print("[Test 1] Running inference with short prompt...")

    try:
        outputs = llm.generate(
            prompts=[short_prompt],
            sampling_params=sampling_params,
        )
        print(f"[GPU] After inference: {gpu_mem_info()}")
        text = outputs[0].outputs[0].text
        print(
            f"[Test 1] PASSED - Output ({len(text)} chars): "
            f"{text[:100]}{'...' if len(text) > 100 else ''}"
        )
    except:
        print("[Test 1] FAILED - OOM during short prompt inference!")
        print(f"[GPU] At OOM: {gpu_mem_info()}")
        sys.exit(1)

    # --- Test 2: Exact 64 tokens (BT=64, same as warmup) ---
    # This should always work since BT matches warmup.
    long_content = "Explain " + "very " * 20 + "briefly what gravity is."
    long_message = [
        {"role": "system", "content": "You are helpful."},
        {"role": "user", "content": long_content},
    ]
    long_prompt = tokenizer.apply_chat_template(
        long_message,
        add_generation_prompt=True,
        tokenize=False,
    )
    long_tokens = tokenizer(long_prompt, add_special_tokens=False)
    n_tokens_long = len(long_tokens["input_ids"])
    expected_bt_long = min(64, max(16, _triton.next_power_of_2(n_tokens_long)))
    print(
        f"\n[Test 2] Longer prompt: {n_tokens_long} tokens "
        f"(expected BT = {expected_bt_long})"
    )

    print("[Test 2] Running inference with longer prompt...")
    try:
        outputs = llm.generate(
            prompts=[long_prompt],
            sampling_params=sampling_params,
        )
        print(f"[GPU] After inference: {gpu_mem_info()}")
        text = outputs[0].outputs[0].text
        print(
            f"[Test 2] PASSED - Output ({len(text)} chars): "
            f"{text[:100]}{'...' if len(text) > 100 else ''}"
        )
    except:
        print("[Test 2] FAILED - OOM during longer prompt inference!")
        sys.exit(1)

    print("\n" + "=" * 60)
    print("ALL TESTS PASSED")
    print("The GDN prefill kernel warmup fix works correctly.")
    print("=" * 60)


if __name__ == "__main__":
    # Step 1: Clear Triton cache
    clear_triton_cache()

    # Step 2: Run the test
    test_short_prompt()

Copy link
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM. cc @vadiklyutiy

Signed-off-by: AuYang <459461160@qq.com>
@ZJY0516
Copy link
Member

ZJY0516 commented Mar 12, 2026

Could you please update branch and run ci again? then we can merge this

@ywang96 ywang96 merged commit 3e64fe4 into vllm-project:main Mar 12, 2026
52 checks passed
tdoublep added a commit to tdoublep/vllm that referenced this pull request Mar 12, 2026
FlashInfer's chunk_gated_delta_rule returns a single tensor when
output_final_state=False, but the wrapper always unpacked two values.
This caused a ValueError during GDN kernel warmup (added in vllm-project#36599)
on SM90 GPUs (H100/H200).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added a commit to tdoublep/vllm that referenced this pull request Mar 12, 2026
PR vllm-project#36599 added Triton autotuner warmup for GDN layers, but it runs
unconditionally on all GPU types. On SM90 GPUs (H100/H200), the
FlashInfer path is used instead of Triton, making the warmup both
unnecessary and broken (FlashInfer returns a single tensor when
output_final_state=False, causing a ValueError).

Skip the warmup entirely on SM90 since FlashInfer has no Triton
autotuner to warm up.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added a commit to tdoublep/vllm that referenced this pull request Mar 12, 2026
PR vllm-project#36599 added Triton autotuner warmup for GDN layers, but it runs
unconditionally on all GPU types. On SM90 GPUs (H100/H200), the
FlashInfer forward_cuda path is used instead of Triton, making the
warmup both unnecessary and broken (FlashInfer returns a single tensor
when output_final_state=False, causing a ValueError).

Skip the warmup entirely on SM90 since FlashInfer has no Triton
autotuner to warm up.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added a commit to tdoublep/vllm that referenced this pull request Mar 12, 2026
FlashInfer's chunk_gated_delta_rule returns a single tensor when
output_final_state=False, but the wrapper always unpacked two values.
This caused a ValueError during GDN kernel warmup (added in vllm-project#36599)
on SM90 GPUs (H100/H200).

Handle the return value based on output_final_state: unpack the tuple
when True, use the single tensor when False.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added a commit to tdoublep/vllm that referenced this pull request Mar 12, 2026
FlashInfer's chunk_gated_delta_rule returns a single tensor when
output_final_state=False, but the wrapper always unpacked two values.
This caused a ValueError during GDN kernel warmup (added in vllm-project#36599)
on SM90 GPUs (H100/H200).

Handle the return value based on output_final_state: unpack the tuple
when True, use the single tensor when False.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@vadiklyutiy
Copy link
Collaborator

during V1 profile runs, _forward_core in Qwen3NextGatedDeltaNet returns early when attn_metadata is None, so the Triton-autotuned kernels used by GDN (solve_tril, chunk_scaled_dot_kkt, chunk_gated_delta_rule_fwd_h, chunk_fwd_o) are never invoked.

Does somebody know why None are passed to _forward_core? It would be a natural way to warmup Triton kernels.

@ZJY0516
Copy link
Member

ZJY0516 commented Mar 13, 2026

during V1 profile runs, _forward_core in Qwen3NextGatedDeltaNet returns early when attn_metadata is None, so the Triton-autotuned kernels used by GDN (solve_tril, chunk_scaled_dot_kkt, chunk_gated_delta_rule_fwd_h, chunk_fwd_o) are never invoked.

Does somebody know why None are passed to _forward_core? It would be a natural way to warmup Triton kernels.

I think it's a historical reason from the v1 design — using None to indicate a profile run.

haosdent added a commit to haosdent/vllm that referenced this pull request Mar 15, 2026
…vllm-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` that runs during
`profile_run()` inside the `memory_profiling()` context. Triton
compiles GDN kernels, loading CUDA modules (~1.5 GiB) that
permanently reduce KV cache budget because the overhead is captured
as `non_torch_increase`.

Move GDN warmup from `_forward_core` (during profiling) to
`kernel_warmup()` (after KV cache allocation) so the Triton CUDA
module overhead comes from `gpu_memory_utilization` headroom instead.

Verified on Qwen3.5-35B-A3B (NVIDIA GB10):
- KV cache memory: 36.18 → 37.55 GiB (+1.37 GiB)
- KV cache tokens: 474,144 → 492,096 (+3.8%)
- Init time: 74.39s → 22.41s (3.3x faster)

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
haosdent added a commit to haosdent/vllm that referenced this pull request Mar 16, 2026
…-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` during `profile_run()`.
Triton's `@triton.autotune` compiles ALL config variants as CUDA
modules (~1.5 GiB across ~300 configs) that permanently reduce the
KV cache budget via `non_torch_increase`.

Fix by running GDN warmup **before** `memory_profiling()` when GPU
memory is plentiful, then calling `cuModuleUnload` on every compiled
CUDA module to free the memory.  The Triton autotuner's in-memory
cache (winning Config per key) is preserved — subsequent kernel calls
skip benchmarking and compile only the winning config (~1 module per
kernel instead of ~20).

This addresses both sides of the problem:
- KV cache is maximized (CUDA module overhead excluded from profiling)
- No OOM during warmup (runs before KV cache allocation)
- No first-inference OOM (autotuner cache pre-populated)

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
haosdent added a commit to haosdent/vllm that referenced this pull request Mar 16, 2026
…-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` during `profile_run()`.
Triton's `@triton.autotune` compiles ALL config variants as CUDA
modules (~1.5 GiB across ~300 configs) that permanently reduce the
KV cache budget via `non_torch_increase`.

Fix by running GDN warmup **before** `memory_profiling()` when GPU
memory is plentiful, then calling `cuModuleUnload` on every compiled
CUDA module to free the memory.  The Triton autotuner's in-memory
cache (winning Config per key) is preserved — subsequent kernel calls
skip benchmarking and compile only the winning config (~1 module per
kernel instead of ~20).

This addresses both sides of the problem:
- KV cache is maximized (CUDA module overhead excluded from profiling)
- No OOM during warmup (runs before KV cache allocation)
- No first-inference OOM (autotuner cache pre-populated)

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
haosdent added a commit to haosdent/vllm that referenced this pull request Mar 19, 2026
…vllm-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` during `profile_run()`.
Triton's `@triton.autotune` compiles ALL config variants as CUDA
modules (~1.3 GiB across ~137 configs).  This overhead is captured
in `non_torch_increase`, permanently reducing the KV cache budget.

The root cause is CUDA driver context memory: each compiled Triton
kernel module allocates ~10 MiB of GPU memory via `cuModuleLoad` that
cannot be freed — `cuModuleUnload` and Triton disk caching do not
help because the memory is allocated at module load time, not
compilation time.

Fix by running GDN autotuning in a **subprocess** with
`TRITON_CACHE_AUTOTUNING=1` before memory profiling.  The child
benchmarks all ~137 configs and saves winning configs to disk cache,
then exits (freeing all CUDA context memory).  The main process
enables `cache_results=True` on FLA Autotuner instances so the
autotuner reads winning configs from disk and compiles **only
those** — loading ~10 modules instead of ~137.

Verified on Qwen3.5-35B-A3B (NVIDIA GB10):
- KV cache memory: 36.18 → 38.56 GiB (+2.38 GiB recovered)
- non_torch_increase: 3.83 → 1.75 GiB (-2.08 GiB)
- Profiling time: 62s → 16s (3.9x faster)
- Inference: works, no OOM

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
haosdent added a commit to haosdent/vllm that referenced this pull request Mar 19, 2026
…vllm-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` during `profile_run()`.
Triton's `@triton.autotune` compiles ALL config variants as CUDA
modules (~1.5 GiB across ~157 configs).  This overhead is captured
in `non_torch_increase`, permanently reducing the KV cache budget.

The root cause is CUDA driver context memory from `cuModuleLoad` —
it cannot be freed by `cuModuleUnload` or Triton disk caching.

Fix with two changes:

1. **Skip GDN warmup during profiling**: `_forward_core()` returns
   early when `attn_metadata is None`, so no GDN CUDA modules are
   loaded during `memory_profiling()`.  KV cache gets the full
   memory budget.

2. **Single-config warmup after KV cache allocation**: In
   `kernel_warmup()`, temporarily swap each FLA Autotuner's configs
   to `[configs[0]]`.  With `len(configs) == 1`, Triton skips
   benchmarking entirely (autotuner.py:243-244) and compiles only
   1 module per autotune key.  This reduces the post-allocation
   overhead from ~1.5 GiB (157 modules) to ~1.2 GiB (non-autotuned
   kernels only), fitting within headroom.

Verified on Qwen3.5-35B-A3B (NVIDIA GB10):
- KV cache memory: 36.18 → 37.75 GiB (+1.57 GiB recovered)
- non_torch_increase: 3.83 → 2.55 GiB (-1.28 GiB)
- Profiling time: 62s → 17s (3.7x faster)
- Inference: works, no OOM

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
haosdent added a commit to haosdent/vllm that referenced this pull request Mar 19, 2026
…vllm-project#36973)

PR vllm-project#36599 added `_warmup_prefill_kernels()` during `profile_run()`.
Triton's `@triton.autotune` compiles ALL config variants as CUDA
modules (~1.5 GiB across ~157 configs).  This overhead is captured
in `non_torch_increase`, permanently reducing the KV cache budget.

The root cause is CUDA driver context memory from `cuModuleLoad` —
it cannot be freed by `cuModuleUnload` or Triton disk caching.

Fix with two changes:

1. **Skip GDN warmup during profiling**: `_forward_core()` returns
   early when `attn_metadata is None`, so no GDN CUDA modules are
   loaded during `memory_profiling()`.  KV cache gets the full
   memory budget.

2. **Single-config warmup after KV cache allocation**: In
   `kernel_warmup()`, temporarily swap each FLA Autotuner's configs
   to `[configs[0]]`.  With `len(configs) == 1`, Triton skips
   benchmarking entirely (autotuner.py:243-244) and compiles only
   1 module per autotune key.  This reduces the post-allocation
   overhead from ~1.5 GiB (157 modules) to ~1.2 GiB (non-autotuned
   kernels only), fitting within headroom.

Verified on Qwen3.5-35B-A3B (NVIDIA GB10):
- KV cache memory: 36.18 → 37.75 GiB (+1.57 GiB recovered)
- non_torch_increase: 3.83 → 2.55 GiB (-1.28 GiB)
- Profiling time: 62s → 17s (3.7x faster)
- Inference: works, no OOM

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants