Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cognify/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import logging
import os
import multiprocessing as mp
import dotenv

from cognify.cognify_args import (
init_cognify_args,
Expand All @@ -18,6 +20,7 @@
from cognify.run.run import run
from cognify._logging import _configure_logger
from cognify._tracing import trace_cli_args, trace_workflow, initial_usage_message
from cognify.rate_limiter import run_rate_limiter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -122,6 +125,11 @@ def main():

cognify_args = from_cognify_args(raw_args)

os.environ["_cognify_rate_limit_base_url"] = f"http://127.0.0.1:{raw_args.rate_limit_port}"
dotenv.load_dotenv(cognify_args.key_env)
rate_limit_process = mp.Process(target=run_rate_limiter, args=(raw_args.rate_limit_port,))
rate_limit_process.start()

if raw_args.mode == "optimize" or raw_args.mode == "evaluate":
workflow_path = cognify_args.workflow
if not os.path.exists(workflow_path):
Expand All @@ -137,6 +145,9 @@ def main():
run_routine(cognify_args)
else:
raise ValueError(f"Unknown mode: {raw_args.mode}")

# end rate limiter
rate_limit_process.terminate()
return


Expand Down
20 changes: 20 additions & 0 deletions cognify/cognify_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ class CommonArgs:
workflow: str
config: str = None
log_level: str = "WARNING"
rate_limit_port: int = 55555
key_env: str = None

def __post_init__(self):
# Set missing values
self.find_files()

def find_files(self):
self.search_at_workflow_dir_if_not_set("config", "config.py")
self.search_at_workflow_dir_if_not_set("key_env", ".env")

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -46,6 +49,23 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
metavar="log_level",
)
parser.add_argument(
"-p",
"--rate_limit_port",
type=int,
default=CommonArgs.rate_limit_port,
help="Port number for rate limit server",
metavar="port_number",
)
parser.add_argument(
"-k",
"--key_env",
type=str,
default=CommonArgs.key_env,
help="Path to the key env file for API access.\n"
"If not provided, will search .env in the same directory as workflow script.",
metavar="path_to_key_env",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "CommonArgs":
Expand Down
62 changes: 56 additions & 6 deletions cognify/llm/litellm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
from litellm import completion
from litellm.types.utils import ModelResponse
from pydantic import BaseModel
import requests
import os

import zmq

class HTTPClient:
_instance = None
_initialized = False

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(HTTPClient, cls).__new__(cls)
return cls._instance

def __init__(self):
if not self._initialized:
self._initialized = True
self.base_url = None

def post_completion(self, data):
if self.base_url is None:
base_url = os.getenv("_cognify_rate_limit_base_url")
if base_url is None:
raise Exception("Rate limit base URL not found.")
self.base_url = base_url

url = f"{self.base_url}/completion_endpoint"
response = requests.post(url, json=data)
if response.status_code != 200:
raise Exception(response.json().get("detail", "Unknown error"))
result = response.json()["result"]
hidden_params = result.pop('_hidden_params', None)
response_headers = result.pop('_response_headers', None)
response = ModelResponse(**result)
if hidden_params:
response._hidden_params = hidden_params
if response_headers:
response._response_headers = response_headers
return response


_client = HTTPClient()


def litellm_completion(model: str, messages: list, model_kwargs: dict, response_format: BaseModel = None):

if response_format:
model_kwargs["response_format"] = response_format

Expand All @@ -20,10 +65,15 @@ def litellm_completion(model: str, messages: list, model_kwargs: dict, response_
del model_kwargs["response_format"]
model_kwargs["format"] = response_format.model_json_schema()

response = completion(
model,
messages,
**model_kwargs
)

# response = completion(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the client does not want rate limiter (e.g., to ease the debugging to avoid another http endpoint, or in the replay mode), can they revert to the non-rate limiter functionality?

# model,
# messages,
# **model_kwargs
# )
response = _client.post_completion({
"model": model,
"messages": messages,
"model_kwargs": model_kwargs
})
# print(response)
return response
4 changes: 2 additions & 2 deletions cognify/optimizer/trace/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def get_pareto_front(candidates: list[TrialLog]) -> list[TrialLog]:
trial_log.result.reduced_exec_time
)
)

vectors = np.array(list(map(list, zip(*score_cost_list))))
vectors = np.array(score_cost_list)
is_efficient = np.ones(vectors.shape[0], dtype=bool)
for i, v in enumerate(vectors):
if is_efficient[i]:
Expand Down
135 changes: 135 additions & 0 deletions cognify/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import threading, queue, time, uuid
from litellm import completion, RateLimitError
from dataclasses import dataclass
import litellm
import sys

