Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Official Python library for RunPod API & SDK.
- [Installation](#installation)
- [SDK - Serverless Worker](#sdk---serverless-worker)
- [Quick Start](#quick-start)
- [Local Test Worker](#local-test-worker)
- [API Language Library](#api-language-library)
- [Endpoints](#endpoints)
- [GPU Pod Control](#gpu-pod-control)
Expand All @@ -44,6 +45,8 @@ This python package can also be used to create a serverless worker that can be d
Create an python script in your project that contains your model definition and the RunPod worker start code. Run this python code as your default container start command:

```python
# my_worker.py

import runpod

def is_even(job):
Expand All @@ -66,6 +69,14 @@ Make sure that this file is ran when your container starts. This can be accompli

See our [blog post](https://www.runpod.io/blog/serverless-create-a-basic-api) for creating a basic Serverless API, or view the [details docs](https://docs.runpod.io/serverless-ai/custom-apis) for more information.

### Local Test Worker

You can also test your worker locally before deploying it to RunPod. This is useful for debugging and testing.

```bash
python my_worker.py --rp_serve_api
```

## API Language Library

When interacting with the RunPod API you can use this library to make requests to the API.
Expand Down
43 changes: 35 additions & 8 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@
prog="runpod",
description="Runpod Serverless Worker Arguments."
)
parser.add_argument("--test_input", type=str, default=None,
help="Test input for the worker, formatted as JSON.")
parser.add_argument("--rp_debugger", action="store_true", default=None,
help="Flag to enable the Debugger.")
parser.add_argument("rp_log_level", default=None,
parser.add_argument("--rp_log_level", type=str,
help="""Controls what level of logs are printed to the console.
Options: ERROR, WARN, INFO, and DEBUG.""")

parser.add_argument("--rp_debugger", action="store_true", default=None,
help="Flag to enable the Debugger.")

parser.add_argument("--rp_serve_api", action="store_true", default=None,
help="Flag to start the API server.")
parser.add_argument("--rp_api_port", type=int, default=8000,
help="Port to start the FastAPI server on.")
parser.add_argument("--rp_api_concurrency", type=int, default=1,
help="Number of concurrent FastAPI workers.")
parser.add_argument("--rp_api_host", type=str, default="localhost",
help="Host to start the FastAPI server on.")

parser.add_argument("--test_input", type=str, default=None,
help="Test input for the worker, formatted as JSON.")


def _set_config_args(config) -> dict:
"""
Expand All @@ -47,7 +58,7 @@ def _set_config_args(config) -> dict:

# Set the log level
if config["rp_args"]["rp_log_level"]:
log.set_level(config["rp_args"]["rp_debug_level"])
log.set_level(config["rp_args"]["rp_log_level"])

return config

Expand All @@ -73,17 +84,33 @@ def start(config):
"""
Starts the serverless worker.
"""
print("--- Starting Serverless Worker ---")

config["reference_counter_start"] = time.perf_counter()
config = _set_config_args(config)

realtime_port = _get_realtime_port()
realtime_concurrency = _get_realtime_concurrency()

if realtime_port:
if config["rp_args"]["rp_serve_api"]:
api_server = rp_fastapi.WorkerAPI()
api_server.config = config

api_server.start_uvicorn(
api_host=config['rp_args']['rp_api_host'],
api_port=config['rp_args']['rp_api_port'],
api_concurrency=config['rp_args']['rp_api_concurrency']
)

elif realtime_port:
api_server = rp_fastapi.WorkerAPI()
api_server.config = config

api_server.start_uvicorn(realtime_port, realtime_concurrency)
api_server.start_uvicorn(
api_host='0.0.0.0',
api_port=realtime_port,
api_concurrency=realtime_concurrency
)

else:
asyncio.run(work_loop.start_worker(config))
5 changes: 3 additions & 2 deletions runpod/serverless/modules/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, Generator, Optional, Union

import os
import sys
import json
import traceback
from aiohttp import ClientSession
Expand All @@ -21,8 +22,8 @@ def _get_local() -> Optional[Dict[str, Any]]:
Returns contents of test_input.json.
"""
if not os.path.exists("test_input.json"):
log.warn("test_input.json not found, skipping local testing")
return None
log.warn("test_input.json not found, exiting.")
sys.exit(1)

with open("test_input.json", "r", encoding="UTF-8") as file:
test_inputs = json.loads(file.read())
Expand Down
74 changes: 65 additions & 9 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,37 @@
from .worker_state import set_job_id
from .heartbeat import HeartbeatSender

RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None)

DESCRIPTION = """
This API server is provided as a method of testing and debugging your worker locally.
Additionally, you can use this to test code that will be making requests to your worker.

### Endpoints

The URLs provided are named to match the endpoints that you will be provided when running on RunPod.

---

*Note: When running your worker on the RunPod platform, this API server will not be used.*
"""


heartbeat = HeartbeatSender()


class Job(BaseModel):
''' Represents a job. '''
id: str
input: dict
input: dict | list | str | int | float | bool


class TestJob(BaseModel):
''' Represents a test job.
input can be any type of data.
'''
id: str = "test_job"
input: dict | list | str | int | float | bool


class WorkerAPI:
Expand All @@ -38,25 +61,39 @@ def __init__(self, handler=None):
self.config = {"handler": handler}

# Initialize the FastAPI web server.
self.rp_app = FastAPI()

try:
import runpod # pylint: disable=import-outside-toplevel,cyclic-import
runpod_version = runpod.__version__
except AttributeError:
runpod_version = "0.0.0"

self.rp_app = FastAPI(
title="RunPod | Test Worker | API",
description=DESCRIPTION,
version=runpod_version,
)

# Create an APIRouter and add the route for processing jobs.
api_router = APIRouter()
api_router.add_api_route(
f"/{os.environ.get('RUNPOD_ENDPOINT_ID')}/realtime",
self.run, methods=["POST"]
)

if RUNPOD_ENDPOINT_ID:
api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime", self.run, methods=["POST"])

api_router.add_api_route("/runsync", self.test_run, methods=["POST"])

# Include the APIRouter in the FastAPI application.
self.rp_app.include_router(api_router)

def start_uvicorn(self, api_port, api_concurrency):
def start_uvicorn(self, api_host='localhost', api_port=8000, api_concurrency=1):
'''
Starts the Uvicorn server.
'''
uvicorn.run(
self.rp_app, host='0.0.0.0',
port=int(api_port), workers=int(api_concurrency)
self.rp_app, host=api_host,
port=int(api_port), workers=int(api_concurrency),
log_level="info",
access_log=False
)

async def run(self, job: Job):
Expand All @@ -78,3 +115,22 @@ async def run(self, job: Job):

# Return the results of the job processing.
return jsonable_encoder(job_results)

async def test_run(self, job: TestJob):
'''
Performs model inference on the input data using the provided handler.
'''
if self.config["handler"] is None:
return {"error": "Handler not provided"}

# Set the current job ID.
set_job_id(job.id)

job_results = run_job(self.config["handler"], job.__dict__)

job_results["id"] = job.id

# Reset the job ID.
set_job_id(None)

return jsonable_encoder(job_results)