Skip to content

[Bugfix][V1] Warm up slot mapping before JIT monitor#42165

Open
lesj0610 wants to merge 1 commit into
vllm-project:mainfrom
lesj0610:lesj/v1-slot-mapping-jit-warmup-upstream-20260509
Open

[Bugfix][V1] Warm up slot mapping before JIT monitor#42165
lesj0610 wants to merge 1 commit into
vllm-project:mainfrom
lesj0610:lesj/v1-slot-mapping-jit-warmup-upstream-20260509

Conversation

@lesj0610
Copy link
Copy Markdown
Contributor

@lesj0610 lesj0610 commented May 9, 2026

Purpose

#40137 added Triton JIT monitoring that activates after warmup finishes. But V1 warmup path (_dummy_run()) never calls BlockTable.compute_slot_mapping(). So when first real request comes in, _compute_slot_mapping_kernel compiles while JIT monitor is already active. Users see unexpected compilation warning during normal inference.

Problem has two sides. First, V1 warmup simply does not exercise the slot mapping path. Second, _compute_slot_mapping_kernel was specialized on num_tokens parameter, meaning even if we warm up with one token count, different request size triggers recompilation again.

Fix is also two parts.

I add do_not_specialize=["num_tokens"] to the kernel so one compilation covers all request sizes. max_num_tokens stays specialized — it is constant for engine lifetime and Triton can optimize the padding loop with it.

I also add small warmup in warmup_v1_slot_mapping_kernel() that calls compute_slot_mapping() directly before JIT monitor activates. It temporarily uses block id 1 (block 0 is null block), then clears in finally block. This runs on all PP ranks because every rank calls compute_slot_mapping() during input preparation.

I did not add synthetic execute_model() warmup. That needs model-specific dummy inputs and is not safe for all model types. This PR only covers slot mapping kernel.

V2 warmup path is not touched. V1 sampler warmup is not touched.

Checked open PRs, no existing PR for this issue.

Test Plan

.venv/bin/python -m pytest tests/v1/worker/test_gpu_model_runner.py -v

pre-commit run ruff-format --files \
  vllm/v1/worker/block_table.py \
  vllm/v1/worker/gpu/warmup.py \
  vllm/v1/worker/gpu_worker.py \
  tests/v1/worker/test_gpu_model_runner.py

pre-commit run ruff-check --files \
  vllm/v1/worker/block_table.py \
  vllm/v1/worker/gpu/warmup.py \
  vllm/v1/worker/gpu_worker.py \
  tests/v1/worker/test_gpu_model_runner.py

pre-commit run mypy-3.10 --files \
  vllm/v1/worker/block_table.py \
  vllm/v1/worker/gpu/warmup.py \
  vllm/v1/worker/gpu_worker.py \
  tests/v1/worker/test_gpu_model_runner.py \
  --hook-stage manual

git diff --check

Test Result

tests/v1/worker/test_gpu_model_runner.py: 34 passed, 16 warnings.

ruff format / ruff check: passed.

mypy-3.10: passed.

git diff --check: passed.

Local smoke on V1 runner with Qwen3-8B text-only: HTTP 200, no _compute_slot_mapping_kernel warning on first request.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR
  • The test plan, such as providing test command.
  • The test results

AI assistance was used (Codex, Claude).

@lesj0610 lesj0610 requested review from WoosukKwon and njhill as code owners May 9, 2026 13:13
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added v1 bug Something isn't working labels May 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warmup mechanism for the V1 slot mapping kernel to ensure it is compiled before the JIT monitor is enabled. Key changes include the implementation of warmup_v1_slot_mapping_kernel, its integration into the GPUWorker warmup sequence, and a Triton JIT optimization to prevent specialization on num_tokens. New unit tests verify the warmup process and its error handling. I have no feedback to provide.

@lesj0610
Copy link
Copy Markdown
Contributor Author

lesj0610 commented May 9, 2026

@ZJY0516 @qiching @tdoublep @vadiklyutiy Hi, this is follow-up fix for #40137.

I found V1 path still triggers JIT warning on first real request. _dummy_run() does not call BlockTable.compute_slot_mapping(), and the kernel was specialized on num_tokens so each different request size recompiles.

Fix is small — do_not_specialize=["num_tokens"] and direct compute_slot_mapping() call during V1 warmup before monitor activates. No synthetic execute_model() warmup, that is model-specific so I keep scope to slot mapping only.

You all reviewed #40137 so your feedback would be very helpful. Thanks.

Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
@lesj0610 lesj0610 force-pushed the lesj/v1-slot-mapping-jit-warmup-upstream-20260509 branch from cad2699 to 877e619 Compare May 9, 2026 13:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant