From f7285325224de4cab2d2f2cf9f2858ec008d79d6 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Sat, 7 Dec 2024 18:24:03 +0530 Subject: [PATCH] feat(agents-api): Add in-memory rate limiter to transition activity Signed-off-by: Diwank Singh Tomer --- .../activities/task_steps/transition_step.py | 16 +++++++- agents-api/agents_api/activities/utils.py | 39 +++++++++++++++++++ agents-api/agents_api/env.py | 3 ++ agents-api/agents_api/worker/codec.py | 2 +- agents-api/tests/test_execution_workflow.py | 8 +++- 5 files changed, 65 insertions(+), 3 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index 100befb56..335f33e8a 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -8,11 +8,19 @@ from ...clients.temporal import get_workflow_handle from ...common.protocol.tasks import StepContext from ...common.storage_handler import load_from_blob_store_if_remote -from ...env import temporal_activity_after_retry_timeout, testing +from ...env import ( + temporal_activity_after_retry_timeout, + testing, + transition_requests_per_minute, +) from ...exceptions import LastErrorInput, TooManyRequestsError from ...models.execution.create_execution_transition import ( create_execution_transition_async, ) +from ..utils import RateLimiter + +# Global rate limiter instance +rate_limiter = RateLimiter(max_requests=transition_requests_per_minute) @beartype @@ -21,6 +29,12 @@ async def transition_step( transition_info: CreateTransitionRequest, last_error: BaseException | None = None, ) -> Transition: + # Check rate limit first + if not await rate_limiter.acquire(): + raise TooManyRequestsError( + f"Rate limit exceeded. Maximum {transition_requests_per_minute} requests per minute allowed." + ) + from ...workflows.task_execution import TaskExecutionWorkflow activity_info = activity.info() diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index 802d10c13..c3c810ef0 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -1,3 +1,4 @@ +import asyncio import base64 import datetime as dt import functools @@ -10,6 +11,9 @@ import time import urllib.parse import zoneinfo +from collections import deque +from dataclasses import dataclass +from threading import Lock as ThreadLock from typing import Any, Callable, ParamSpec, TypeVar import re2 @@ -378,3 +382,38 @@ def get_handler(system: SystemDef) -> Callable: raise NotImplementedError( f"System call not implemented for {system.resource}.{system.operation}" ) + + +@dataclass +class RateLimiter: + max_requests: int # Maximum requests per minute + window_size: int = 60 # Window size in seconds (1 minute) + + def __post_init__(self): + self._requests = deque() + self._lock = ThreadLock() # Thread-safe lock + self._async_lock = asyncio.Lock() # Async-safe lock + + def _clean_old_requests(self): + now = time.time() + while self._requests and now - self._requests[0] > self.window_size: + self._requests.popleft() + + async def acquire(self): + async with self._async_lock: + with self._lock: + now = time.time() + self._clean_old_requests() + + if len(self._requests) >= self.max_requests: + return False + + self._requests.append(now) + return True + + @property + def current_usage(self) -> int: + """Return current number of requests in the window""" + with self._lock: + self._clean_old_requests() + return len(self._requests) diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 46b019048..26835db52 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -29,6 +29,9 @@ # Tasks # ----- task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100) +transition_requests_per_minute: int = env.int( + "AGENTS_API_TRANSITION_REQUESTS_PER_MINUTE", default=100 +) # Blob Store diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index f909ebef1..18cb2afc9 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -133,7 +133,7 @@ def to_payload(self, value: Any) -> Optional[Payload]: # TODO: In production, we don't want to crash the workflow # But the sentinel object must be handled by the caller logging.warning(f"WARNING: Could not encode {value}: {e}") - return FailedEncodingSentinel(payload_data=data) + return FailedEncodingSentinel(payload_data=value) def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: current_python_version = ( diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 17777ed6c..e733f81c0 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -15,7 +15,13 @@ ) from agents_api.models.task.create_task import create_task from agents_api.routers.tasks.create_task_execution import start_execution -from tests.fixtures import cozo_client, async_cozo_client, test_agent, test_developer_id, cozo_clients_with_migrations +from tests.fixtures import ( + async_cozo_client, + cozo_client, + cozo_clients_with_migrations, + test_agent, + test_developer_id, +) from tests.utils import patch_integration_service, patch_testing_temporal EMBEDDING_SIZE: int = 1024