Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/v1/sample/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform

from ...conftest import HfRunner, VllmRunner

Expand All @@ -31,6 +32,23 @@
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT

# On ROCm, floating-point reductions in attention and GEMM kernels are
# non-associative and sensitive to batch geometry. The ref LLM (no spec
# decode, default scheduling) and the spec-decode LLM (chunked prefill,
# different effective batch sizes) follow different reduction orders,
# producing numerically divergent logprobs that get mis-attributed to
# spec-decode incorrectness.
#
# Force LLM instances into an identical, deterministic execution
# mode so the test isolates spec-decode correctness only:
ROCM_DETERMINISM_KWARGS: dict = (
dict(
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
Comment on lines +44 to +50
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ROCM_DETERMINISM_KWARGS dictionary currently only sets max_num_seqs=1. The PR description mentions enforce_eager and async_scheduling=False as part of the determinism kwargs. These should also be included in the dictionary to fully align with the described fix and ensure consistent execution paths on ROCm.

Suggested change
ROCM_DETERMINISM_KWARGS: dict = (
dict(
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
ROCM_DETERMINISM_KWARGS: dict = (
dict(
enforce_eager=True,
async_scheduling=False,
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the description already, apparently those args were unnecessary.



@pytest.fixture(
scope="module",
Expand Down Expand Up @@ -1035,6 +1053,7 @@ def test_spec_decode_logprobs(
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
Expand Down Expand Up @@ -1064,6 +1083,7 @@ def test_spec_decode_logprobs(
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
Expand Down