Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type issues in openai_utils.py #2062

Merged
merged 6 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
mypy \
autogen/logger \
autogen/exception_utils.py \
autogen/coding
autogen/coding \
autogen/oai/openai_utils.py
65 changes: 36 additions & 29 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,8 @@

from dotenv import find_dotenv, load_dotenv

try:
from openai import OpenAI
from openai.types.beta.assistant import Assistant

ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 to use autogen.OpenAIWrapper.")
OpenAI = object
Assistant = object
from openai import OpenAI
from openai.types.beta.assistant import Assistant

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
Expand Down Expand Up @@ -75,7 +68,7 @@ def get_key(config: Dict[str, Any]) -> str:
return json.dumps(config, sort_keys=True)


def is_valid_api_key(api_key: str):
def is_valid_api_key(api_key: str) -> bool:
"""Determine if input is valid OpenAI API key.

Args:
Expand All @@ -89,8 +82,11 @@ def is_valid_api_key(api_key: str):


def get_config_list(
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> List[Dict]:
api_keys: List[str],
base_urls: Optional[List[str]] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get a list of configs for OpenAI API client.

Args:
Expand Down Expand Up @@ -143,7 +139,7 @@ def config_list_openai_aoai(
openai_api_base_file: Optional[str] = "base_openai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API).

This function constructs configurations by reading API keys and base URLs from environment variables or text files.
Expand Down Expand Up @@ -250,8 +246,8 @@ def config_list_openai_aoai(
else []
)
# process openai base urls
base_urls = os.environ.get("OPENAI_API_BASE", None)
base_urls = base_urls if base_urls is None else base_urls.split("\n")
base_urls_env_var = os.environ.get("OPENAI_API_BASE", None)
base_urls = base_urls_env_var if base_urls_env_var is None else base_urls_env_var.split("\n")
openai_config = (
get_config_list(
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
Expand All @@ -271,8 +267,8 @@ def config_list_from_models(
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
model_list: Optional[list] = None,
) -> List[Dict]:
model_list: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Get a list of configs for API calls with models specified in the model list.

Expand Down Expand Up @@ -338,7 +334,7 @@ def config_list_gpt4_gpt35(
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls.

Args:
Expand All @@ -361,7 +357,10 @@ def config_list_gpt4_gpt35(
)


def filter_config(config_list, filter_dict):
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
) -> List[Dict[str, Any]]:
"""
This function filters `config_list` by checking each configuration dictionary against the
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
Expand Down Expand Up @@ -426,7 +425,7 @@ def filter_config(config_list, filter_dict):
dictionaries that do not have that key will also be considered a match.
"""

def _satisfies(config_value, acceptable_values):
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
Expand All @@ -445,7 +444,7 @@ def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""
Retrieves a list of API configurations from a JSON stored in an environment variable or a file.

Expand Down Expand Up @@ -497,15 +496,22 @@ def config_list_from_json(
else:
# The environment variable does not exist.
# So, `env_or_file` is a filename. We should use the file location.
config_list_path = os.path.join(file_location, env_or_file)
if file_location is not None:
config_list_path = os.path.join(file_location, env_or_file)
else:
config_list_path = env_or_file

with open(config_list_path) as json_file:
config_list = json.load(json_file)
return filter_config(config_list, filter_dict)


def get_config(
api_key: str, base_url: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> Dict:
api_key: Optional[str],
base_url: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> Dict[str, Any]:
"""
Constructs a configuration dictionary for a single model with the provided API configurations.

Expand Down Expand Up @@ -544,7 +550,9 @@ def get_config(


def config_list_from_dotenv(
dotenv_file_path: Optional[str] = None, model_api_key_map: Optional[dict] = None, filter_dict: Optional[dict] = None
dotenv_file_path: Optional[str] = None,
model_api_key_map: Optional[Dict[str, Any]] = None,
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
) -> List[Dict[str, Union[str, Set[str]]]]:
"""
Load API configurations from a specified .env file or environment variables and construct a list of configurations.
Expand Down Expand Up @@ -582,9 +590,10 @@ def config_list_from_dotenv(
else:
logging.warning(f"The specified .env file {dotenv_path} does not exist.")
else:
dotenv_path = find_dotenv()
if not dotenv_path:
dotenv_path_str = find_dotenv()
if not dotenv_path_str:
logging.warning("No .env file found. Loading configurations from environment variables.")
dotenv_path = Path(dotenv_path_str)
load_dotenv(dotenv_path)

# Ensure the model_api_key_map is not None to prevent TypeErrors during key assignment.
Expand Down Expand Up @@ -647,8 +656,6 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
"""
Return the assistants with the given name from OAI assistant API
"""
if ERROR:
raise ERROR
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data:
Expand Down
Loading