Skip to content

Commit

Permalink
feat(agents-api): Add in-memory rate limiter to transition activity
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 7, 2024
1 parent 277a24d commit f728532
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 3 deletions.
16 changes: 15 additions & 1 deletion agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import load_from_blob_store_if_remote
from ...env import temporal_activity_after_retry_timeout, testing
from ...env import (
temporal_activity_after_retry_timeout,
testing,
transition_requests_per_minute,
)
from ...exceptions import LastErrorInput, TooManyRequestsError
from ...models.execution.create_execution_transition import (
create_execution_transition_async,
)
from ..utils import RateLimiter

# Global rate limiter instance
rate_limiter = RateLimiter(max_requests=transition_requests_per_minute)


@beartype
Expand All @@ -21,6 +29,12 @@ async def transition_step(
transition_info: CreateTransitionRequest,
last_error: BaseException | None = None,
) -> Transition:
# Check rate limit first
if not await rate_limiter.acquire():
raise TooManyRequestsError(
f"Rate limit exceeded. Maximum {transition_requests_per_minute} requests per minute allowed."
)

from ...workflows.task_execution import TaskExecutionWorkflow

activity_info = activity.info()
Expand Down
39 changes: 39 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import datetime as dt
import functools
Expand All @@ -10,6 +11,9 @@
import time
import urllib.parse
import zoneinfo
from collections import deque
from dataclasses import dataclass
from threading import Lock as ThreadLock
from typing import Any, Callable, ParamSpec, TypeVar

import re2
Expand Down Expand Up @@ -378,3 +382,38 @@ def get_handler(system: SystemDef) -> Callable:
raise NotImplementedError(
f"System call not implemented for {system.resource}.{system.operation}"
)


@dataclass
class RateLimiter:
max_requests: int # Maximum requests per minute
window_size: int = 60 # Window size in seconds (1 minute)

def __post_init__(self):
self._requests = deque()
self._lock = ThreadLock() # Thread-safe lock
self._async_lock = asyncio.Lock() # Async-safe lock

def _clean_old_requests(self):
now = time.time()
while self._requests and now - self._requests[0] > self.window_size:
self._requests.popleft()

async def acquire(self):
async with self._async_lock:
with self._lock:
now = time.time()
self._clean_old_requests()

if len(self._requests) >= self.max_requests:
return False

self._requests.append(now)
return True

@property
def current_usage(self) -> int:
"""Return current number of requests in the window"""
with self._lock:
self._clean_old_requests()
return len(self._requests)
3 changes: 3 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
# Tasks
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)
transition_requests_per_minute: int = env.int(
"AGENTS_API_TRANSITION_REQUESTS_PER_MINUTE", default=100
)


# Blob Store
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/worker/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def to_payload(self, value: Any) -> Optional[Payload]:
# TODO: In production, we don't want to crash the workflow
# But the sentinel object must be handled by the caller
logging.warning(f"WARNING: Could not encode {value}: {e}")
return FailedEncodingSentinel(payload_data=data)
return FailedEncodingSentinel(payload_data=value)

def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
current_python_version = (
Expand Down
8 changes: 7 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
)
from agents_api.models.task.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution
from tests.fixtures import cozo_client, async_cozo_client, test_agent, test_developer_id, cozo_clients_with_migrations
from tests.fixtures import (
async_cozo_client,
cozo_client,
cozo_clients_with_migrations,
test_agent,
test_developer_id,
)
from tests.utils import patch_integration_service, patch_testing_temporal

EMBEDDING_SIZE: int = 1024
Expand Down

0 comments on commit f728532

Please sign in to comment.