Skip to content

Commit d44edb6

Browse files
marklyszeHk669
authored andcommitted
Client class utilities (#2949)
* Addition of client utilities, initially for parameter validation * Corrected test * update: type checks and few tests * fix: docs, tests --------- Co-authored-by: Hk669 <[email protected]>
1 parent 51971fd commit d44edb6

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

autogen/oai/client_utils.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Utilities for client classes"""
2+
3+
import warnings
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
7+
def validate_parameter(
8+
params: Dict[str, Any],
9+
param_name: str,
10+
allowed_types: Tuple,
11+
allow_None: bool,
12+
default_value: Any,
13+
numerical_bound: Tuple,
14+
allowed_values: list,
15+
) -> Any:
16+
"""
17+
Validates a given config parameter, checking its type, values, and setting defaults
18+
Parameters:
19+
params (Dict[str, Any]): Dictionary containing parameters to validate.
20+
param_name (str): The name of the parameter to validate.
21+
allowed_types (Tuple): Tuple of acceptable types for the parameter.
22+
allow_None (bool): Whether the parameter can be `None`.
23+
default_value (Any): The default value to use if the parameter is invalid or missing.
24+
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
25+
A tuple specifying the lower and upper bounds for numerical parameters.
26+
Each bound can be `None` if not applicable.
27+
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
28+
Can be `None` if no specific values are required.
29+
30+
Returns:
31+
Any: The validated parameter value or the default value if validation fails.
32+
33+
Raises:
34+
TypeError: If `allowed_values` is provided but is not a list.
35+
36+
Example Usage:
37+
```python
38+
# Validating a numerical parameter within specific bounds
39+
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
40+
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
41+
# Result: 0.5
42+
43+
# Validating a parameter that can be one of a list of allowed values
44+
model = validate_parameter(
45+
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
46+
)
47+
# If "safety_model" is missing or invalid in params, defaults to "default"
48+
```
49+
"""
50+
51+
if allowed_values is not None and not isinstance(allowed_values, list):
52+
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")
53+
54+
param_value = params.get(param_name, default_value)
55+
warning = ""
56+
57+
if param_value is None and allow_None:
58+
pass
59+
elif param_value is None:
60+
if not allow_None:
61+
warning = "cannot be None"
62+
elif not isinstance(param_value, allowed_types):
63+
# Check types and list possible types if invalid
64+
if isinstance(allowed_types, tuple):
65+
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
66+
else:
67+
formatted_types = f"{allowed_types.__name__}"
68+
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
69+
elif numerical_bound:
70+
# Check the value fits in possible bounds
71+
lower_bound, upper_bound = numerical_bound
72+
if (lower_bound is not None and param_value < lower_bound) or (
73+
upper_bound is not None and param_value > upper_bound
74+
):
75+
warning = "has numerical bounds"
76+
if lower_bound is not None:
77+
warning += f", >= {str(lower_bound)}"
78+
if upper_bound is not None:
79+
if lower_bound is not None:
80+
warning += " and"
81+
warning += f" <= {str(upper_bound)}"
82+
if allow_None:
83+
warning += ", or can be None"
84+
85+
elif allowed_values:
86+
# Check if the value matches any allowed values
87+
if not (allow_None and param_value is None):
88+
if param_value not in allowed_values:
89+
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"
90+
91+
# If we failed any checks, warn and set to default value
92+
if warning:
93+
warnings.warn(
94+
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
95+
UserWarning,
96+
)
97+
param_value = default_value
98+
99+
return param_value

test/oai/test_client_utils.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python3 -m pytest
2+
3+
import pytest
4+
5+
import autogen
6+
from autogen.oai.client_utils import validate_parameter
7+
8+
9+
def test_validate_parameter():
10+
# Test valid parameters
11+
params = {
12+
"model": "Qwen/Qwen2-72B-Instruct",
13+
"max_tokens": 1000,
14+
"stream": False,
15+
"temperature": 1,
16+
"top_p": 0.8,
17+
"top_k": 50,
18+
"repetition_penalty": 0.5,
19+
"presence_penalty": 1.5,
20+
"frequency_penalty": 1.5,
21+
"min_p": 0.2,
22+
"safety_model": "Meta-Llama/Llama-Guard-7b",
23+
}
24+
25+
# Should return the original value as they are valid
26+
assert params["model"] == validate_parameter(params, "model", str, False, None, None, None)
27+
assert params["max_tokens"] == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
28+
assert params["stream"] == validate_parameter(params, "stream", bool, False, False, None, None)
29+
assert params["temperature"] == validate_parameter(params, "temperature", (int, float), True, None, None, None)
30+
assert params["top_k"] == validate_parameter(params, "top_k", int, True, None, None, None)
31+
assert params["repetition_penalty"] == validate_parameter(
32+
params, "repetition_penalty", float, True, None, None, None
33+
)
34+
assert params["presence_penalty"] == validate_parameter(
35+
params, "presence_penalty", (int, float), True, None, (-2, 2), None
36+
)
37+
assert params["safety_model"] == validate_parameter(params, "safety_model", str, True, None, None, None)
38+
39+
# Test None allowed
40+
params = {
41+
"max_tokens": None,
42+
}
43+
44+
# Should remain None
45+
assert validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) is None
46+
47+
# Test not None allowed
48+
params = {
49+
"max_tokens": None,
50+
}
51+
52+
# Should return default
53+
assert 512 == validate_parameter(params, "max_tokens", int, False, 512, (0, None), None)
54+
55+
# Test invalid parameters
56+
params = {
57+
"stream": "Yes",
58+
"temperature": "0.5",
59+
"top_p": "0.8",
60+
"top_k": "50",
61+
"repetition_penalty": "0.5",
62+
"presence_penalty": "1.5",
63+
"frequency_penalty": "1.5",
64+
"min_p": "0.2",
65+
"safety_model": False,
66+
}
67+
68+
# Should all be set to defaults
69+
assert validate_parameter(params, "stream", bool, False, False, None, None) is not None
70+
assert validate_parameter(params, "temperature", (int, float), True, None, None, None) is None
71+
assert validate_parameter(params, "top_p", (int, float), True, None, None, None) is None
72+
assert validate_parameter(params, "top_k", int, True, None, None, None) is None
73+
assert validate_parameter(params, "repetition_penalty", float, True, None, None, None) is None
74+
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
75+
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
76+
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None
77+
assert validate_parameter(params, "safety_model", str, True, None, None, None) is None
78+
79+
# Test parameters outside of bounds
80+
params = {
81+
"max_tokens": -200,
82+
"presence_penalty": -5,
83+
"frequency_penalty": 5,
84+
"min_p": -0.5,
85+
}
86+
87+
# Should all be set to defaults
88+
assert 512 == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
89+
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
90+
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
91+
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None
92+
93+
# Test valid list options
94+
params = {
95+
"safety_model": "Meta-Llama/Llama-Guard-7b",
96+
}
97+
98+
# Should all be set to defaults
99+
assert "Meta-Llama/Llama-Guard-7b" == validate_parameter(
100+
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
101+
)
102+
103+
# Test invalid list options
104+
params = {
105+
"stream": True,
106+
}
107+
108+
# Should all be set to defaults
109+
assert not validate_parameter(params, "stream", bool, False, False, None, [False])
110+
111+
# test invalid type
112+
params = {
113+
"temperature": None,
114+
}
115+
116+
# should be set to defaults
117+
assert validate_parameter(params, "temperature", (int, float), False, 0.7, (0.0, 1.0), None) == 0.7
118+
119+
# test value out of bounds
120+
params = {
121+
"temperature": 23,
122+
}
123+
124+
# should be set to defaults
125+
assert validate_parameter(params, "temperature", (int, float), False, 1.0, (0.0, 1.0), None) == 1.0
126+
127+
# type error for the parameters
128+
with pytest.raises(TypeError):
129+
validate_parameter({}, "param", str, True, None, None, "not_a_list")
130+
131+
# passing empty params, which will set to defaults
132+
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512
133+
134+
135+
if __name__ == "__main__":
136+
test_validate_parameter()

0 commit comments

Comments
 (0)