diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 8f6e3f185b6a..a2f59edbb2e6 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -16,6 +16,8 @@ from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled from autogen.token_count_utils import count_token +from .rate_limiters import RateLimiter, TimeRateLimiter + TOOL_ENABLED = False try: import openai @@ -207,7 +209,9 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: """ iostream = IOStream.get_default() - completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] + completions: Completions = ( + self._oai_client.chat.completions if "messages" in params else self._oai_client.completions + ) # type: ignore [attr-defined] # If streaming is enabled and has messages, then iterate over the chunks of the response. if params.get("stream", False) and "messages" in params: response_contents = [""] * params.get("n", 1) @@ -427,8 +431,11 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base self._clients: List[ModelClient] = [] self._config_list: List[Dict[str, Any]] = [] + self._rate_limiters: List[Optional[RateLimiter]] = [] if config_list: + self._initialize_rate_limiters(config_list) + config_list = [config.copy() for config in config_list] # make a copy before modifying for config in config_list: self._register_default_client(config, openai_config) # could modify the config @@ -749,6 +756,7 @@ def yes_or_no_filter(context, response): return response continue # filter is not passed; try the next config try: + self._throttle_api_calls(i) request_ts = get_current_ts() response = client.create(params) except APITimeoutError as err: @@ -1042,3 +1050,20 @@ def extract_text_or_completion_object( A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. """ return response.message_retrieval_function(response) + + def _throttle_api_calls(self, idx: int) -> None: + """Rate limit api calls.""" + if self._rate_limiters[idx]: + limiter = self._rate_limiters[idx] + + assert limiter is not None + limiter.sleep() + + def _initialize_rate_limiters(self, config_list: List[Dict[str, Any]]) -> None: + for config in config_list: + # Instantiate the rate limiter + if "api_rate_limit" in config: + self._rate_limiters.append(TimeRateLimiter(config["api_rate_limit"])) + del config["api_rate_limit"] + else: + self._rate_limiters.append(None) diff --git a/autogen/oai/rate_limiters.py b/autogen/oai/rate_limiters.py new file mode 100644 index 000000000000..4b84a7f99400 --- /dev/null +++ b/autogen/oai/rate_limiters.py @@ -0,0 +1,36 @@ +import time +from typing import Protocol + + +class RateLimiter(Protocol): + def sleep(self, *args, **kwargs): ... + + +class TimeRateLimiter: + """A class to implement a time-based rate limiter. + + This rate limiter ensures that a certain operation does not exceed a specified frequency. + It can be used to limit the rate of requests sent to a server or the rate of any repeated action. + """ + + def __init__(self, rate: float): + """ + Args: + rate (int): The frequency of the time-based rate limiter (NOT time interval). + """ + self._time_interval_seconds = 1.0 / rate + self._last_time_called = 0.0 + + def sleep(self, *args, **kwargs): + """Synchronously waits until enough time has passed to allow the next operation. + + If the elapsed time since the last operation is less than the required time interval, + this method will block the execution by sleeping for the remaining time. + """ + if self._elapsed_time() < self._time_interval_seconds: + time.sleep(self._time_interval_seconds - self._elapsed_time()) + + self._last_time_called = time.perf_counter() + + def _elapsed_time(self): + return time.perf_counter() - self._last_time_called diff --git a/test/oai/test_client.py b/test/oai/test_client.py index 443ec995de48..bd8b072e6127 100755 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -4,6 +4,7 @@ import shutil import sys import time +from types import SimpleNamespace import pytest @@ -31,6 +32,40 @@ OAI_CONFIG_LIST = "OAI_CONFIG_LIST" +class _MockClient: + def __init__(self, config, **kwargs): + pass + + def create(self, params): + # can create my own data response class + # here using SimpleNamespace for simplicity + # as long as it adheres to the ModelClientResponseProtocol + + response = SimpleNamespace() + response.choices = [] + response.model = "mock_model" + + text = "this is a dummy text response" + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = text + choice.message.function_call = None + response.choices.append(choice) + return response + + def message_retrieval(self, response): + choices = response.choices + return [choice.message.content for choice in choices] + + def cost(self, response) -> float: + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response): + return {} + + @pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_aoai_chat_completion(): config_list = config_list_from_json( @@ -322,6 +357,32 @@ def test_cache(): assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED))) +def test_throttled_api_calls(): + # Api calling limited at 0.2 request per second, or 1 request per 5 seconds + rate = 1 / 5.0 + + config_list = [ + { + "model": "mock_model", + "model_client_cls": "_MockClient", + # Adding a timeout to catch false positives + "timeout": 1 / rate, + "api_rate_limit": rate, + } + ] + + client = OpenAIWrapper(config_list=config_list, cache_seed=None) + client.register_model_client(_MockClient) + + n_loops = 2 + current_time = time.time() + for _ in range(n_loops): + client.create(messages=[{"role": "user", "content": "hello"}]) + + min_expected_time = (n_loops - 1) / rate + assert time.time() - current_time > min_expected_time + + if __name__ == "__main__": # test_aoai_chat_completion() # test_oai_tool_calling_extraction() @@ -329,5 +390,6 @@ def test_cache(): test_completion() # # test_cost() # test_usage_summary() - # test_legacy_cache() - # test_cache() + test_legacy_cache() + test_cache() + test_throttled_api_calls() diff --git a/test/oai/test_rate_limiters.py b/test/oai/test_rate_limiters.py new file mode 100644 index 000000000000..a04429c0dea2 --- /dev/null +++ b/test/oai/test_rate_limiters.py @@ -0,0 +1,21 @@ +import time + +import pytest + +from autogen.oai.rate_limiters import TimeRateLimiter + + +@pytest.mark.parametrize("execute_n_times", range(5)) +def test_time_rate_limiter(execute_n_times): + current_time_seconds = time.time() + + rate = 1 + rate_limiter = TimeRateLimiter(rate) + + n_loops = 2 + for _ in range(n_loops): + rate_limiter.sleep() + + total_time = time.time() - current_time_seconds + min_expected_time = (n_loops - 1) / rate + assert total_time >= min_expected_time diff --git a/website/docs/FAQ.mdx b/website/docs/FAQ.mdx index 2798ae9375b2..a367a9b20635 100644 --- a/website/docs/FAQ.mdx +++ b/website/docs/FAQ.mdx @@ -37,7 +37,15 @@ Yes. You currently have two options: - Autogen can work with any API endpoint which complies with OpenAI-compatible RESTful APIs - e.g. serving local LLM via FastChat or LM Studio. Please check https://microsoft.github.io/autogen/blog/2023/07/14/Local-LLMs for an example. - You can supply your own custom model implementation and use it with Autogen. Please check https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models for more information. -## Handle Rate Limit Error and Timeout Error +## Handling API Rate Limits + +### Setting the API Rate Limit + +You can set the `api_rate_limit` in a `config_list` for an agent, which will be used to control the rate at which API requests are sent. + +- `api_rate_limit` (float): the maximum number of API requests allowed per second. + +### Handle Rate Limit Error and Timeout Error You can set `max_retries` to handle rate limit error. And you can set `timeout` to handle timeout error. They can all be specified in `llm_config` for an agent, which will be used in the OpenAI client for LLM inference. They can be set differently for different clients if they are set in the `config_list`. diff --git a/website/docs/topics/llm_configuration.ipynb b/website/docs/topics/llm_configuration.ipynb index 0c094f6531ed..a9c42592a866 100644 --- a/website/docs/topics/llm_configuration.ipynb +++ b/website/docs/topics/llm_configuration.ipynb @@ -63,6 +63,7 @@ " \n", " - `model` (str, required): The identifier of the model to be used, such as 'gpt-4', 'gpt-3.5-turbo'.\n", " - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n", + " - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n", " - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n", " - `tags` (List[str], optional): Tags which can be used for filtering.\n", "\n", @@ -72,6 +73,7 @@ " {\n", " \"model\": \"gpt-4\",\n", " \"api_key\": os.environ['OPENAI_API_KEY']\n", + " \"api_rate_limit\": 60.0, // Set to allow up to 60 API requests per second.\n", " }\n", " ]\n", " ```\n", @@ -80,6 +82,7 @@ " - `model` (str, required): The deployment to be used. The model corresponds to the deployment name on Azure OpenAI.\n", " - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n", " - `api_type`: `azure`\n", + " - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n", " - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n", " - `api_version` (str, optional): The version of the Azure API you wish to use.\n", " - `tags` (List[str], optional): Tags which can be used for filtering.\n", @@ -100,6 +103,7 @@ " \n", " - `model` (str, required): The identifier of the model to be used, such as 'llama-7B'.\n", " - `api_key` (str, optional): The API key required for authenticating requests to the model's API endpoint.\n", + " - `api_rate_limit` (float, optional): Specifies the maximum number of API requests permitted per second.\n", " - `base_url` (str, optional): The base URL of the API endpoint. This is the root address where API calls are directed.\n", " - `tags` (List[str], optional): Tags which can be used for filtering.\n", "\n",