diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 329f286683f1..8a384dd8463f 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -20,6 +20,7 @@ from vllm import SamplingParams from vllm.config.model import LogprobsMode from vllm.distributed import cleanup_dist_env_and_memory +from vllm.platforms import current_platform from ...conftest import HfRunner, VllmRunner @@ -31,6 +32,23 @@ PROMPT = BatchLogprobsComposition.PROMPT SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT +# On ROCm, floating-point reductions in attention and GEMM kernels are +# non-associative and sensitive to batch geometry. The ref LLM (no spec +# decode, default scheduling) and the spec-decode LLM (chunked prefill, +# different effective batch sizes) follow different reduction orders, +# producing numerically divergent logprobs that get mis-attributed to +# spec-decode incorrectness. +# +# Force LLM instances into an identical, deterministic execution +# mode so the test isolates spec-decode correctness only: +ROCM_DETERMINISM_KWARGS: dict = ( + dict( + max_num_seqs=1, + ) + if current_platform.is_rocm() + else {} +) + @pytest.fixture( scope="module", @@ -1035,6 +1053,7 @@ def test_spec_decode_logprobs( logprobs_mode=logprobs_mode, gpu_memory_utilization=0.4, enable_prefix_caching=False, + **ROCM_DETERMINISM_KWARGS, ) ref_results = ref_llm.generate( [prompt, prompt], [sampling_params, penalty_sampling_params] @@ -1064,6 +1083,7 @@ def test_spec_decode_logprobs( enable_chunked_prefill=True, max_num_batched_tokens=32, enable_prefix_caching=False, + **ROCM_DETERMINISM_KWARGS, ) spec_results = spec_llm.generate( [prompt, prompt], [sampling_params, penalty_sampling_params]