# litellm.set_verbose=True

app = FastAPI()

# --- Configuration ---
DEFAULT_RATE = 500 # Maximum calls per second, used when response header has no rate limit info
# equivalent to Tier-5 openai-4o-mini rate limit
NUM_WORKERS = 256 # Number of worker threads per model

# Holds (job_id, task data) per model
task_queue_pool = {}
# Maps job_id -> result dict
job_results = {}
# Maps job_id -> threading.Event
job_events = {}
rate_limit_pool = {}
rate_semaphore_pool = {}

@dataclass
class CompletionRequest:
model: str
messages: list
model_kwargs: dict

# --- Rate Limiter Thread ---
def rate_limiter(semaphore, name):
while True:
time.sleep(1.0 / rate_limit_pool[name])
semaphore.release()

# --- Worker Thread Function ---
def worker(semaphore, task_queue):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worker is single process multi-thread. I'm not sure if it will becomes the performance bottleneck. Can we write a short piece of code to test its performance limit?

while True:
job_id, req = task_queue.get()
# Wait for a token (rate limiting)
semaphore.acquire()
try:
# Call the underlying completion function.
response = completion(req.model, req.messages, **req.model_kwargs)
result = {"result": {**response.model_dump(), "_hidden_params": response._hidden_params, "_response_headers": response._response_headers}}
job_results[job_id] = result
# increase rate limit by 1
rate_limit_pool[req.model] += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also refactor into a separate function and let user decide the strategy. In some cases we don't need to increase the rate limit.

except RateLimitError as e:
# reduce rate limit by half and put to the back of the queue
rate_limit_pool[req.model] /= 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about refactor this into a backoff function that a user can control the strategy? "/2" may be too aggressive.

task_queue.put((job_id, req))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes the requester won't timeout. It is probably ok for cognify but not necessary ok for a generic rate limiter. Can you add a comment?

except Exception as e:
job_results[job_id] = {"error": str(e)}
# Signal that the job is done.
if job_id in job_results:
job_events[job_id].set()
task_queue.task_done()

def first_time_request(job_id, req: CompletionRequest):
try:
response = completion(req.model, req.messages, **req.model_kwargs)
result = {"result": {**response.model_dump(), "_hidden_params": response._hidden_params, "_response_headers": response._response_headers}}
job_results[job_id] = result
except Exception as e:
job_results[job_id] = {"error": str(e)}
raise e
# Signal that the job is done.
if job_id in job_results:
job_events[job_id].set()

# setup rate limit for this model
if limit := response._response_headers.get("x-ratelimit-remaining-requests", None):
rate = (int(limit) + 1) / 60 # to account for the current request
# print(f"Rate limit for {req.model}: {rate}")
else:
rate = DEFAULT_RATE
# start workers
for _ in range(NUM_WORKERS):
t = threading.Thread(target=worker, args=(
rate_semaphore_pool[req.model],
task_queue_pool[req.model]
), daemon=True)
t.start()
# start ticket generator
rate_limit_pool[req.model] = rate
threading.Thread(target=rate_limiter, args=(
rate_semaphore_pool[req.model],
req.model
), daemon=True).start()


# --- FastAPI Endpoint ---
@app.post("/completion_endpoint")
def completion_endpoint(req: CompletionRequest):

# Create a unique job ID and an Event to wait for the result.
job_id = str(uuid.uuid4())
event = threading.Event()
job_events[job_id] = event

# Enqueue the task.
# If model is new, create a new limiter for it
if req.model not in task_queue_pool:
task_queue_pool[req.model] = queue.Queue()
rate_semaphore_pool[req.model] = threading.Semaphore(0)
first_time_request(job_id, req)
else:
task_queue_pool[req.model].put((job_id, req))

# Wait for the worker to process the task
event.wait()
# if not event.wait(timeout=30):
# job_events.pop(job_id, None)
# job_results.pop(job_id, None)
# raise HTTPException(status_code=504, detail="Task timed out")

result = job_results.pop(job_id, None)
job_events.pop(job_id, None)
if result is None:
raise HTTPException(status_code=500, detail="Job processing error")
if "error" in result:
raise HTTPException(status_code=500, detail=result["error"])
return result

import uvicorn

def run_rate_limiter(port):
uvicorn.run(app, host="0.0.0.0", port=port, log_level="error")

if __name__ == "__main__":
run_rate_limiter(int(sys.argv[1]))