Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ __pycache__/

# C extensions
*.so
!sls_core.so

# Distribution / packaging
.Python
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Change Log

## Release 1.5.0 (12/28/23)

### Added

- Optional serverless core implementation, use with environment variable `RUNPOD_USE_CORE=True` or `RUNPOD_CORE_PATH=/path/to/core.so`

### Changed

- Reduced *await asyncio.sleep* calls to 0 to reduce execution time.

---

## Release 1.4.2 (12/14/23)

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ dependencies = { file = ["requirements.txt"] }

# Used by pytest coverage
[tool.coverage.run]
omit = ["runpod/_version.py"]
omit = ["runpod/_version.py", "runpod/serverless/core.py"]
11 changes: 9 additions & 2 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import argparse
from typing import Dict, Any

from runpod.serverless import core
from . import worker
from .modules import rp_fastapi
from .modules.rp_logger import RunPodLogger
Expand Down Expand Up @@ -125,7 +126,7 @@ def start(config: Dict[str, Any]):
realtime_concurrency = _get_realtime_concurrency()

if config["rp_args"]["rp_serve_api"]:
print("Starting API server.")
log.info("Starting API server.")
api_server = rp_fastapi.WorkerAPI(config)

api_server.start_uvicorn(
Expand All @@ -135,7 +136,7 @@ def start(config: Dict[str, Any]):
)

elif realtime_port:
print("Starting API server for realtime.")
log.info("Starting API server for realtime.")
api_server = rp_fastapi.WorkerAPI(config)

api_server.start_uvicorn(
Expand All @@ -144,5 +145,11 @@ def start(config: Dict[str, Any]):
api_concurrency=realtime_concurrency
)

# --------------------------------- SLS-Core --------------------------------- #
elif os.environ.get("RUNPOD_USE_CORE", None) or os.environ.get("RUNPOD_CORE_PATH", None):
log.info("Starting worker with SLS-Core.")
core.main(config)

# --------------------------------- Standard --------------------------------- #
else:
worker.main(config)
227 changes: 227 additions & 0 deletions runpod/serverless/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
""" Core functionality for the runpod serverless worker. """

import ctypes
import inspect
import json
import os
import pathlib
import asyncio
from ctypes import CDLL, byref, c_char_p, c_int
from typing import Any, Callable, List, Dict, Optional

from runpod.serverless.modules.rp_logger import RunPodLogger


log = RunPodLogger()


class CGetJobResult(ctypes.Structure): # pylint: disable=too-few-public-methods
"""
result of _runpod_sls_get_jobs.
## fields
- `res_len` the number bytes were written to the `dst_buf` passed to _runpod_sls_get_jobs.
- `status_code` tells you what happened.
see CGetJobResult.status_code for more information.
"""

_fields_ = [("status_code", ctypes.c_int), ("res_len", ctypes.c_int)]

def __str__(self) -> str:
return f"CGetJobResult(res_len={self.res_len}, status_code={self.status_code})"


class Hook: # pylint: disable=too-many-instance-attributes
""" Singleton class for interacting with sls_core.so"""

_instance = None

# C function pointers
_get_jobs: Callable = None
_progress_update: Callable = None
_stream_output: Callable = None
_post_output: Callable = None
_finish_stream: Callable = None

def __new__(cls):
if Hook._instance is None:
Hook._instance = object.__new__(cls)
Hook._initialized = False
return Hook._instance

def __init__(self, rust_so_path: Optional[str] = None) -> None:
if self._initialized:
return

if rust_so_path is None:
default_path = os.path.join(
pathlib.Path(__file__).parent.absolute(), "sls_core.so"
)
self.rust_so_path = os.environ.get("RUNPOD_SLS_CORE_PATH", str(default_path))
else:
self.rust_so_path = rust_so_path

rust_library = CDLL(self.rust_so_path)
buffer = ctypes.create_string_buffer(1024) # 1 KiB
num_bytes = rust_library._runpod_sls_crate_version(byref(buffer), c_int(len(buffer)))

self.rust_crate_version = buffer.raw[:num_bytes].decode("utf-8")

# Get Jobs
self._get_jobs = rust_library._runpod_sls_get_jobs
self._get_jobs.restype = CGetJobResult

# Progress Update
self._progress_update = rust_library._runpod_sls_progress_update
self._progress_update.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int # json_ptr, json_len
]
self._progress_update.restype = c_int # 1 if success, 0 if failure

# Stream Output
self._stream_output = rust_library._runpod_sls_stream_output
self._stream_output.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int, # json_ptr, json_len
]
self._stream_output.restype = c_int # 1 if success, 0 if failure

# Post Output
self._post_output = rust_library._runpod_sls_post_output
self._post_output.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int, # json_ptr, json_len
]
self._post_output.restype = c_int # 1 if success, 0 if failure

# Finish Stream
self._finish_stream = rust_library._runpod_sls_finish_stream
self._finish_stream.argtypes = [c_char_p, c_int] # id_ptr, id_len
self._finish_stream.restype = c_int # 1 if success, 0 if failure

rust_library._runpod_sls_crate_version.restype = c_int

rust_library._runpod_sls_init.argtypes = []
rust_library._runpod_sls_init.restype = c_int
rust_library._runpod_sls_init()

self._initialized = True

def _json_serialize_job_data(self, job_data: Any) -> bytes:
return json.dumps(job_data, ensure_ascii=False).encode("utf-8")

def get_jobs(self, max_concurrency: int, max_jobs: int) -> List[Dict[str, Any]]:
"""Get a job or jobs from the queue. The jobs are returned as a list of Job objects."""
buffer = ctypes.create_string_buffer(1024 * 1024 * 20) # 20MB buffer to store jobs in
destination_length = len(buffer.raw)
result: CGetJobResult = self._get_jobs(
c_int(max_concurrency), c_int(max_jobs),
byref(buffer), c_int(destination_length)
)
if result.status_code == 1: # success! the job was stored bytes 0..res_len of buf.raw
return list(json.loads(buffer.raw[: result.res_len].decode("utf-8")))

if result.status_code not in [0, 1]:
raise RuntimeError(f"get_jobs failed with status code {result.status_code}")

return [] # Status code 0, still waiting for jobs

def progress_update(self, job_id: str, json_data: bytes) -> bool:
"""
send a progress update to AI-API.
"""
id_bytes = job_id.encode("utf-8")
return bool(self._progress_update(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def stream_output(self, job_id: str, job_output: bytes) -> bool:
"""
send part of a streaming result to AI-API.
"""
json_data = self._json_serialize_job_data(job_output)
id_bytes = job_id.encode("utf-8")
return bool(self._stream_output(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def post_output(self, job_id: str, job_output: bytes) -> bool:
"""
send the result of a job to AI-API.
Returns True if the task was successfully stored, False otherwise.
"""
json_data = self._json_serialize_job_data(job_output)
id_bytes = job_id.encode("utf-8")
return bool(self._post_output(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def finish_stream(self, job_id: str) -> bool:
"""
tell the SLS queue that the result of a streaming job is complete.
"""
id_bytes = job_id.encode("utf-8")
return bool(self._finish_stream(
c_char_p(id_bytes), c_int(len(id_bytes))
))


# -------------------------------- Process Job ------------------------------- #
def _process_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
""" Process a single job. """
hook = Hook()

try:
result = handler(job)
except Exception as err:
raise RuntimeError(
f"run {job['id']}: user code raised an {type(err).__name__}") from err

if inspect.isgeneratorfunction(handler):
for part in result:
hook.stream_output(job['id'], part)

hook.finish_stream(job['id'])

else:
hook.post_output(job['id'], result)


# -------------------------------- Run Worker -------------------------------- #
async def run(config: Dict[str, Any]) -> None:
""" Run the worker.

Args:
config: A dictionary containing the following keys:
handler: A function that takes a job and returns a result.
"""
handler = config['handler']
max_concurrency = config.get('max_concurrency', 4)
max_jobs = config.get('max_jobs', 4)

hook = Hook()

while True:
jobs = hook.get_jobs(max_concurrency, max_jobs)

if len(jobs) == 0:
continue

for job in jobs:
asyncio.create_task(_process_job(handler, job))
await asyncio.sleep(0)

await asyncio.sleep(0)


def main(config: Dict[str, Any]) -> None:
"""Run the worker in an asyncio event loop."""
try:
work_loop = asyncio.new_event_loop()
asyncio.ensure_future(run(config), loop=work_loop)
work_loop.run_forever()
finally:
work_loop.close()
2 changes: 1 addition & 1 deletion runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]
if retry is False:
break

await asyncio.sleep(1)
await asyncio.sleep(0)
else:
job_list.add_job(next_job["id"])
log.debug("Request ID added.", next_job['id'])
Expand Down
3 changes: 1 addition & 2 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ async def get_jobs(self, session):
if job:
yield job

await asyncio.sleep(1)

await asyncio.sleep(0)

log.debug(f"Concurrency set to: {self.current_concurrency}")
Binary file added runpod/serverless/sls_core.so
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,15 @@ def mock_is_alive():
# 5 calls with actual jobs
assert mock_run_job.call_count == 5
assert mock_send_result.call_count == 5

# Test with sls-core
async def test_run_worker_with_sls_core(self):
'''
Test run_worker with sls-core.
'''
os.environ["RUNPOD_USE_CORE"] = "true"

with patch("runpod.serverless.core.main") as mock_main:
runpod.serverless.start(self.config)

assert mock_main.called