diff --git a/cognify/__main__.py b/cognify/__main__.py index 546e8cf..48d75f7 100644 --- a/cognify/__main__.py +++ b/cognify/__main__.py @@ -1,6 +1,8 @@ import argparse import logging import os +import multiprocessing as mp +import dotenv from cognify.cognify_args import ( init_cognify_args, @@ -18,6 +20,7 @@ from cognify.run.run import run from cognify._logging import _configure_logger from cognify._tracing import trace_cli_args, trace_workflow, initial_usage_message +from cognify.rate_limiter import run_rate_limiter logger = logging.getLogger(__name__) @@ -122,6 +125,11 @@ def main(): cognify_args = from_cognify_args(raw_args) + os.environ["_cognify_rate_limit_base_url"] = f"http://127.0.0.1:{raw_args.rate_limit_port}" + dotenv.load_dotenv(cognify_args.key_env) + rate_limit_process = mp.Process(target=run_rate_limiter, args=(raw_args.rate_limit_port,)) + rate_limit_process.start() + if raw_args.mode == "optimize" or raw_args.mode == "evaluate": workflow_path = cognify_args.workflow if not os.path.exists(workflow_path): @@ -137,6 +145,9 @@ def main(): run_routine(cognify_args) else: raise ValueError(f"Unknown mode: {raw_args.mode}") + + # end rate limiter + rate_limit_process.terminate() return diff --git a/cognify/cognify_args.py b/cognify/cognify_args.py index 418aa89..b8c813d 100644 --- a/cognify/cognify_args.py +++ b/cognify/cognify_args.py @@ -11,6 +11,8 @@ class CommonArgs: workflow: str config: str = None log_level: str = "WARNING" + rate_limit_port: int = 55555 + key_env: str = None def __post_init__(self): # Set missing values @@ -18,6 +20,7 @@ def __post_init__(self): def find_files(self): self.search_at_workflow_dir_if_not_set("config", "config.py") + self.search_at_workflow_dir_if_not_set("key_env", ".env") @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -46,6 +49,23 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], metavar="log_level", ) + parser.add_argument( + "-p", + "--rate_limit_port", + type=int, + default=CommonArgs.rate_limit_port, + help="Port number for rate limit server", + metavar="port_number", + ) + parser.add_argument( + "-k", + "--key_env", + type=str, + default=CommonArgs.key_env, + help="Path to the key env file for API access.\n" + "If not provided, will search .env in the same directory as workflow script.", + metavar="path_to_key_env", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "CommonArgs": diff --git a/cognify/llm/litellm_wrapper.py b/cognify/llm/litellm_wrapper.py index 5320cf3..5ab365f 100644 --- a/cognify/llm/litellm_wrapper.py +++ b/cognify/llm/litellm_wrapper.py @@ -1,7 +1,52 @@ from litellm import completion +from litellm.types.utils import ModelResponse from pydantic import BaseModel +import requests +import os + +import zmq + +class HTTPClient: + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(HTTPClient, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self._initialized = True + self.base_url = None + + def post_completion(self, data): + if self.base_url is None: + base_url = os.getenv("_cognify_rate_limit_base_url") + if base_url is None: + raise Exception("Rate limit base URL not found.") + self.base_url = base_url + + url = f"{self.base_url}/completion_endpoint" + response = requests.post(url, json=data) + if response.status_code != 200: + raise Exception(response.json().get("detail", "Unknown error")) + result = response.json()["result"] + hidden_params = result.pop('_hidden_params', None) + response_headers = result.pop('_response_headers', None) + response = ModelResponse(**result) + if hidden_params: + response._hidden_params = hidden_params + if response_headers: + response._response_headers = response_headers + return response + + +_client = HTTPClient() + def litellm_completion(model: str, messages: list, model_kwargs: dict, response_format: BaseModel = None): + if response_format: model_kwargs["response_format"] = response_format @@ -20,10 +65,15 @@ def litellm_completion(model: str, messages: list, model_kwargs: dict, response_ del model_kwargs["response_format"] model_kwargs["format"] = response_format.model_json_schema() - response = completion( - model, - messages, - **model_kwargs - ) - + # response = completion( + # model, + # messages, + # **model_kwargs + # ) + response = _client.post_completion({ + "model": model, + "messages": messages, + "model_kwargs": model_kwargs + }) + # print(response) return response \ No newline at end of file diff --git a/cognify/optimizer/trace/checkpoint.py b/cognify/optimizer/trace/checkpoint.py index 2cc39c3..3585d9e 100644 --- a/cognify/optimizer/trace/checkpoint.py +++ b/cognify/optimizer/trace/checkpoint.py @@ -92,8 +92,8 @@ def get_pareto_front(candidates: list[TrialLog]) -> list[TrialLog]: trial_log.result.reduced_exec_time ) ) - - vectors = np.array(list(map(list, zip(*score_cost_list)))) + + vectors = np.array(score_cost_list) is_efficient = np.ones(vectors.shape[0], dtype=bool) for i, v in enumerate(vectors): if is_efficient[i]: diff --git a/cognify/rate_limiter.py b/cognify/rate_limiter.py new file mode 100644 index 0000000..de41955 --- /dev/null +++ b/cognify/rate_limiter.py @@ -0,0 +1,135 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import threading, queue, time, uuid +from litellm import completion, RateLimitError +from dataclasses import dataclass +import litellm +import sys + +# litellm.set_verbose=True + +app = FastAPI() + +# --- Configuration --- +DEFAULT_RATE = 500 # Maximum calls per second, used when response header has no rate limit info + # equivalent to Tier-5 openai-4o-mini rate limit +NUM_WORKERS = 256 # Number of worker threads per model + +# Holds (job_id, task data) per model +task_queue_pool = {} +# Maps job_id -> result dict +job_results = {} +# Maps job_id -> threading.Event +job_events = {} +rate_limit_pool = {} +rate_semaphore_pool = {} + +@dataclass +class CompletionRequest: + model: str + messages: list + model_kwargs: dict + +# --- Rate Limiter Thread --- +def rate_limiter(semaphore, name): + while True: + time.sleep(1.0 / rate_limit_pool[name]) + semaphore.release() + +# --- Worker Thread Function --- +def worker(semaphore, task_queue): + while True: + job_id, req = task_queue.get() + # Wait for a token (rate limiting) + semaphore.acquire() + try: + # Call the underlying completion function. + response = completion(req.model, req.messages, **req.model_kwargs) + result = {"result": {**response.model_dump(), "_hidden_params": response._hidden_params, "_response_headers": response._response_headers}} + job_results[job_id] = result + # increase rate limit by 1 + rate_limit_pool[req.model] += 1 + except RateLimitError as e: + # reduce rate limit by half and put to the back of the queue + rate_limit_pool[req.model] /= 2 + task_queue.put((job_id, req)) + except Exception as e: + job_results[job_id] = {"error": str(e)} + # Signal that the job is done. + if job_id in job_results: + job_events[job_id].set() + task_queue.task_done() + +def first_time_request(job_id, req: CompletionRequest): + try: + response = completion(req.model, req.messages, **req.model_kwargs) + result = {"result": {**response.model_dump(), "_hidden_params": response._hidden_params, "_response_headers": response._response_headers}} + job_results[job_id] = result + except Exception as e: + job_results[job_id] = {"error": str(e)} + raise e + # Signal that the job is done. + if job_id in job_results: + job_events[job_id].set() + + # setup rate limit for this model + if limit := response._response_headers.get("x-ratelimit-remaining-requests", None): + rate = (int(limit) + 1) / 60 # to account for the current request + # print(f"Rate limit for {req.model}: {rate}") + else: + rate = DEFAULT_RATE + # start workers + for _ in range(NUM_WORKERS): + t = threading.Thread(target=worker, args=( + rate_semaphore_pool[req.model], + task_queue_pool[req.model] + ), daemon=True) + t.start() + # start ticket generator + rate_limit_pool[req.model] = rate + threading.Thread(target=rate_limiter, args=( + rate_semaphore_pool[req.model], + req.model + ), daemon=True).start() + + +# --- FastAPI Endpoint --- +@app.post("/completion_endpoint") +def completion_endpoint(req: CompletionRequest): + + # Create a unique job ID and an Event to wait for the result. + job_id = str(uuid.uuid4()) + event = threading.Event() + job_events[job_id] = event + + # Enqueue the task. + # If model is new, create a new limiter for it + if req.model not in task_queue_pool: + task_queue_pool[req.model] = queue.Queue() + rate_semaphore_pool[req.model] = threading.Semaphore(0) + first_time_request(job_id, req) + else: + task_queue_pool[req.model].put((job_id, req)) + + # Wait for the worker to process the task + event.wait() + # if not event.wait(timeout=30): + # job_events.pop(job_id, None) + # job_results.pop(job_id, None) + # raise HTTPException(status_code=504, detail="Task timed out") + + result = job_results.pop(job_id, None) + job_events.pop(job_id, None) + if result is None: + raise HTTPException(status_code=500, detail="Job processing error") + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + return result + +import uvicorn + +def run_rate_limiter(port): + uvicorn.run(app, host="0.0.0.0", port=port, log_level="error") + +if __name__ == "__main__": + run_rate_limiter(int(sys.argv[1])) \ No newline at end of file