diff --git a/CHANGELOG.md b/CHANGELOG.md index b0ee2080..fd7235e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Change Log +## Release 1.4.1 (12/13/23) + +### Added + +- Local test API server includes simulated endpoints that mimic the behavior of `run`, `runsync`, `stream`, and `status`. +- Internal job tracker can be used to track job inputs. + +--- + ## Release 1.4.0 (12/4/23) ### Changed diff --git a/examples/serverless/simple_handler.py b/examples/serverless/simple_handler.py new file mode 100644 index 00000000..cc1c72fc --- /dev/null +++ b/examples/serverless/simple_handler.py @@ -0,0 +1,17 @@ +""" Simple Handler + +To setup a local API server, run the following command: +python simple_handler.py --rp_serve_api +""" + +import runpod + + +def handler(job): + """ Simple handler """ + job_input = job["input"] + + return f"Hello {job_input['name']}!" + + +runpod.serverless.start({"handler": handler}) diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index d16c345b..d1700715 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -126,8 +126,7 @@ def start(config: Dict[str, Any]): if config["rp_args"]["rp_serve_api"]: print("Starting API server.") - api_server = rp_fastapi.WorkerAPI() - api_server.config = config + api_server = rp_fastapi.WorkerAPI(config) api_server.start_uvicorn( api_host=config['rp_args']['rp_api_host'], @@ -137,8 +136,7 @@ def start(config: Dict[str, Any]): elif realtime_port: print("Starting API server for realtime.") - api_server = rp_fastapi.WorkerAPI() - api_server.config = config + api_server = rp_fastapi.WorkerAPI(config) api_server.start_uvicorn( api_host='0.0.0.0', diff --git a/runpod/serverless/modules/rp_fastapi.py b/runpod/serverless/modules/rp_fastapi.py index f1500cc9..6ffb7c59 100644 --- a/runpod/serverless/modules/rp_fastapi.py +++ b/runpod/serverless/modules/rp_fastapi.py @@ -2,11 +2,13 @@ # pylint: disable=too-few-public-methods import os -from typing import Union +import uuid +from typing import Union, Optional, Dict, Any import uvicorn from fastapi import FastAPI, APIRouter from fastapi.encoders import jsonable_encoder +from fastapi.responses import RedirectResponse from pydantic import BaseModel from .rp_handler import is_generator @@ -47,14 +49,39 @@ class TestJob(BaseModel): ''' Represents a test job. input can be any type of data. ''' - id: str = "test_job" - input: Union[dict, list, str, int, float, bool] + id: Optional[str] + input: Optional[Union[dict, list, str, int, float, bool]] + + +class DefaultInput(BaseModel): + """ Represents a test input. """ + input: Dict[str, Any] + + +# ------------------------------ Output Objects ------------------------------ # +class JobOutput(BaseModel): + ''' Represents the output of a job. ''' + id: str + status: str + output: Optional[Union[dict, list, str, int, float, bool]] + error: Optional[str] + + +class StreamOutput(BaseModel): + """ Stream representation of a job. """ + id: str + status: str = "IN_PROGRESS" + stream: Optional[Union[dict, list, str, int, float, bool]] + error: Optional[str] +# ---------------------------------------------------------------------------- # +# API Worker # +# ---------------------------------------------------------------------------- # class WorkerAPI: ''' Used to launch the FastAPI web server when the worker is running in API mode. ''' - def __init__(self, handler=None): + def __init__(self, config: Dict[str, Any]): ''' Initializes the WorkerAPI class. 1. Starts the heartbeat thread. @@ -64,23 +91,50 @@ def __init__(self, handler=None): # Start the heartbeat thread. heartbeat.start_ping() - # Set the handler for processing jobs. - self.config = {"handler": handler} + self.config = config # Initialize the FastAPI web server. self.rp_app = FastAPI( title="RunPod | Test Worker | API", description=DESCRIPTION, version=runpod_version, + docs_url="/" ) # Create an APIRouter and add the route for processing jobs. api_router = APIRouter() - if RUNPOD_ENDPOINT_ID: - api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime", self._run, methods=["POST"]) + # Docs Redirect /docs -> / + api_router.add_api_route( + "/docs", lambda: RedirectResponse(url="/"), + include_in_schema=False + ) - api_router.add_api_route("/runsync", self._debug_run, methods=["POST"]) + if RUNPOD_ENDPOINT_ID: + api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime", + self._realtime, methods=["POST"]) + + # Simulation endpoints. + api_router.add_api_route( + "/run", self._sim_run, methods=["POST"], response_model_exclude_none=True, + summary="Simulate run behavior.", + description="Returns job ID to be used with `/stream` and `/status` endpoints." + ) + api_router.add_api_route( + "/runsync", self._sim_runsync, methods=["POST"], response_model_exclude_none=True, + summary="Simulate runsync behavior.", + description="Returns job output directly when called." + ) + api_router.add_api_route( + "/stream/{job_id}", self._sim_stream, methods=["POST"], + response_model_exclude_none=True, summary="Simulate stream behavior.", + description="Aggregates the output of the job and returns it when the job is complete." + ) + api_router.add_api_route( + "/status/{job_id}", self._sim_status, methods=["POST"], + response_model_exclude_none=True, summary="Simulate status behavior.", + description="Returns the output of the job when the job is complete." + ) # Include the APIRouter in the FastAPI application. self.rp_app.include_router(api_router) @@ -96,47 +150,111 @@ def start_uvicorn(self, api_host='localhost', api_port=8000, api_concurrency=1): access_log=False ) - async def _run(self, job: Job): + # ----------------------------- Realtime Endpoint ---------------------------- # + async def _realtime(self, job: Job): ''' Performs model inference on the input data using the provided handler. If handler is not provided, returns an error message. ''' - if self.config["handler"] is None: - return {"error": "Handler not provided"} - - # Set the current job ID. job_list.add_job(job.id) - # Process the job using the provided handler. + # Process the job using the provided handler, passing in the job input. job_results = await run_job(self.config["handler"], job.__dict__) - # Reset the job ID. job_list.remove_job(job.id) # Return the results of the job processing. return jsonable_encoder(job_results) - async def _debug_run(self, job: TestJob): - ''' - Performs model inference on the input data using the provided handler. - ''' - if self.config["handler"] is None: - return {"error": "Handler not provided"} + # ---------------------------------------------------------------------------- # + # Simulation Endpoints # + # ---------------------------------------------------------------------------- # - # Set the current job ID. - job_list.add_job(job.id) + # ------------------------------------ run ----------------------------------- # + async def _sim_run(self, job_input: DefaultInput) -> JobOutput: + """ Development endpoint to simulate run behavior. """ + assigned_job_id = f"test-{uuid.uuid4()}" + job_list.add_job(assigned_job_id, job_input.input) + return jsonable_encoder({"id": assigned_job_id, "status": "IN_PROGRESS"}) + + # ---------------------------------- runsync --------------------------------- # + async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput: + """ Development endpoint to simulate runsync behavior. """ + assigned_job_id = f"test-{uuid.uuid4()}" + job = TestJob(id=assigned_job_id, input=job_input.input) if is_generator(self.config["handler"]): generator_output = run_job_generator(self.config["handler"], job.__dict__) - job_results = {"output": []} + job_output = {"output": []} async for stream_output in generator_output: - job_results["output"].append(stream_output["output"]) + job_output['output'].append(stream_output["output"]) else: - job_results = await run_job(self.config["handler"], job.__dict__) + job_output = await run_job(self.config["handler"], job.__dict__) + + return jsonable_encoder({ + "id": job.id, + "status": "COMPLETED", + "output": job_output['output'] + }) + + # ---------------------------------- stream ---------------------------------- # + async def _sim_stream(self, job_id: str) -> StreamOutput: + """ Development endpoint to simulate stream behavior. """ + job_input = job_list.get_job_input(job_id) + if job_input is None: + return jsonable_encoder({ + "id": job_id, + "status": "FAILED", + "error": "Job ID not found" + }) + + job = TestJob(id=job_id, input=job_input) - job_results["id"] = job.id + if is_generator(self.config["handler"]): + generator_output = run_job_generator(self.config["handler"], job.__dict__) + stream_accumulator = [] + async for stream_output in generator_output: + stream_accumulator.append({"output": stream_output["output"]}) + else: + return jsonable_encoder({ + "id": job_id, + "status": "FAILED", + "error": "Stream not supported, handler must be a generator." + }) - # Reset the job ID. job_list.remove_job(job.id) - return jsonable_encoder(job_results) + return jsonable_encoder({ + "id": job_id, + "status": "COMPLETED", + "stream": stream_accumulator + }) + + # ---------------------------------- status ---------------------------------- # + async def _sim_status(self, job_id: str) -> JobOutput: + """ Development endpoint to simulate status behavior. """ + job_input = job_list.get_job_input(job_id) + if job_input is None: + return jsonable_encoder({ + "id": job_id, + "status": "FAILED", + "error": "Job ID not found" + }) + + job = TestJob(id=job_id, input=job_input) + + if is_generator(self.config["handler"]): + generator_output = run_job_generator(self.config["handler"], job.__dict__) + job_output = {"output": []} + async for stream_output in generator_output: + job_output['output'].append(stream_output["output"]) + else: + job_output = await run_job(self.config["handler"], job.__dict__) + + job_list.remove_job(job.id) + + return jsonable_encoder({ + "id": job_id, + "status": "COMPLETED", + "output": job_output['output'] + }) diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index 30ac32eb..0a3cc15f 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -4,7 +4,7 @@ # pylint: disable=too-many-branches import inspect -from typing import Any, Callable, Dict, Generator, Optional, Union +from typing import Any, Callable, Dict, Optional, Union, AsyncGenerator import os import json @@ -179,9 +179,9 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: async def run_job_generator( handler: Callable, - job: Dict[str, Any]) -> Generator[Dict[str, Union[str, Any]], None, None]: + job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Union[str, Any]], None]: ''' - Run generator job. + Run generator job used to stream output. Yields output partials from the generator. ''' try: diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index db4e1d71..22a02dd5 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -5,6 +5,7 @@ import os import uuid import time +from typing import Optional, Dict, Any, Union REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger. @@ -22,6 +23,25 @@ def get_auth_header(): return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} +# ------------------------------- Job Tracking ------------------------------- # +class Job: + """ Represents a job. """ + + def __init__(self, job_id: str, job_input: Optional[Dict[str, Any]] = None) -> None: + self.job_id = job_id + self.job_input = job_input + + def __eq__(self, other: object) -> bool: + if isinstance(other, Job): + return self.job_id == other.job_id + return False + + def __hash__(self) -> int: + return hash(self.job_id) + + def __str__(self) -> str: + return self.job_id + class Jobs: ''' Track the state of current jobs.''' @@ -35,23 +55,31 @@ def __new__(cls): Jobs._instance.jobs = set() return Jobs._instance - def add_job(self, job_id): + def add_job(self, job_id, job_input=None): ''' Adds a job to the list of jobs. ''' - self.jobs.add(job_id) + self.jobs.add(Job(job_id, job_input)) def remove_job(self, job_id): ''' Removes a job from the list of jobs. ''' - self.jobs.remove(job_id) + self.jobs.remove(Job(job_id)) + + def get_job_input(self, job_id) -> Optional[Union[dict, list, str, int, float, bool]]: + ''' + Returns the job with the given id. + Used within rp_fastapi.py for local testing. + ''' + for job in self.jobs: + if job.job_id == job_id: + return job.job_input + + return None def get_job_list(self): ''' Returns the list of jobs as a string. ''' - if len(self.jobs) == 0: - return None - - return ','.join(list(self.jobs)) + return ','.join(str(job) for job in self.jobs) if self.jobs else None diff --git a/tests/test_serverless/test_modules/test_fastapi.py b/tests/test_serverless/test_modules/test_fastapi.py index d8dedc80..f71105d5 100644 --- a/tests/test_serverless/test_modules/test_fastapi.py +++ b/tests/test_serverless/test_modules/test_fastapi.py @@ -1,4 +1,6 @@ ''' Tests for runpod.serverless.modules.rp_fastapi.py ''' +# pylint: disable=protected-access + import os import asyncio @@ -9,6 +11,7 @@ import runpod from runpod.serverless.modules import rp_fastapi + class TestFastAPI(unittest.TestCase): ''' Tests the FastAPI ''' @@ -22,10 +25,9 @@ def test_start_serverless_with_realtime(self): ''' module_location = "runpod.serverless.modules.rp_fastapi" with patch(f"{module_location}.Heartbeat.start_ping", Mock()) as mock_ping, \ - patch(f"{module_location}.FastAPI", Mock()) as mock_fastapi, \ - patch(f"{module_location}.APIRouter", return_value=Mock()) as mock_router, \ - patch(f"{module_location}.uvicorn", Mock()) as mock_uvicorn: - + patch(f"{module_location}.FastAPI", Mock()) as mock_fastapi, \ + patch(f"{module_location}.APIRouter", return_value=Mock()) as mock_router, \ + patch(f"{module_location}.uvicorn", Mock()) as mock_uvicorn: rp_fastapi.RUNPOD_REALTIME_PORT = '1111' rp_fastapi.RUNPOD_ENDPOINT_ID = 'test_endpoint_id' @@ -45,7 +47,6 @@ def test_start_serverless_with_realtime(self): self.assertTrue(mock_uvicorn.run.called) - @pytest.mark.asyncio def test_run(self): ''' @@ -55,35 +56,31 @@ def test_run(self): module_location = "runpod.serverless.modules.rp_fastapi" with patch(f"{module_location}.Heartbeat.start_ping", Mock()) as mock_ping, \ - patch(f"{module_location}.FastAPI", Mock()), \ - patch(f"{module_location}.APIRouter", return_value=Mock()), \ - patch(f"{module_location}.uvicorn", Mock()): + patch(f"{module_location}.FastAPI", Mock()), \ + patch(f"{module_location}.APIRouter", return_value=Mock()), \ + patch(f"{module_location}.uvicorn", Mock()), \ + patch(f"{module_location}.uuid.uuid4", return_value="123"): job_object = rp_fastapi.Job( id="test_job_id", input={"test_input": "test_input"} ) - # Test without handler - worker_api_without_handler = rp_fastapi.WorkerAPI() - - handlerless_run_return = asyncio.run(worker_api_without_handler._run(job_object)) # pylint: disable=protected-access - assert handlerless_run_return == {"error": "Handler not provided"} - - handlerless_debug_run = asyncio.run(worker_api_without_handler._debug_run(job_object)) # pylint: disable=protected-access - assert handlerless_debug_run == {"error": "Handler not provided"} + default_input_object = rp_fastapi.DefaultInput( + input={"test_input": "test_input"} + ) # Test with handler - worker_api = rp_fastapi.WorkerAPI(handler=self.handler) + worker_api = rp_fastapi.WorkerAPI({"handler": self.handler}) - run_return = asyncio.run(worker_api._run(job_object)) # pylint: disable=protected-access + run_return = asyncio.run(worker_api._realtime(job_object)) assert run_return == {"output": {"result": "success"}} - debug_run_return = asyncio.run(worker_api._debug_run(job_object)) # pylint: disable=protected-access + debug_run_return = asyncio.run(worker_api._sim_run(default_input_object)) assert debug_run_return == { - "id": "test_job_id", - "output": {"result": "success"} - } + "id": "test-123", + "status": "IN_PROGRESS" + } self.assertTrue(mock_ping.called) @@ -92,11 +89,158 @@ def generator_handler(job): del job yield {"result": "success"} - generator_worker_api = rp_fastapi.WorkerAPI(handler=generator_handler) - generator_run_return = asyncio.run(generator_worker_api._debug_run(job_object)) # pylint: disable=protected-access + generator_worker_api = rp_fastapi.WorkerAPI({"handler": generator_handler}) + generator_run_return = asyncio.run(generator_worker_api._sim_run(default_input_object)) assert generator_run_return == { - "id": "test_job_id", - "output": [{"result": "success"}] - } + "id": "test-123", + "status": "IN_PROGRESS" + } + + loop.close() + + @pytest.mark.asyncio + def test_runsync(self): + ''' + Tests the _runsync() method. + ''' + loop = asyncio.get_event_loop() + + module_location = "runpod.serverless.modules.rp_fastapi" + with patch(f"{module_location}.FastAPI", Mock()), \ + patch(f"{module_location}.APIRouter", return_value=Mock()), \ + patch(f"{module_location}.uvicorn", Mock()), \ + patch(f"{module_location}.uuid.uuid4", return_value="123"): + + default_input_object = rp_fastapi.DefaultInput( + input={"test_input": "test_input"} + ) + + # Test with handler + worker_api = rp_fastapi.WorkerAPI({"handler": self.handler}) + + runsync_return = asyncio.run(worker_api._sim_runsync(default_input_object)) + assert runsync_return == { + "id": "test-123", + "status": "COMPLETED", + "output": {"result": "success"} + } + + # Test with generator handler + def generator_handler(job): + del job + yield {"result": "success"} + + generator_worker_api = rp_fastapi.WorkerAPI({"handler": generator_handler}) + generator_runsync_return = asyncio.run( + generator_worker_api._sim_runsync(default_input_object)) + assert generator_runsync_return == { + "id": "test-123", + "status": "COMPLETED", + "output": [{"result": "success"}] + } + + loop.close() + + @pytest.mark.asyncio + def test_stream(self): + ''' + Tests the _stream() method. + ''' + loop = asyncio.get_event_loop() + + module_location = "runpod.serverless.modules.rp_fastapi" + with patch(f"{module_location}.FastAPI", Mock()), \ + patch(f"{module_location}.APIRouter", return_value=Mock()), \ + patch(f"{module_location}.uvicorn", Mock()), \ + patch(f"{module_location}.uuid.uuid4", return_value="123"): + + default_input_object = rp_fastapi.DefaultInput( + input={"test_input": "test_input"} + ) + + worker_api = rp_fastapi.WorkerAPI({"handler": self.handler}) + + # Add job to job_list + asyncio.run(worker_api._sim_run(default_input_object)) + + stream_return = asyncio.run(worker_api._sim_stream("test_job_id")) + assert stream_return == { + "id": "test_job_id", + "status": "FAILED", + "error": "Job ID not found" + } + + stream_return = asyncio.run(worker_api._sim_stream("test-123")) + assert stream_return == { + "id": "test-123", + "status": "FAILED", + "error": "Stream not supported, handler must be a generator." + } + + # Test with generator handler + def generator_handler(job): + del job + yield {"result": "success"} + + generator_worker_api = rp_fastapi.WorkerAPI({"handler": generator_handler}) + generator_stream_return = asyncio.run( + generator_worker_api._sim_stream("test-123")) + assert generator_stream_return == { + "id": "test-123", + "status": "COMPLETED", + "stream": [{"output": {"result": "success"}}] + } + + loop.close() + + @pytest.mark.asyncio + def test_status(self): + ''' + Tests the _status() method. + ''' + loop = asyncio.get_event_loop() + + module_location = "runpod.serverless.modules.rp_fastapi" + with patch(f"{module_location}.FastAPI", Mock()), \ + patch(f"{module_location}.APIRouter", return_value=Mock()), \ + patch(f"{module_location}.uvicorn", Mock()), \ + patch(f"{module_location}.uuid.uuid4", return_value="123"): + + worker_api = rp_fastapi.WorkerAPI({"handler": self.handler}) + + default_input_object = rp_fastapi.DefaultInput( + input={"test_input": "test_input"} + ) + + # Add job to job_list + asyncio.run(worker_api._sim_run(default_input_object)) + + status_return = asyncio.run(worker_api._sim_status("test_job_id")) + assert status_return == { + "id": "test_job_id", + "status": "FAILED", + "error": "Job ID not found" + } + + status_return = asyncio.run(worker_api._sim_status("test-123")) + assert status_return == { + "id": "test-123", + "status": "COMPLETED", + "output": {"result": "success"} + } + + # Test with generator handler + def generator_handler(job): + del job + yield {"result": "success"} + generator_worker_api = rp_fastapi.WorkerAPI({"handler": generator_handler}) + asyncio.run(generator_worker_api._sim_run(default_input_object)) + generator_stream_return = asyncio.run( + generator_worker_api._sim_status("test-123")) + assert generator_stream_return == { + "id": "test-123", + "status": "COMPLETED", + "output": [{"result": "success"}] + } loop.close() diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index 61dd544d..445b53f5 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -3,7 +3,10 @@ import os import unittest -from runpod.serverless.modules.worker_state import Jobs, IS_LOCAL_TEST, WORKER_ID, get_auth_header +from runpod.serverless.modules.worker_state import ( + Job, Jobs, IS_LOCAL_TEST, WORKER_ID, get_auth_header +) + class TestEnvVars(unittest.TestCase): ''' Tests for environment variables module ''' @@ -36,6 +39,7 @@ def test_get_auth_header(self): ''' self.assertEqual(get_auth_header(), {'Authorization': self.test_api_key}) + class TestJobs(unittest.TestCase): ''' Tests for Jobs class ''' @@ -58,7 +62,7 @@ def test_add_job(self): Tests if add_job() method works as expected ''' self.jobs.add_job('123') - self.assertIn('123', self.jobs.jobs) + self.assertIn(Job('123'), self.jobs.jobs) def test_remove_job(self): ''' @@ -66,7 +70,23 @@ def test_remove_job(self): ''' self.jobs.add_job('123') self.jobs.remove_job('123') - self.assertNotIn('123', self.jobs.jobs) + self.assertNotIn(Job('123'), self.jobs.jobs) + + def test_get_job_input(self): + ''' + Tests if get_job_input() method works as expected + ''' + job1 = Job(job_id="id1") + job2 = Job(job_id="id2") + self.assertNotEqual(job1, job2) + + job = Job(job_id="id1") + non_job_object = "some_string" + self.assertNotEqual(job, non_job_object) + + self.assertEqual(self.jobs.get_job_input('123'), None) + self.jobs.add_job('123', 'test_input') + self.assertEqual(self.jobs.get_job_input('123'), 'test_input') def test_get_job_list(self): ''' @@ -77,7 +97,7 @@ def test_get_job_list(self): self.jobs.add_job('123') self.jobs.add_job('456') self.assertEqual(len(self.jobs.jobs), 2) - self.assertTrue('123' in self.jobs.jobs) - self.assertTrue('456' in self.jobs.jobs) + self.assertTrue(Job('123') in self.jobs.jobs) + self.assertTrue(Job('456') in self.jobs.jobs) self.assertTrue(self.jobs.get_job_list() in ['123,456', '456,123'])