Skip to content

Commit e6c044b

Browse files
afourneysonichi
andauthored
Added the ability to add tags to the OAI_CONFIG_LIST, and filter (microsoft#1226)
* Added the ability to add tags to the OAI_CONFIG_LIST, and filter on them. * Update openai_utils.py Co-authored-by: Chi Wang <[email protected]> --------- Co-authored-by: Chi Wang <[email protected]>
1 parent df5fe57 commit e6c044b

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

autogen/oai/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class OpenAIWrapper:
5252
"""A wrapper class for openai client."""
5353

5454
cache_path_root: str = ".cache"
55-
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
55+
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
5656
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
5757
total_usage_summary: Optional[Dict[str, Any]] = None
5858
actual_usage_summary: Optional[Dict[str, Any]] = None

autogen/oai/openai_utils.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,11 @@ def filter_config(config_list, filter_dict):
356356
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
357357
field name to check within the configuration dictionaries, and the
358358
corresponding value is a list of acceptable values for that field.
359+
If the configuration's field's value is not a list, then a match occurs
360+
when it is found in the list of acceptable values. If the configuration's
361+
field's value is a list, then a match occurs if there is a non-empty
362+
intersection with the acceptable values.
363+
359364
360365
Returns:
361366
list of dict: A list of configuration dictionaries that meet all the criteria specified
@@ -368,6 +373,7 @@ def filter_config(config_list, filter_dict):
368373
{'model': 'gpt-3.5-turbo'},
369374
{'model': 'gpt-4'},
370375
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
376+
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
371377
]
372378
373379
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
@@ -382,6 +388,19 @@ def filter_config(config_list, filter_dict):
382388
383389
# The resulting `filtered_configs` will be:
384390
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
391+
392+
393+
# Define a filter to select a given tag
394+
filter_criteria = {
395+
'tags': ['gpt35_turbo'],
396+
}
397+
398+
# Apply the filter to the configuration list
399+
filtered_configs = filter_config(configs, filter_criteria)
400+
401+
# The resulting `filtered_configs` will be:
402+
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
403+
385404
```
386405
387406
Note:
@@ -391,9 +410,18 @@ def filter_config(config_list, filter_dict):
391410
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
392411
dictionaries that do not have that key will also be considered a match.
393412
"""
413+
414+
def _satisfies(config_value, acceptable_values):
415+
if isinstance(config_value, list):
416+
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
417+
else:
418+
return config_value in acceptable_values
419+
394420
if filter_dict:
395421
config_list = [
396-
config for config in config_list if all(config.get(key) in value for key, value in filter_dict.items())
422+
config
423+
for config in config_list
424+
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
397425
]
398426
return config_list
399427

test/oai/test_utils.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
import autogen # noqa: E402
11-
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION
11+
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config
1212

1313
# Example environment variables
1414
ENV_VARS = {
@@ -48,6 +48,7 @@
4848
},
4949
{
5050
"model": "gpt-35-turbo-v0301",
51+
"tags": ["gpt-3.5-turbo", "gpt35_turbo"],
5152
"api_key": "111113fc7e8a46419bfac511bb301111",
5253
"base_url": "https://1111.openai.azure.com",
5354
"api_type": "azure",
@@ -342,5 +343,32 @@ def test_get_config_list():
342343
assert len(config_list_with_empty_key) == 2, "The config_list should exclude configurations with empty api_keys."
343344

344345

346+
def test_tags():
347+
config_list = json.loads(JSON_SAMPLE)
348+
349+
target_list = filter_config(config_list, {"model": ["gpt-35-turbo-v0301"]})
350+
assert len(target_list) == 1
351+
352+
list_1 = filter_config(config_list, {"tags": ["gpt35_turbo"]})
353+
assert len(list_1) == 1
354+
assert list_1[0] == target_list[0]
355+
356+
list_2 = filter_config(config_list, {"tags": ["gpt-3.5-turbo"]})
357+
assert len(list_2) == 1
358+
assert list_2[0] == target_list[0]
359+
360+
list_3 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "gpt35_turbo"]})
361+
assert len(list_3) == 1
362+
assert list_3[0] == target_list[0]
363+
364+
# Will still match because there's a non-empty intersection
365+
list_4 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "does_not_exist"]})
366+
assert len(list_4) == 1
367+
assert list_4[0] == target_list[0]
368+
369+
list_5 = filter_config(config_list, {"tags": ["does_not_exist"]})
370+
assert len(list_5) == 0
371+
372+
345373
if __name__ == "__main__":
346374
pytest.main()

0 commit comments

Comments
 (0)