Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/quantization/test_cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ def test_cpu_offload_compressed_tensors(monkeypatch):
["--enforce_eager"],
["--enforce_eager", "--cpu-offload-gb", "1"],
max_wait_seconds=480,
include_seeded_sampling=False,
)
71 changes: 46 additions & 25 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def _test_completion(
model: str,
prompt: str,
token_ids: list[int],
include_seeded_sampling: bool = True,
):
results = []

Expand Down Expand Up @@ -690,33 +691,40 @@ def _test_completion(
}
)

# test seeded random sampling
completion = client.completions.create(
model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0
)
if include_seeded_sampling:
# test seeded random sampling
completion = client.completions.create(
model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0
)

results.append(
{
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
}
)
results.append(
{
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
}
)

# test seeded random sampling with multiple prompts
completion = client.completions.create(
model=model, prompt=[prompt, prompt], max_tokens=5, seed=33, temperature=1.0
)
# test seeded random sampling with multiple prompts
completion = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0,
)

results.append(
{
"test": "seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason": [choice.finish_reason for choice in completion.choices],
"usage": completion.usage,
}
)
results.append(
{
"test": "seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason": [
choice.finish_reason for choice in completion.choices
],
"usage": completion.usage,
}
)

# test simple list
batch = client.completions.create(
Expand Down Expand Up @@ -911,6 +919,7 @@ def compare_two_settings(
*,
method: str = "generate",
max_wait_seconds: float | None = None,
include_seeded_sampling: bool = True,
) -> None:
"""
Launch API server with two different sets of arguments/environments
Expand All @@ -922,6 +931,8 @@ def compare_two_settings(
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
include_seeded_sampling: Whether to include temperature=1.0 seeded
sampling checks in the default generate comparison.
"""

compare_all_settings(
Expand All @@ -930,6 +941,7 @@ def compare_two_settings(
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
include_seeded_sampling=include_seeded_sampling,
)


Expand All @@ -940,6 +952,7 @@ def compare_all_settings(
*,
method: str = "generate",
max_wait_seconds: float | None = None,
include_seeded_sampling: bool = True,
) -> None:
"""
Launch API server with several different sets of arguments/environments
Expand All @@ -948,6 +961,8 @@ def compare_all_settings(
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
include_seeded_sampling: Whether to include temperature=1.0 seeded
sampling checks in the default generate comparison.
"""

trust_remote_code = False
Expand Down Expand Up @@ -1008,7 +1023,13 @@ def compare_all_settings(
)

if method == "generate":
results += _test_completion(client, model, prompt, token_ids)
results += _test_completion(
client,
model,
prompt,
token_ids,
include_seeded_sampling=include_seeded_sampling,
)
elif method == "generate_close":
results += _test_completion_close(client, model, prompt)
elif method == "generate_chat":
Expand Down
Loading