diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 42584938bc06..7f421a757cc2 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -556,6 +556,7 @@ def test_spec_decode_logprobs( seed=42, logprobs_mode=logprobs_mode, gpu_memory_utilization=0.4, + dtype=DTYPE, ) ref_results = ref_llm.generate([prompt], sampling_params) # Collect logprobs outputs from reference LLM. @@ -582,6 +583,7 @@ def test_spec_decode_logprobs( seed=42, logprobs_mode=logprobs_mode, gpu_memory_utilization=0.4, + dtype=DTYPE, ) spec_results = spec_llm.generate([prompt], sampling_params) # Collect logprobs outputs from spec decode LLM.