Skip to content

[ROCm][CI] Fix spec decode profile assertion and logprob test determinism#35043

Merged
vllm-bot merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_v1_others
Feb 23, 2026
Merged

[ROCm][CI] Fix spec decode profile assertion and logprob test determinism#35043
vllm-bot merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_v1_others

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 22, 2026

Fixes two issues blocking spec decode logprob tests on ROCm:

  1. Profile run assertion failure (gpu_model_runner.py): The _dummy_run method asserted num_tokens <= self.scheduler_config.max_num_batched_tokens, but with speculative decoding max_num_tokens (which accounts for verification tokens) can exceed max_num_batched_tokens. Updated the assertion to use self.max_num_tokens, consistent with the rest of the runner.

  2. Non-deterministic logprob comparison (test_logprobs.py): The ref LLM and spec-decode LLM used different batch sizes, which on ROCm triggers non-associative floating-point reduction differences in attention/GEMM kernels. These numerical divergences were misattributed to spec decode incorrectness. Added ROCM_DETERMINISM_KWARGS (max_num_seqs=1) applied to both LLM instances on ROCm only, pinning identical execution paths. No behavioral change on other platforms.

Test Plan

pytest -s -v tests/v1/sample/test_logprobs.py

…ched_tokens

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…rob test

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@dosubot
Copy link

dosubot bot commented Feb 22, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Feb 22, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 22, 2026
@mergify
Copy link

mergify bot commented Feb 22, 2026

Hi @AndreasKaratzas, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request addresses two issues related to speculative decoding on ROCm: an assertion failure in gpu_model_runner.py and non-deterministic logprob comparisons in test_logprobs.py. The changes correctly update the assertion to use self.max_num_tokens and introduce ROCM_DETERMINISM_KWARGS to ensure deterministic execution for logprob tests on ROCm. The changes are well-explained and directly address the identified problems.

Comment on lines +44 to +50
ROCM_DETERMINISM_KWARGS: dict = (
dict(
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ROCM_DETERMINISM_KWARGS dictionary currently only sets max_num_seqs=1. The PR description mentions enforce_eager and async_scheduling=False as part of the determinism kwargs. These should also be included in the dictionary to fully align with the described fix and ensure consistent execution paths on ROCm.

Suggested change
ROCM_DETERMINISM_KWARGS: dict = (
dict(
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
ROCM_DETERMINISM_KWARGS: dict = (
dict(
enforce_eager=True,
async_scheduling=False,
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the description already, apparently those args were unnecessary.

@AndreasKaratzas
Copy link
Collaborator Author

AndreasKaratzas commented Feb 22, 2026

This PR depends on:

That's why pre-commit is failing.

EDIT: This is no longer true. I have reverted the change in vllm/v1/worker/gpu_model_runner.py in order to merge this PR quickly.

@AndreasKaratzas
Copy link
Collaborator Author

cc @LucasWilkinson

@AndreasKaratzas
Copy link
Collaborator Author

I'm going to revert the change in vllm/v1/worker/gpu_model_runner.py since this PR is critical for AMD CI.

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 23, 2026
@vllm-bot vllm-bot merged commit 5f68464 into vllm-project:main Feb 23, 2026
16 of 18 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 23, 2026
@AndreasKaratzas AndreasKaratzas deleted the akaratza_fix_v1_others branch February 23, 2026 17:42
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…nism (vllm-project#35043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…nism (vllm-project#35043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…nism (vllm-project#35043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…nism (vllm-project#35043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants