diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index 3b58614e58d4..151b5d97ddf3 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -70,4 +70,5 @@ def test_cpu_offload_compressed_tensors(monkeypatch): ["--enforce_eager"], ["--enforce_eager", "--cpu-offload-gb", "1"], max_wait_seconds=480, + include_seeded_sampling=False, ) diff --git a/tests/utils.py b/tests/utils.py index 5ccdaa0d64e2..982dbbb063a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -656,6 +656,7 @@ def _test_completion( model: str, prompt: str, token_ids: list[int], + include_seeded_sampling: bool = True, ): results = [] @@ -690,33 +691,40 @@ def _test_completion( } ) - # test seeded random sampling - completion = client.completions.create( - model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0 - ) + if include_seeded_sampling: + # test seeded random sampling + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0 + ) - results.append( - { - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - } - ) + results.append( + { + "test": "seeded_sampling", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) - # test seeded random sampling with multiple prompts - completion = client.completions.create( - model=model, prompt=[prompt, prompt], max_tokens=5, seed=33, temperature=1.0 - ) + # test seeded random sampling with multiple prompts + completion = client.completions.create( + model=model, + prompt=[prompt, prompt], + max_tokens=5, + seed=33, + temperature=1.0, + ) - results.append( - { - "test": "seeded_sampling", - "text": [choice.text for choice in completion.choices], - "finish_reason": [choice.finish_reason for choice in completion.choices], - "usage": completion.usage, - } - ) + results.append( + { + "test": "seeded_sampling", + "text": [choice.text for choice in completion.choices], + "finish_reason": [ + choice.finish_reason for choice in completion.choices + ], + "usage": completion.usage, + } + ) # test simple list batch = client.completions.create( @@ -911,6 +919,7 @@ def compare_two_settings( *, method: str = "generate", max_wait_seconds: float | None = None, + include_seeded_sampling: bool = True, ) -> None: """ Launch API server with two different sets of arguments/environments @@ -922,6 +931,8 @@ def compare_two_settings( arg2: The second set of arguments to pass to the API server. env1: The first set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server. + include_seeded_sampling: Whether to include temperature=1.0 seeded + sampling checks in the default generate comparison. """ compare_all_settings( @@ -930,6 +941,7 @@ def compare_two_settings( [env1, env2], method=method, max_wait_seconds=max_wait_seconds, + include_seeded_sampling=include_seeded_sampling, ) @@ -940,6 +952,7 @@ def compare_all_settings( *, method: str = "generate", max_wait_seconds: float | None = None, + include_seeded_sampling: bool = True, ) -> None: """ Launch API server with several different sets of arguments/environments @@ -948,6 +961,8 @@ def compare_all_settings( model: The model to test. all_args: A list of argument lists to pass to the API server. all_envs: A list of environment dictionaries to pass to the API server. + include_seeded_sampling: Whether to include temperature=1.0 seeded + sampling checks in the default generate comparison. """ trust_remote_code = False @@ -1008,7 +1023,13 @@ def compare_all_settings( ) if method == "generate": - results += _test_completion(client, model, prompt, token_ids) + results += _test_completion( + client, + model, + prompt, + token_ids, + include_seeded_sampling=include_seeded_sampling, + ) elif method == "generate_close": results += _test_completion_close(client, model, prompt) elif method == "generate_chat":