Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client class utilities #2949

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Utilities for client classes"""

import warnings
from typing import Any, Dict, List, Optional, Tuple


def validate_parameter(
params: Dict[str, Any],
param_name: str,
allowed_types: Tuple,
allow_None: bool,
default_value: Any,
numerical_bound: Tuple,
allowed_values: list,
) -> Any:
"""
Validates a given config parameter, checking its type, values, and setting defaults
Parameters:
params (Dict[str, Any]): Dictionary containing parameters to validate.
param_name (str): The name of the parameter to validate.
allowed_types (Tuple): Tuple of acceptable types for the parameter.
allow_None (bool): Whether the parameter can be `None`.
default_value (Any): The default value to use if the parameter is invalid or missing.
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
A tuple specifying the lower and upper bounds for numerical parameters.
Each bound can be `None` if not applicable.
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
Can be `None` if no specific values are required.

Returns:
Any: The validated parameter value or the default value if validation fails.

Raises:
TypeError: If `allowed_values` is provided but is not a list.

Example Usage:
```python
# Validating a numerical parameter within specific bounds
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
# Result: 0.5

# Validating a parameter that can be one of a list of allowed values
model = validate_parameter(
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
)
# If "safety_model" is missing or invalid in params, defaults to "default"
```
"""

if allowed_values is not None and not isinstance(allowed_values, list):
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")

param_value = params.get(param_name, default_value)
warning = ""

if param_value is None and allow_None:
pass
elif param_value is None:
if not allow_None:
warning = "cannot be None"
elif not isinstance(param_value, allowed_types):
# Check types and list possible types if invalid
if isinstance(allowed_types, tuple):
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
else:
formatted_types = f"{allowed_types.__name__}"
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
elif numerical_bound:
# Check the value fits in possible bounds
lower_bound, upper_bound = numerical_bound
if (lower_bound is not None and param_value < lower_bound) or (
upper_bound is not None and param_value > upper_bound
):
warning = "has numerical bounds"
if lower_bound is not None:
warning += f", >= {str(lower_bound)}"
if upper_bound is not None:
if lower_bound is not None:
warning += " and"
warning += f" <= {str(upper_bound)}"
if allow_None:
warning += ", or can be None"

elif allowed_values:
# Check if the value matches any allowed values
if not (allow_None and param_value is None):
if param_value not in allowed_values:
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"

# If we failed any checks, warn and set to default value
if warning:
warnings.warn(
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
UserWarning,
)
param_value = default_value

return param_value
136 changes: 136 additions & 0 deletions test/oai/test_client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3 -m pytest

import pytest

import autogen
from autogen.oai.client_utils import validate_parameter


def test_validate_parameter():
# Test valid parameters
params = {
"model": "Qwen/Qwen2-72B-Instruct",
"max_tokens": 1000,
"stream": False,
"temperature": 1,
"top_p": 0.8,
"top_k": 50,
"repetition_penalty": 0.5,
"presence_penalty": 1.5,
"frequency_penalty": 1.5,
"min_p": 0.2,
"safety_model": "Meta-Llama/Llama-Guard-7b",
}

# Should return the original value as they are valid
assert params["model"] == validate_parameter(params, "model", str, False, None, None, None)
assert params["max_tokens"] == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
assert params["stream"] == validate_parameter(params, "stream", bool, False, False, None, None)
assert params["temperature"] == validate_parameter(params, "temperature", (int, float), True, None, None, None)
assert params["top_k"] == validate_parameter(params, "top_k", int, True, None, None, None)
assert params["repetition_penalty"] == validate_parameter(
params, "repetition_penalty", float, True, None, None, None
)
assert params["presence_penalty"] == validate_parameter(
params, "presence_penalty", (int, float), True, None, (-2, 2), None
)
assert params["safety_model"] == validate_parameter(params, "safety_model", str, True, None, None, None)

# Test None allowed
params = {
"max_tokens": None,
}

# Should remain None
assert validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) is None

# Test not None allowed
params = {
"max_tokens": None,
}

# Should return default
assert 512 == validate_parameter(params, "max_tokens", int, False, 512, (0, None), None)

# Test invalid parameters
params = {
"stream": "Yes",
"temperature": "0.5",
"top_p": "0.8",
"top_k": "50",
"repetition_penalty": "0.5",
"presence_penalty": "1.5",
"frequency_penalty": "1.5",
"min_p": "0.2",
"safety_model": False,
}

# Should all be set to defaults
assert validate_parameter(params, "stream", bool, False, False, None, None) is not None
assert validate_parameter(params, "temperature", (int, float), True, None, None, None) is None
assert validate_parameter(params, "top_p", (int, float), True, None, None, None) is None
assert validate_parameter(params, "top_k", int, True, None, None, None) is None
assert validate_parameter(params, "repetition_penalty", float, True, None, None, None) is None
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None
assert validate_parameter(params, "safety_model", str, True, None, None, None) is None

# Test parameters outside of bounds
params = {
"max_tokens": -200,
"presence_penalty": -5,
"frequency_penalty": 5,
"min_p": -0.5,
}

# Should all be set to defaults
assert 512 == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None

# Test valid list options
params = {
"safety_model": "Meta-Llama/Llama-Guard-7b",
}

# Should all be set to defaults
assert "Meta-Llama/Llama-Guard-7b" == validate_parameter(
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
)

# Test invalid list options
params = {
"stream": True,
}

# Should all be set to defaults
assert not validate_parameter(params, "stream", bool, False, False, None, [False])

# test invalid type
params = {
"temperature": None,
}

# should be set to defaults
assert validate_parameter(params, "temperature", (int, float), False, 0.7, (0.0, 1.0), None) == 0.7

# test value out of bounds
params = {
"temperature": 23,
}

# should be set to defaults
assert validate_parameter(params, "temperature", (int, float), False, 1.0, (0.0, 1.0), None) == 1.0

# type error for the parameters
with pytest.raises(TypeError):
validate_parameter({}, "param", str, True, None, None, "not_a_list")

# passing empty params, which will set to defaults
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512


if __name__ == "__main__":
test_validate_parameter()
Loading