diff --git a/README.md b/README.md index 936a26c1..7accd986 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Official Python library for RunPod API & SDK. - [Installation](#installation) - [SDK - Serverless Worker](#sdk---serverless-worker) - [Quick Start](#quick-start) + - [Local Test Worker](#local-test-worker) - [API Language Library](#api-language-library) - [Endpoints](#endpoints) - [GPU Pod Control](#gpu-pod-control) @@ -44,6 +45,8 @@ This python package can also be used to create a serverless worker that can be d Create an python script in your project that contains your model definition and the RunPod worker start code. Run this python code as your default container start command: ```python +# my_worker.py + import runpod def is_even(job): @@ -66,6 +69,14 @@ Make sure that this file is ran when your container starts. This can be accompli See our [blog post](https://www.runpod.io/blog/serverless-create-a-basic-api) for creating a basic Serverless API, or view the [details docs](https://docs.runpod.io/serverless-ai/custom-apis) for more information. +### Local Test Worker + +You can also test your worker locally before deploying it to RunPod. This is useful for debugging and testing. + +```bash +python my_worker.py --rp_serve_api +``` + ## API Language Library When interacting with the RunPod API you can use this library to make requests to the API. diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index aa3ee390..95546a04 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -21,14 +21,25 @@ prog="runpod", description="Runpod Serverless Worker Arguments." ) -parser.add_argument("--test_input", type=str, default=None, - help="Test input for the worker, formatted as JSON.") -parser.add_argument("--rp_debugger", action="store_true", default=None, - help="Flag to enable the Debugger.") -parser.add_argument("rp_log_level", default=None, +parser.add_argument("--rp_log_level", type=str, help="""Controls what level of logs are printed to the console. Options: ERROR, WARN, INFO, and DEBUG.""") +parser.add_argument("--rp_debugger", action="store_true", default=None, + help="Flag to enable the Debugger.") + +parser.add_argument("--rp_serve_api", action="store_true", default=None, + help="Flag to start the API server.") +parser.add_argument("--rp_api_port", type=int, default=8000, + help="Port to start the FastAPI server on.") +parser.add_argument("--rp_api_concurrency", type=int, default=1, + help="Number of concurrent FastAPI workers.") +parser.add_argument("--rp_api_host", type=str, default="localhost", + help="Host to start the FastAPI server on.") + +parser.add_argument("--test_input", type=str, default=None, + help="Test input for the worker, formatted as JSON.") + def _set_config_args(config) -> dict: """ @@ -47,7 +58,7 @@ def _set_config_args(config) -> dict: # Set the log level if config["rp_args"]["rp_log_level"]: - log.set_level(config["rp_args"]["rp_debug_level"]) + log.set_level(config["rp_args"]["rp_log_level"]) return config @@ -73,17 +84,33 @@ def start(config): """ Starts the serverless worker. """ + print("--- Starting Serverless Worker ---") + config["reference_counter_start"] = time.perf_counter() config = _set_config_args(config) realtime_port = _get_realtime_port() realtime_concurrency = _get_realtime_concurrency() - if realtime_port: + if config["rp_args"]["rp_serve_api"]: + api_server = rp_fastapi.WorkerAPI() + api_server.config = config + + api_server.start_uvicorn( + api_host=config['rp_args']['rp_api_host'], + api_port=config['rp_args']['rp_api_port'], + api_concurrency=config['rp_args']['rp_api_concurrency'] + ) + + elif realtime_port: api_server = rp_fastapi.WorkerAPI() api_server.config = config - api_server.start_uvicorn(realtime_port, realtime_concurrency) + api_server.start_uvicorn( + api_host='0.0.0.0', + api_port=realtime_port, + api_concurrency=realtime_concurrency + ) else: asyncio.run(work_loop.start_worker(config)) diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 872208a1..7667804d 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Generator, Optional, Union import os +import sys import json import traceback from aiohttp import ClientSession @@ -21,8 +22,8 @@ def _get_local() -> Optional[Dict[str, Any]]: Returns contents of test_input.json. """ if not os.path.exists("test_input.json"): - log.warn("test_input.json not found, skipping local testing") - return None + log.warn("test_input.json not found, exiting.") + sys.exit(1) with open("test_input.json", "r", encoding="UTF-8") as file: test_inputs = json.loads(file.read()) diff --git a/runpod/serverless/modules/rp_fastapi.py b/runpod/serverless/modules/rp_fastapi.py index 0d58b15a..33082a26 100644 --- a/runpod/serverless/modules/rp_fastapi.py +++ b/runpod/serverless/modules/rp_fastapi.py @@ -11,6 +11,21 @@ from .worker_state import set_job_id from .heartbeat import HeartbeatSender +RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None) + +DESCRIPTION = """ +This API server is provided as a method of testing and debugging your worker locally. +Additionally, you can use this to test code that will be making requests to your worker. + +### Endpoints + +The URLs provided are named to match the endpoints that you will be provided when running on RunPod. + +--- + +*Note: When running your worker on the RunPod platform, this API server will not be used.* +""" + heartbeat = HeartbeatSender() @@ -18,7 +33,15 @@ class Job(BaseModel): ''' Represents a job. ''' id: str - input: dict + input: dict | list | str | int | float | bool + + +class TestJob(BaseModel): + ''' Represents a test job. + input can be any type of data. + ''' + id: str = "test_job" + input: dict | list | str | int | float | bool class WorkerAPI: @@ -38,25 +61,39 @@ def __init__(self, handler=None): self.config = {"handler": handler} # Initialize the FastAPI web server. - self.rp_app = FastAPI() + + try: + import runpod # pylint: disable=import-outside-toplevel,cyclic-import + runpod_version = runpod.__version__ + except AttributeError: + runpod_version = "0.0.0" + + self.rp_app = FastAPI( + title="RunPod | Test Worker | API", + description=DESCRIPTION, + version=runpod_version, + ) # Create an APIRouter and add the route for processing jobs. api_router = APIRouter() - api_router.add_api_route( - f"/{os.environ.get('RUNPOD_ENDPOINT_ID')}/realtime", - self.run, methods=["POST"] - ) + + if RUNPOD_ENDPOINT_ID: + api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime", self.run, methods=["POST"]) + + api_router.add_api_route("/runsync", self.test_run, methods=["POST"]) # Include the APIRouter in the FastAPI application. self.rp_app.include_router(api_router) - def start_uvicorn(self, api_port, api_concurrency): + def start_uvicorn(self, api_host='localhost', api_port=8000, api_concurrency=1): ''' Starts the Uvicorn server. ''' uvicorn.run( - self.rp_app, host='0.0.0.0', - port=int(api_port), workers=int(api_concurrency) + self.rp_app, host=api_host, + port=int(api_port), workers=int(api_concurrency), + log_level="info", + access_log=False ) async def run(self, job: Job): @@ -78,3 +115,22 @@ async def run(self, job: Job): # Return the results of the job processing. return jsonable_encoder(job_results) + + async def test_run(self, job: TestJob): + ''' + Performs model inference on the input data using the provided handler. + ''' + if self.config["handler"] is None: + return {"error": "Handler not provided"} + + # Set the current job ID. + set_job_id(job.id) + + job_results = run_job(self.config["handler"], job.__dict__) + + job_results["id"] = job.id + + # Reset the job ID. + set_job_id(None) + + return jsonable_encoder(job_results)