Skip to content

[Performance] Enable Triton autotuning disk cache by default#37188

Merged
zou3519 merged 3 commits intovllm-project:mainfrom
arpera:artem/enable-triton-cache-autotuning
Mar 19, 2026
Merged

[Performance] Enable Triton autotuning disk cache by default#37188
zou3519 merged 3 commits intovllm-project:mainfrom
arpera:artem/enable-triton-cache-autotuning

Conversation

@arpera
Copy link
Contributor

@arpera arpera commented Mar 16, 2026

Purpose

Triton's @triton.autotune decorator re-runs kernel autotuning on every process restart because TRITON_CACHE_AUTOTUNING defaults to disabled. For vLLM serving workloads this adds significant latency to the first inference request after each server start.

Set TRITON_CACHE_AUTOTUNING=1 via os.environ.setdefault so that autotuning results are persisted to TRITON_CACHE_DIR and reused across restarts. Users can still opt out by explicitly setting TRITON_CACHE_AUTOTUNING=0.

Test description

Setup: 8×B200, Qwen/Qwen3.5-397B-A17B-FP8, dp=8, expert parallelism enabled.

Server command:

TRITON_PRINT_AUTOTUNING=1 vllm serve Qwen/Qwen3.5-397B-A17B-FP8 \
  --port 8000 -tp 1 -pp 1 -dp 8 --enable-expert-parallel \
  --language-model-only --reasoning-parser qwen3 \
  --stream-interval 100 --safetensors-load-strategy prefetch

Benchmark command:

vllm bench serve \
  --backend vllm \
  --model Qwen/Qwen3.5-397B-A17B-FP8 \
  --port 8000 \
  --endpoint /v1/completions \
  --dataset-name random \
  --random-input 8000 \
  --random-output 1 \
  --max-concurrency 8 \
  --num-prompt 128 \
  --ignore-eos \
  --temperature 0.0

Before this change

1st run — started the server, ran the benchmark. Autotuning messages appeared in the log:

(Worker_DP4_EP4 pid=395399) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP5_EP5 pid=395400) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP4_EP4 pid=395399) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP5_EP5 pid=395400) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 2, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP4_EP4 pid=395399) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP5_EP5 pid=395400) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 4, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP5_EP5 pid=395400) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP4_EP4 pid=395399) Autotuning kernel chunk_local_cumsum_scalar_kernel with config num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None
(Worker_DP5_EP5 pid=395400) Triton autotuning for function chunk_local_cumsum_scalar_kernel,
(Worker_DP5_EP5 pid=395400) with key as (1, 64, 64, True, False, 'torch.float32', 'torch.float32', 'torch.int32', 'torch.int32'),
(Worker_DP5_EP5 pid=395400) finished after 0.56s,
(Worker_DP5_EP5 pid=395400) best config selected: num_warps: 1, num_ctas: 1, num_stages: 3, maxnreg: None;

2nd run — restarted the server with the exact same command, ran the same benchmark. The same autotuning messages appeared again. Triton did not reuse any cached results because TRITON_CACHE_AUTOTUNING is False by default — without it Triton ignores the disk cache entirely and re-runs autotuning from scratch on every process restart.

After this change

1st run — started the server, ran the benchmark. Autotuning messages appeared in the log (cache is cold, expected).

2nd run — restarted the server with the exact same command, ran the same benchmark. No autotuning messages in the log. Triton successfully loaded the cached autotuning results from disk and skipped re-autotuning.

Test result

Setup: 8x NVIDIA B200, Qwen/Qwen3.5-397B-A17B-FP8, dp=8, expert parallelism, 128 prompts, random input 8000 tokens, output 1 token, concurrency 8.

Metric 2nd run before change 2nd run after change
Benchmark duration (s) 121.49 39.45
Request throughput (req/s) 1.05 3.24
Total token throughput (tok/s) 8,430.07 25,961.99
Mean TTFT (ms) 7,555.32 2,449.88
Median TTFT (ms) 1,001.79 985.23
P99 TTFT (ms) 90,318.72 20,389.16
Successful requests 128 128
Failed requests 0 0

By caching Triton kernel autotuning results to disk, the 2nd benchmark run achieved a 3x speedup.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a significant performance issue by enabling Triton's autotuning disk cache by default. The use of os.environ.setdefault is appropriate, allowing existing user configurations to take precedence while providing a beneficial default. The accompanying comments clearly explain the purpose and impact of this change, including how users can override the setting. This is a valuable improvement for vLLM serving workloads, reducing latency on subsequent server starts.

Copy link
Collaborator

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Triton's `@triton.autotune` decorator re-runs kernel autotuning on
every process restart because `TRITON_CACHE_AUTOTUNING` defaults to
disabled. For vLLM serving workloads this adds significant latency
to the first inference request after each server start.

Set `TRITON_CACHE_AUTOTUNING=1` via `os.environ.setdefault` so that
autotuning results are persisted to `TRITON_CACHE_DIR` and reused
across restarts. Users can still opt out by explicitly setting
`TRITON_CACHE_AUTOTUNING=0`.

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Made-with: Cursor
@vadiklyutiy
Copy link
Collaborator

Likely your perf data was collected without #36599

Now it should be difference in runtime. But the speed of server run should be sufficient. Could you pls remeasure after sync with main

@arpera
Copy link
Contributor Author

arpera commented Mar 17, 2026

No, tests were done on top of PR #36599. PR #36599 has a bug in kernels warmup phase. Due to this bug we do autotuning of Triton's kernels during inference. That bug was fixed in PR #37338, have a look.

@vadiklyutiy vadiklyutiy added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
@vadiklyutiy
Copy link
Collaborator

@mgoin and @robertgshaw2-redhat
there were no auto assignment of reviewers. Change is simple. Could you take a look please

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see if it breaks anything :)

@mgoin mgoin added the performance Performance-related issues label Mar 19, 2026
Keep the related TORCHINDUCTOR_COMPILE_THREADS env var and
torch._inductor.config.compile_threads assignment adjacent.

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
@zou3519 zou3519 merged commit b55156e into vllm-project:main Mar 19, 2026
45 checks passed
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
…oject#37188)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants