diff --git a/pyproject.toml b/pyproject.toml index 133266ec..0641259b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ runpod = "runpod.cli.entry:runpod_cli" test = [ "asynctest", "nest_asyncio", + "faker", "pytest-asyncio", "pytest-cov", "pytest-timeout", diff --git a/runpod/http_client.py b/runpod/http_client.py index 145060bf..0621fccd 100644 --- a/runpod/http_client.py +++ b/runpod/http_client.py @@ -8,7 +8,6 @@ from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError from .cli.groups.config.functions import get_credentials -from .tracer import create_aiohttp_tracer, create_request_tracer from .user_agent import USER_AGENT @@ -37,13 +36,11 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name """ Deprecation from aiohttp.ClientSession forbids inheritance. This is now a factory method - TODO: use httpx """ return ClientSession( connector=TCPConnector(limit=0), headers=get_auth_header(), timeout=ClientTimeout(600, ceil_threshold=400), - trace_configs=[create_aiohttp_tracer()], *args, **kwargs, ) @@ -52,43 +49,5 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name class SyncClientSession(requests.Session): """ Inherits requests.Session to override `request()` method for tracing - TODO: use httpx """ - - def request(self, method, url, **kwargs): # pylint: disable=arguments-differ - """ - Override for tracing. Not using super().request() - to capture metrics for connection and transfer times - """ - with create_request_tracer() as tracer: - # Separate out the kwargs that are not applicable to `requests.Request` - request_kwargs = { - k: v - for k, v in kwargs.items() - # contains the names of the arguments - if k in requests.Request.__init__.__code__.co_varnames - } - - # Separate out the kwargs that are applicable to `requests.Request` - send_kwargs = {k: v for k, v in kwargs.items() if k not in request_kwargs} - - # Create a PreparedRequest object to hold the request details - req = requests.Request(method, url, **request_kwargs) - prepped = self.prepare_request(req) - tracer.request = prepped # Assign the request to the tracer - - # Merge environment settings - settings = self.merge_environment_settings( - prepped.url, - send_kwargs.get("proxies"), - send_kwargs.get("stream"), - send_kwargs.get("verify"), - send_kwargs.get("cert"), - ) - send_kwargs.update(settings) - - # Send the request - response = self.send(prepped, **send_kwargs) - tracer.response = response # Assign the response to the tracer - - return response + pass \ No newline at end of file diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index 62f70722..541ac56c 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -23,6 +23,13 @@ log = RunPodLogger() + +def handle_uncaught_exception(exc_type, exc_value, exc_traceback): + log.error(f"Uncaught exception | {exc_type}; {exc_value}; {exc_traceback};") + +sys.excepthook = handle_uncaught_exception + + # ---------------------------------------------------------------------------- # # Run Time Arguments # # ---------------------------------------------------------------------------- # diff --git a/runpod/serverless/modules/rp_fastapi.py b/runpod/serverless/modules/rp_fastapi.py index 15c75133..1747337d 100644 --- a/runpod/serverless/modules/rp_fastapi.py +++ b/runpod/serverless/modules/rp_fastapi.py @@ -286,12 +286,12 @@ async def _realtime(self, job: Job): Performs model inference on the input data using the provided handler. If handler is not provided, returns an error message. """ - await job_list.add(job.id) + job_list.add(job.id) # Process the job using the provided handler, passing in the job input. job_results = await run_job(self.config["handler"], job.__dict__) - await job_list.remove(job.id) + job_list.remove(job.id) # Return the results of the job processing. return jsonable_encoder(job_results) @@ -304,7 +304,7 @@ async def _realtime(self, job: Job): async def _sim_run(self, job_request: DefaultRequest) -> JobOutput: """Development endpoint to simulate run behavior.""" assigned_job_id = f"test-{uuid.uuid4()}" - await job_list.add({ + job_list.add({ "id": assigned_job_id, "input": job_request.input, "webhook": job_request.webhook @@ -345,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput: # ---------------------------------- stream ---------------------------------- # async def _sim_stream(self, job_id: str) -> StreamOutput: """Development endpoint to simulate stream behavior.""" - stashed_job = await job_list.get(job_id) + stashed_job = job_list.get(job_id) if stashed_job is None: return jsonable_encoder( {"id": job_id, "status": "FAILED", "error": "Job ID not found"} @@ -367,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput: } ) - await job_list.remove(job.id) + job_list.remove(job.id) if stashed_job.webhook: thread = threading.Thread( @@ -384,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput: # ---------------------------------- status ---------------------------------- # async def _sim_status(self, job_id: str) -> JobOutput: """Development endpoint to simulate status behavior.""" - stashed_job = await job_list.get(job_id) + stashed_job = job_list.get(job_id) if stashed_job is None: return jsonable_encoder( {"id": job_id, "status": "FAILED", "error": "Job ID not found"} @@ -400,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput: else: job_output = await run_job(self.config["handler"], job.__dict__) - await job_list.remove(job.id) + job_list.remove(job.id) if job_output.get("error", None): return jsonable_encoder( diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 55bad7fa..0a478005 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -10,10 +10,9 @@ from ...http_client import AsyncClientSession, ClientSession, TooManyRequests from .rp_job import get_job, handle_job from .rp_logger import RunPodLogger -from .worker_state import JobsQueue, JobsProgress +from .worker_state import JobsProgress, IS_LOCAL_TEST log = RunPodLogger() -job_list = JobsQueue() job_progress = JobsProgress() @@ -38,16 +37,50 @@ class JobScaler: """ def __init__(self, config: Dict[str, Any]): - concurrency_modifier = config.get("concurrency_modifier") - if concurrency_modifier is None: - self.concurrency_modifier = _default_concurrency_modifier - else: - self.concurrency_modifier = concurrency_modifier - self._shutdown_event = asyncio.Event() self.current_concurrency = 1 self.config = config + self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) + + self.concurrency_modifier = _default_concurrency_modifier + self.jobs_fetcher = get_job + self.jobs_fetcher_timeout = 90 + self.jobs_handler = handle_job + + if concurrency_modifier := config.get("concurrency_modifier"): + self.concurrency_modifier = concurrency_modifier + + if not IS_LOCAL_TEST: + # below cannot be changed unless local + return + + if jobs_fetcher := self.config.get("jobs_fetcher"): + self.jobs_fetcher = jobs_fetcher + + if jobs_fetcher_timeout := self.config.get("jobs_fetcher_timeout"): + self.jobs_fetcher_timeout = jobs_fetcher_timeout + + if jobs_handler := self.config.get("jobs_handler"): + self.jobs_handler = jobs_handler + + async def set_scale(self): + self.current_concurrency = self.concurrency_modifier(self.current_concurrency) + + if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize): + # no need to resize + return + + while self.current_occupancy() > 0: + # not safe to scale when jobs are in flight + await asyncio.sleep(1) + continue + + self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) + log.debug( + f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}" + ) + def start(self): """ This is required for the worker to be able to shut down gracefully @@ -105,6 +138,15 @@ def kill_worker(self): log.info("Kill worker.") self._shutdown_event.set() + def current_occupancy(self) -> int: + current_queue_count = self.jobs_queue.qsize() + current_progress_count = job_progress.get_job_count() + + log.debug( + f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}" + ) + return current_progress_count + current_queue_count + async def get_jobs(self, session: ClientSession): """ Retrieve multiple jobs from the server in batches using blocking requests. @@ -114,29 +156,21 @@ async def get_jobs(self, session: ClientSession): Adds jobs to the JobsQueue """ while self.is_alive(): - log.debug("JobScaler.get_jobs | Starting job acquisition.") - - self.current_concurrency = self.concurrency_modifier( - self.current_concurrency - ) - log.debug(f"JobScaler.get_jobs | current Concurrency set to: {self.current_concurrency}") + await self.set_scale() - current_progress_count = await job_progress.get_job_count() - log.debug(f"JobScaler.get_jobs | current Jobs in progress: {current_progress_count}") - - current_queue_count = job_list.get_job_count() - log.debug(f"JobScaler.get_jobs | current Jobs in queue: {current_queue_count}") - - jobs_needed = self.current_concurrency - current_progress_count - current_queue_count + jobs_needed = self.current_concurrency - self.current_occupancy() if jobs_needed <= 0: log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.") await asyncio.sleep(1) # don't go rapidly continue try: - # Keep the connection to the blocking call up to 30 seconds + log.debug("JobScaler.get_jobs | Starting job acquisition.") + + # Keep the connection to the blocking call with timeout acquired_jobs = await asyncio.wait_for( - get_job(session, jobs_needed), timeout=30 + self.jobs_fetcher(session, jobs_needed), + timeout=self.jobs_fetcher_timeout, ) if not acquired_jobs: @@ -144,15 +178,20 @@ async def get_jobs(self, session: ClientSession): continue for job in acquired_jobs: - await job_list.add_job(job) + await self.jobs_queue.put(job) + job_progress.add(job) + log.debug("Job Queued", job["id"]) - log.info(f"Jobs in queue: {job_list.get_job_count()}") + log.info(f"Jobs in queue: {self.jobs_queue.qsize()}") except TooManyRequests: - log.debug(f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds.") + log.debug( + f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds." + ) await asyncio.sleep(5) # debounce for 5 seconds except asyncio.CancelledError: log.debug("JobScaler.get_jobs | Request was cancelled.") + raise # CancelledError is a BaseException except TimeoutError: log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.") except TypeError as error: @@ -173,10 +212,10 @@ async def run_jobs(self, session: ClientSession): """ tasks = [] # Store the tasks for concurrent job processing - while self.is_alive() or not job_list.empty(): + while self.is_alive() or not self.jobs_queue.empty(): # Fetch as many jobs as the concurrency allows - while len(tasks) < self.current_concurrency and not job_list.empty(): - job = await job_list.get_job() + while len(tasks) < self.current_concurrency and not self.jobs_queue.empty(): + job = await self.jobs_queue.get() # Create a new task for each job and add it to the task list task = asyncio.create_task(self.handle_job(session, job)) @@ -204,9 +243,9 @@ async def handle_job(self, session: ClientSession, job: dict): Process an individual job. This function is run concurrently for multiple jobs. """ try: - await job_progress.add(job) + log.debug("Handling Job", job["id"]) - await handle_job(session, self.config, job) + await self.jobs_handler(session, self.config, job) if self.config.get("refresh_worker", False): self.kill_worker() @@ -216,8 +255,10 @@ async def handle_job(self, session: ClientSession, job: dict): raise err finally: - # Inform JobsQueue of a task completion - job_list.task_done() + # Inform Queue of a task completion + self.jobs_queue.task_done() # Job is no longer in progress - await job_progress.remove(job["id"]) + job_progress.remove(job) + + log.debug("Finished Job", job["id"]) diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index d0b1c07c..5e1a2f98 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -6,7 +6,6 @@ import time import uuid from typing import Any, Dict, Optional -from asyncio import Queue, Lock from .rp_logger import RunPodLogger @@ -72,19 +71,13 @@ def __new__(cls): JobsProgress._instance = set.__new__(cls) return JobsProgress._instance - def __init__(self): - if not hasattr(self, "_lock"): - # Initialize the lock once - self._lock = Lock() - def __repr__(self) -> str: return f"<{self.__class__.__name__}>: {self.get_job_list()}" - async def clear(self) -> None: - async with self._lock: - return super().clear() + def clear(self) -> None: + return super().clear() - async def add(self, element: Any): + def add(self, element: Any): """ Adds a Job object to the set. @@ -101,11 +94,9 @@ async def add(self, element: Any): if not isinstance(element, Job): raise TypeError("Only Job objects can be added to JobsProgress.") - async with self._lock: - log.debug(f"JobsProgress.add", element.id) - super().add(element) + return super().add(element) - async def remove(self, element: Any): + def remove(self, element: Any): """ Removes a Job object from the set. @@ -122,21 +113,18 @@ async def remove(self, element: Any): if not isinstance(element, Job): raise TypeError("Only Job objects can be removed from JobsProgress.") - async with self._lock: - log.debug(f"JobsProgress.remove", element.id) - return super().discard(element) + return super().discard(element) - async def get(self, element: Any) -> Job: + def get(self, element: Any) -> Job: if isinstance(element, str): element = Job(id=element) if not isinstance(element, Job): raise TypeError("Only Job objects can be retrieved from JobsProgress.") - async with self._lock: - for job in self: - if job == element: - return job + for job in self: + if job == element: + return job def get_job_list(self) -> str: """ @@ -147,66 +135,8 @@ def get_job_list(self) -> str: return ",".join(str(job) for job in self) - async def get_job_count(self) -> int: - """ - Returns the number of jobs in a thread-safe manner. - """ - async with self._lock: - return len(self) - - -class JobsQueue(Queue): - """Central Jobs Queue for job take and job processing""" - - _instance = None - - def __new__(cls): - if JobsQueue._instance is None: - JobsQueue._instance = object.__new__(cls) - return JobsQueue._instance - - def __iter__(self): - return iter(list(self._queue)) - - async def add_job(self, job: dict): - """ - Adds a job to the queue. - - If the queue is full, wait until a free - slot is available before adding item. - """ - log.debug(f"JobsQueue.add_job", job["id"]) - return await self.put(job) - - async def get_job(self) -> dict: - """ - Remove and return the next job from the queue. - - If queue is empty, wait until a job is available. - - Note: make sure to call `.task_done()` when processing the job is finished. - """ - return await self.get() - - def get_job_list(self) -> Optional[str]: - """ - Returns the comma-separated list of jobs as a string. (read-only) - """ - if self.empty(): - return None - - return ",".join(job.get("id") for job in self) - def get_job_count(self) -> int: """ Returns the number of jobs. """ - return self.qsize() - - async def clear(self): - """ - Empties the Queue by getting each item. - """ - while not self.empty(): - await self.get() - self.task_done() + return len(self) diff --git a/runpod/tracer.py b/runpod/tracer.py deleted file mode 100644 index 81dd21b5..00000000 --- a/runpod/tracer.py +++ /dev/null @@ -1,296 +0,0 @@ -# pylint: disable-all -# Temporary tracer while we're still using aiohttp and requests -# TODO: use httpx and opentelemetry - -import asyncio -import json -from datetime import datetime, timezone -from time import monotonic, time -from types import SimpleNamespace -from uuid import uuid4 - -from aiohttp import ( - ClientSession, - TraceConfig, - TraceConnectionCreateEndParams, - TraceConnectionCreateStartParams, - TraceConnectionReuseconnParams, - TraceRequestChunkSentParams, - TraceRequestEndParams, - TraceRequestExceptionParams, - TraceRequestStartParams, - TraceResponseChunkReceivedParams, -) -from requests import PreparedRequest, Response, structures - -from .serverless.modules.rp_logger import RunPodLogger - -log = RunPodLogger() - - -def time_to_iso8601(ts: float) -> str: - """Convert a Unix timestamp to an ISO 8601 formatted string in UTC.""" - dt = datetime.fromtimestamp(ts, tz=timezone.utc) - return dt.isoformat() - - -def headers_to_context(context: SimpleNamespace, headers: dict): - """Generate a context object based on the provided headers.""" - context.trace_id = str(uuid4()) - context.request_id = None - context.user_agent = None - - if headers: - headers = structures.CaseInsensitiveDict(headers) - context.trace_id = headers.get("x-trace-id", context.trace_id) - context.request_id = headers.get("x-request-id") - context.user_agent = headers.get("user-agent") - - return context - - -# Tracer for aiohttp - - -async def on_request_start( - session: ClientSession, - context: SimpleNamespace, - params: TraceRequestStartParams, -): - """Handle the start of a request.""" - headers = params.headers if hasattr(params, "headers") else {} - context = headers_to_context(context, headers) - context.start_time = time() - context.on_request_start = asyncio.get_event_loop().time() - context.method = params.method - context.url = params.url.human_repr() - context.mode = "async" - - if hasattr(context, "trace_request_ctx") and context.trace_request_ctx: - context.retries = context.trace_request_ctx["current_attempt"] - - -async def on_connection_create_start( - session: ClientSession, - context: SimpleNamespace, - params: TraceConnectionCreateStartParams, -): - """Handle the event when a connection is started.""" - context.connect = asyncio.get_event_loop().time() - context.on_request_start - - -async def on_connection_create_end( - session: ClientSession, - context: SimpleNamespace, - params: TraceConnectionCreateEndParams, -): - """Handle the event when a connection is created.""" - context.connect = asyncio.get_event_loop().time() - context.on_request_start - - -async def on_connection_reuseconn( - session: ClientSession, - context: SimpleNamespace, - params: TraceConnectionReuseconnParams, -): - """Handle the event when a connection is reused.""" - context.connect = asyncio.get_event_loop().time() - context.on_request_start - - -async def on_request_chunk_sent( - session: ClientSession, - context: SimpleNamespace, - params: TraceRequestChunkSentParams, -): - """Handle the event when a request chunk is sent.""" - if not hasattr(context, "payload_size_bytes"): - context.payload_size_bytes = 0 - context.payload_size_bytes += len(params.chunk) - - -async def on_response_chunk_received( - session: ClientSession, - context: SimpleNamespace, - params: TraceResponseChunkReceivedParams, -): - """Handle the event when a response chunk is received.""" - if not hasattr(context, "response_size_bytes"): - context.response_size_bytes = 0 - context.response_size_bytes += len(params.chunk) - - -async def on_request_end( - session: ClientSession, - context: SimpleNamespace, - params: TraceRequestEndParams, -): - """Handle the end of a request.""" - elapsed = asyncio.get_event_loop().time() - context.on_request_start - context.transfer = elapsed - context.connect - context.end_time = time() - - # log to trace level - report_trace(context, params, elapsed) - - -async def on_request_exception( - session: ClientSession, - context: SimpleNamespace, - params: TraceRequestExceptionParams, -): - """Handle the exception that occurred during the request.""" - context.exception = params.exception - elapsed = asyncio.get_event_loop().time() - context.on_request_start - context.transfer = elapsed - context.connect - context.end_time = time() - - # log to error level - report_trace(context, params, elapsed, log.trace) - - -def report_trace( - context: SimpleNamespace, params: object, elapsed: float, logger=log.trace -): - """ - Report the trace of a request. - The logger function is called with the JSON representation of the context object and the request ID. - - Args: - context (SimpleNamespace): The context object containing trace information. - params: The parameters of the request. - elapsed (float): The elapsed time of the request. - logger (function, optional): The logger function to use. Defaults to log.trace. - """ - context.start_time = time_to_iso8601(context.start_time) - context.end_time = time_to_iso8601(context.end_time) - context.total = round(elapsed * 1000, 1) - - if hasattr(context, "transfer") and context.transfer: - context.transfer = round(context.transfer * 1000, 1) - - if hasattr(context, "connect") and context.connect: - context.connect = round(context.connect * 1000, 1) - - if hasattr(context, "on_request_start"): - delattr(context, "on_request_start") - - if hasattr(context, "trace_request_ctx"): - delattr(context, "trace_request_ctx") - - if hasattr(params, "response") and params.response: - context.response_status = params.response.status - - logger(json.dumps(vars(context)), context.request_id) - - -def create_aiohttp_tracer() -> TraceConfig: - """ - Creates a TraceConfig object for aiohttp tracing. - - This function initializes a TraceConfig object with event handlers for various tracing events. - The TraceConfig object is used to configure and customize the tracing behavior of aiohttp. - - Returns: - TraceConfig: The initialized TraceConfig object. - - """ - # https://docs.aiohttp.org/en/stable/tracing_reference.html - trace_config = TraceConfig() - - trace_config.on_request_start.append(on_request_start) - trace_config.on_connection_create_start.append(on_connection_create_start) - trace_config.on_connection_create_end.append(on_connection_create_end) - trace_config.on_connection_reuseconn.append(on_connection_reuseconn) - trace_config.on_request_chunk_sent.append(on_request_chunk_sent) - trace_config.on_response_chunk_received.append(on_response_chunk_received) - trace_config.on_request_end.append(on_request_end) - trace_config.on_request_exception.append(on_request_exception) - - return trace_config - - -# Tracer for requests - - -class TraceRequest: - """ - Context manager for tracing requests. - - This class is used to trace requests made by the `requests` library. - It stores the request and response objects in the `request` and `response` - attributes respectively. It also provides a context manager interface - allowing the tracing of requests, including the connection and transfer - times. - - When the context manager is entered, the request start time is recorded. - When the context manager is exited, the trace is reported. - - Attributes: - context (SimpleNamespace): The context object used to store - trace information. - request (PreparedRequest): The request object. - response (Response): The response object. - request_start (float): The start time of the request. - """ - - def __init__(self): - self.context = SimpleNamespace() - self.request: PreparedRequest = None - self.response: Response = None - self.request_start = None - - def __enter__(self): - """ - Enter the context manager and record the start time of the request. - """ - self.request_start = ( - monotonic() - ) # consistency with asyncio.get_event_loop().time() - self.context.start_time = time() # reported timestamp - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Exit the context manager and report the trace. - """ - if self.request is not None: - self.context = headers_to_context(self.context, self.request.headers) - self.context.method = self.request.method - self.context.url = self.request.url - self.context.mode = "sync" - - if hasattr(self.request, "body") and \ - self.request.body and \ - isinstance(self.request.body, bytes): - self.context.payload_size_bytes = len(self.request.body) - - if self.response is not None: - self.context.end_time = time() - request_end = monotonic() - self.request_start - self.context.transfer = self.response.elapsed.total_seconds() - self.context.connect = request_end - self.context.transfer - - self.context.response_status = self.response.status_code - self.context.response_size_bytes = len(self.response.content) - - if hasattr(self.response.raw, "retries"): - self.context.retries = self.response.raw.retries.total - - logger = log.trace if self.response.ok else log.error - report_trace(self.context, {}, request_end, logger) - - -def create_request_tracer(): - """ - This function creates and returns a new instance of the `TraceRequest` class. - The `TraceRequest` class is used to trace the execution of a request in a context manager. - - Returns: - TraceRequest: An instance of the `TraceRequest` class. - - Example: - >>> with get_request_tracer() as tracer: - ... result = requests.get("https://example.com") - ... tracer.response = result.response - """ - return TraceRequest() diff --git a/setup.py b/setup.py index 01fb4aaa..11fe7ce5 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ extras_require = { "test": [ "asynctest", + "faker", "nest_asyncio", "pytest", "pytest-cov", diff --git a/tests/test_serverless/test_modules/run_scale.py b/tests/test_serverless/test_modules/run_scale.py new file mode 100644 index 00000000..5983c7a6 --- /dev/null +++ b/tests/test_serverless/test_modules/run_scale.py @@ -0,0 +1,63 @@ +import asyncio +import math +from faker import Faker +from typing import Any, Dict, Optional, List + +from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger, JobsProgress + +fake = Faker() +log = RunPodLogger() +job_progress = JobsProgress() + + +# Change this number to your desired concurrency +start = 1 + + +# sample concurrency modifier that loops +def collatz_conjecture(current_concurrency): + if current_concurrency == 1: + return start + + if current_concurrency % 2 == 0: + return math.floor(current_concurrency / 2) + else: + return current_concurrency * 3 + 1 + + +def fake_job(): + # Change this number to your desired delay + delay = fake.random_digit_above_two() + return { + "id": fake.uuid4(), + "input": fake.sentence(), + "mock_delay": delay, + } + + +async def fake_get_job(session, num_jobs: int = 1) -> Optional[List[Dict[str, Any]]]: + # Change this number to your desired delay + delay = fake.random_digit_above_two() - 1 + + log.info(f"... artificial delay ({delay}s)") + await asyncio.sleep(delay) # Simulates a blocking process + + jobs = [fake_job() for _ in range(num_jobs)] + log.info(f"... Generated # jobs: {len(jobs)}") + return jobs + + +async def fake_handle_job(session, config, job) -> dict: + await asyncio.sleep(job["mock_delay"]) # Simulates a blocking process + log.info(f"... Job handled ({job['mock_delay']}s)", job["id"]) + + +job_scaler = JobScaler( + { + # "concurrency_modifier": collatz_conjecture, + # "jobs_fetcher_timeout": 5, + "jobs_fetcher": fake_get_job, + "jobs_handler": fake_handle_job, + } +) +job_scaler.start() diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index 695a79c8..0a5517a3 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -136,8 +136,8 @@ async def test_send_ping(self, mock_get): mock_get.return_value = mock_response jobs = JobsProgress() - await jobs.add("job1") - await jobs.add("job2") + jobs.add("job1") + jobs.add("job2") heartbeat = Heartbeat() heartbeat._send_ping() diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index fdab963b..f3bb3372 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -6,7 +6,6 @@ from runpod.serverless.modules.worker_state import ( Job, JobsProgress, - JobsQueue, IS_LOCAL_TEST, WORKER_ID, ) @@ -97,79 +96,6 @@ def test_missing_attributes(self): _ = job.non_existent_attr -class TestJobsQueue(unittest.IsolatedAsyncioTestCase): - """Tests for JobsQueue class""" - - async def asyncSetUp(self): - """ - Set up test variables - """ - self.jobs = JobsQueue() - await self.jobs.clear() # clear jobs before each test - - def test_singleton(self): - """ - Tests if Jobs is a singleton class - """ - jobs2 = JobsQueue() - self.assertEqual(self.jobs, jobs2) - - async def test_add_job(self): - """ - Tests if add_job() method works as expected - """ - assert not self.jobs.get_job_count() - - job_input = {"id": "123"} - await self.jobs.add_job(job_input) - - assert self.jobs.get_job_count() == 1 - - async def test_remove_job(self): - """ - Tests if get_job() method removes the job from the queue - """ - job = {"id": "123"} - await self.jobs.add_job(job) - await self.jobs.get_job() - assert job not in self.jobs - - async def test_get_job(self): - """ - Tests if get_job() is FIFO - """ - job1 = {"id": "123"} - await self.jobs.add_job(job1) - - job2 = {"id": "456"} - await self.jobs.add_job(job2) - - next_job = await self.jobs.get_job() - assert next_job not in self.jobs - assert next_job == job1 - - next_job = await self.jobs.get_job() - assert next_job not in self.jobs - assert next_job == job2 - - async def test_get_job_list(self): - """ - Tests if get_job_list() returns comma-separated IDs - """ - self.assertTrue(self.jobs.get_job_list() is None) - - job1 = {"id": "123"} - await self.jobs.add_job(job1) - - job2 = {"id": "456"} - await self.jobs.add_job(job2) - - assert self.jobs.get_job_count() == 2 - assert job1 in self.jobs - assert job2 in self.jobs - assert self.jobs.get_job_list() in ["123,456", "456,123"] - - class TestJobsProgress(unittest.IsolatedAsyncioTestCase): """Tests for JobsProgress class""" @@ -178,56 +104,56 @@ async def asyncSetUp(self): Set up test variables """ self.jobs = JobsProgress() - await self.jobs.clear() # clear jobs before each test + self.jobs.clear() # clear jobs before each test def test_singleton(self): jobs2 = JobsProgress() self.assertEqual(self.jobs, jobs2) async def test_add_job(self): - assert not await self.jobs.get_job_count() + assert not self.jobs.get_job_count() id = "123" - await self.jobs.add({"id": id}) - assert await self.jobs.get_job_count() == 1 + self.jobs.add({"id": id}) + assert self.jobs.get_job_count() == 1 - job1 = await self.jobs.get(id) + job1 = self.jobs.get(id) assert job1 in self.jobs id = "234" - await self.jobs.add(id) - assert await self.jobs.get_job_count() == 2 + self.jobs.add(id) + assert self.jobs.get_job_count() == 2 - job2 = await self.jobs.get(id) + job2 = self.jobs.get(id) assert job2 in self.jobs async def test_remove_job(self): - assert not await self.jobs.get_job_count() + assert not self.jobs.get_job_count() job = {"id": "123"} - await self.jobs.add(job) - assert await self.jobs.get_job_count() + self.jobs.add(job) + assert self.jobs.get_job_count() - await self.jobs.remove("123") - assert not await self.jobs.get_job_count() + self.jobs.remove("123") + assert not self.jobs.get_job_count() async def test_get_job(self): for id in ["123", "234", "345"]: - await self.jobs.add({"id": id}) + self.jobs.add({"id": id}) - job1 = await self.jobs.get(id) + job1 = self.jobs.get(id) assert job1 in self.jobs async def test_get_job_list(self): assert self.jobs.get_job_list() is None job1 = {"id": "123"} - await self.jobs.add(job1) + self.jobs.add(job1) job2 = {"id": "456"} - await self.jobs.add(job2) + self.jobs.add(job2) - assert await self.jobs.get_job_count() == 2 + assert self.jobs.get_job_count() == 2 assert self.jobs.get_job_list() in ["123,456", "456,123"] async def test_get_job_count(self): diff --git a/tests/test_serverless/test_worker.py b/tests/test_serverless/test_worker.py index 87d6b629..65f5c62d 100644 --- a/tests/test_serverless/test_worker.py +++ b/tests/test_serverless/test_worker.py @@ -212,7 +212,7 @@ async def test_run_worker( runpod.serverless.start(self.config) # Make assertions about the behaviors - self.assertEqual(mock_get_job.call_count, 2) # Verify get_job called twice + self.assertEqual(mock_get_job.call_count, 1) mock_run_job.assert_called_once() mock_send_result.assert_called_once() @@ -282,7 +282,7 @@ async def test_run_worker_generator_handler_exception( assert not mock_run_job.called # Check that `send_result` was called - assert mock_send_result.call_count == 2 # Adjust expectation if multiple calls are valid + assert mock_send_result.call_count == 1 # Adjust expectation if multiple calls are valid # Inspect the arguments for each call to `send_result` for call in mock_send_result.call_args_list: diff --git a/tests/test_tracer.py b/tests/test_tracer.py deleted file mode 100644 index 6228887e..00000000 --- a/tests/test_tracer.py +++ /dev/null @@ -1,224 +0,0 @@ -# pylint: disable-all -# Temporary tracer while we're still using aiohttp and requests -# TODO: use httpx and opentelemetry -import asyncio -import json -import unittest -from time import time -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from aiohttp import ( - TraceConfig, - TraceConnectionCreateEndParams, - TraceConnectionCreateStartParams, - TraceConnectionReuseconnParams, - TraceRequestChunkSentParams, - TraceRequestExceptionParams, - TraceRequestStartParams, - TraceResponseChunkReceivedParams, -) -from yarl import URL - -from runpod.tracer import ( - create_aiohttp_tracer, - on_connection_create_end, - on_connection_create_start, - on_connection_reuseconn, - on_request_chunk_sent, - on_request_end, - on_request_exception, - on_request_start, - on_response_chunk_received, - report_trace, - time_to_iso8601, -) - - -class TestTracer(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - self.loop.close() - - def test_get_aiohttp_tracer(self): - assert isinstance(create_aiohttp_tracer(), TraceConfig) - - def test_on_request_start(self): - session = MagicMock() - context = SimpleNamespace(trace_request_ctx={"current_attempt": 0}) - params = TraceRequestStartParams( - "GET", URL("http://test.com/"), {"X-Request-ID": "myRequestId"} - ) - - self.loop.run_until_complete(on_request_start(session, context, params)) - assert hasattr(context, "on_request_start") - assert hasattr(context, "trace_id") - assert context.method == params.method - assert context.url == params.url.human_repr() - - def test_on_connection_create_start(self): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time()) - params = TraceConnectionCreateStartParams() - - self.loop.run_until_complete( - on_connection_create_start(session, context, params) - ) - - assert context.connect - - def test_on_connection_create_end(self): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time()) - params = TraceConnectionCreateEndParams() - - self.loop.run_until_complete(on_connection_create_end(session, context, params)) - - assert context.connect - - def test_on_connection_reuseconn(self): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time()) - params = TraceConnectionReuseconnParams() - - self.loop.run_until_complete(on_connection_reuseconn(session, context, params)) - - assert context.connect - - def test_on_request_chunk_sent(self): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time()) - params = TraceRequestChunkSentParams( - "GET", URL("http://test.com/"), chunk=b"test data" - ) - - # Initial call to on_request_start to initialize context - self.loop.run_until_complete(on_request_start(session, context, params)) - - # Call on_request_chunk_sent multiple times to simulate multiple chunks being sent - for _ in range(3): - self.loop.run_until_complete( - on_request_chunk_sent(session, context, params) - ) - - # Verify that payload_size_bytes has accumulated - assert context.payload_size_bytes == len(params.chunk) * 3 - - def test_on_response_chunk_received(self): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time()) - params = TraceResponseChunkReceivedParams( - "GET", URL("http://test.com/"), chunk=b"received data" - ) - - # Initial call to on_request_start to initialize context - self.loop.run_until_complete(on_request_start(session, context, params)) - - # Call on_response_chunk_received multiple times to simulate multiple chunks being received - for _ in range(3): - self.loop.run_until_complete( - on_response_chunk_received(session, context, params) - ) - - # Verify that payload_size_bytes has accumulated - assert context.response_size_bytes == len(params.chunk) * 3 - - @patch("runpod.tracer.report_trace") - def test_on_request_end(self, mock_report_trace): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time(), connect=0.5) - params = MagicMock() - - self.loop.run_until_complete(on_request_end(session, context, params)) - mock_report_trace.assert_called_once() - - @patch("runpod.tracer.report_trace") - def test_on_request_exception(self, mock_report_trace): - session = MagicMock() - context = SimpleNamespace(on_request_start=self.loop.time(), connect=0.5) - params = TraceRequestExceptionParams( - "GET", - URL("http://test.com/"), - headers={}, - exception=Exception("Test Exception"), - ) - - self.loop.run_until_complete(on_request_exception(session, context, params)) - mock_report_trace.assert_called_once() - assert context.exception - - @patch("runpod.tracer.log") - def test_report_trace(self, mock_log): - context = SimpleNamespace() - context.trace_id = "test-trace-id" - context.request_id = "test-request-id" - context.start_time = time() - context.end_time = time() + 2 - context.connect = 0.5 - context.payload_size_bytes = 1024 - context.response_size_bytes = 2048 - context.retries = 0 - context.trace_request_ctx = {"current_attempt": 0} - context.transfer = 1.0 - - params = MagicMock() - params.response.status = 200 - - elapsed = 1.5 - - expected_report = { - "trace_id": "test-trace-id", - "request_id": "test-request-id", - "connect": 500.0, - "payload_size_bytes": 1024, - "response_size_bytes": 2048, - "retries": 0, - "start_time": time_to_iso8601(context.start_time), - "end_time": time_to_iso8601(context.end_time), - "total": 1500.0, # 1.5 seconds to milliseconds - "transfer": 1000.0, # 1.5 - 0.5 seconds to milliseconds - "response_status": 200, - } - - report_trace(context, params, elapsed, mock_log.trace) - - assert expected_report == json.loads(mock_log.trace.call_args[0][0]) - - @patch("runpod.tracer.log") - def test_report_trace_error_log(self, mock_log): - context = SimpleNamespace() - context.trace_id = "test-trace-id" - context.request_id = "test-request-id" - context.start_time = time() - context.end_time = time() + 2 - context.connect = 0.5 - context.retries = 3 - context.trace_request_ctx = {"current_attempt": 3} - context.exception = str(Exception("Test Exception")) - context.transfer = 1.0 - - params = MagicMock() - params.response.status = 502 - - elapsed = 1.5 - - expected_report = { - "trace_id": "test-trace-id", - "request_id": "test-request-id", - "connect": 500.0, - "retries": 3, - "exception": "Test Exception", - "start_time": time_to_iso8601(context.start_time), - "end_time": time_to_iso8601(context.end_time), - "total": 1500.0, # 1.5 seconds to milliseconds - "transfer": 1000.0, # 1.5 - 0.5 seconds to milliseconds - "response_status": 502, - } - - report_trace(context, params, elapsed, mock_log.error) - - assert expected_report == json.loads(mock_log.error.call_args[0][0])