Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CD-publish_to_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
"runpod-workers/worker-controlnet",
"runpod-workers/worker-blip",
"runpod-workers/worker-deforum",
runpod-workers/mock-worker,
]

runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Added

- BETA: CLI DevEx functionality to create development projects.
- `test_output` can be passed in as an arg to compare the results of `test_input`
- Generator/Streaming handlers supported with local testing

## Release 1.3.0 (10/12/23)

Expand Down
4 changes: 4 additions & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _set_config_args(config) -> dict:
if config["rp_args"]["test_input"]:
config["rp_args"]["test_input"] = json.loads(config["rp_args"]["test_input"])

# Parse the test output from JSON
if config["rp_args"].get("test_output", None):
config["rp_args"]["test_output"] = json.loads(config["rp_args"]["test_output"])

# Set the log level
if config["rp_args"]["rp_log_level"]:
log.set_level(config["rp_args"]["rp_log_level"])
Expand Down
11 changes: 9 additions & 2 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel

from .rp_job import run_job
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .worker_state import Jobs
from .rp_ping import Heartbeat
from ...version import __version__ as runpod_version
Expand Down Expand Up @@ -125,7 +126,13 @@ async def _debug_run(self, job: TestJob):
# Set the current job ID.
job_list.add_job(job.id)

job_results = await run_job(self.config["handler"], job.__dict__)
if is_generator(self.config["handler"]):
generator_output = run_job_generator(self.config["handler"], job.__dict__)
job_results = {"output": []}
async for stream_output in generator_output:
job_results["output"].append(stream_output["output"])
else:
job_results = await run_job(self.config["handler"], job.__dict__)

job_results["id"] = job.id

Expand Down
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Retrieve handler info. """

import inspect
from typing import Callable

def is_generator(handler: Callable) -> bool:
"""Check if handler is a generator function. """
return inspect.isgeneratorfunction(handler) or inspect.isasyncgenfunction(handler)
8 changes: 8 additions & 0 deletions runpod/serverless/modules/rp_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,13 @@ async def run_local(config: Dict[str, Any]) -> None:
log.info(f"Job {local_job['id']} completed successfully.")
log.info(f"Job result: {job_result}")

# Compare to sample output, if provided
if config['rp_args'].get('test_output', None):
log.info("test_output set, comparing output to test_output.")
if job_result != config['rp_args']['test_output']:
log.error("Job output does not match test_output.")
sys.exit(1)
log.info("Job output matches test_output.")

log.info("Local testing complete, exiting.")
sys.exit(0)
4 changes: 4 additions & 0 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def start_ping(self, test=False):
'''
Sends heartbeat pings to the Runpod server.
'''
if os.environ.get('RUNPOD_AI_API_KEY') is None:
log.debug("Not deployed on RunPod serverless, pings will not be sent.")
return

if os.environ.get('RUNPOD_POD_ID') is None:
log.info("Not running on RunPod, pings will not be sent.")
return
Expand Down
7 changes: 3 additions & 4 deletions runpod/serverless/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""
import os
import asyncio
import inspect
from typing import Dict, Any

import aiohttp

from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.rp_scale import JobScaler
from .modules import rp_local
from .modules.rp_handler import is_generator
from .modules.rp_ping import Heartbeat
from .modules.rp_job import run_job, run_job_generator
from .modules.rp_http import send_result, stream_result
Expand Down Expand Up @@ -46,11 +46,10 @@ def _is_local(config) -> bool:


async def _process_job(job, session, job_scaler, config):
if inspect.isgeneratorfunction(config["handler"]) \
or inspect.isasyncgenfunction(config["handler"]):
if is_generator(config["handler"]):
generator_output = run_job_generator(config["handler"], job)

log.debug("Handler is a generator, streaming results.")

job_result = {'output': []}
async for stream_output in generator_output:
if 'error' in stream_output:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_serverless/test_modules/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ def test_run(self):

self.assertTrue(mock_ping.called)

# Test with generator handler
def generator_handler(job):
del job
yield {"result": "success"}

generator_worker_api = rp_fastapi.WorkerAPI(handler=generator_handler)
generator_run_return = asyncio.run(generator_worker_api._debug_run(job_object)) # pylint: disable=protected-access
assert generator_run_return == {
"id": "test_job_id",
"output": [{"result": "success"}]
}

loop.close()
33 changes: 33 additions & 0 deletions tests/test_serverless/test_modules/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
""" Unit tests for the handler module.
"""
import unittest

from runpod.serverless.modules.rp_handler import is_generator


class TestIsGenerator(unittest.TestCase):
"""Tests for the is_generator function."""

def test_regular_function(self):
"""Test that a regular function is not a generator."""
def regular_func():
return "I'm a regular function!"
self.assertFalse(is_generator(regular_func))

def test_generator_function(self):
"""Test that a generator function is a generator."""
def generator_func():
yield "I'm a generator function!"
self.assertTrue(is_generator(generator_func))

def test_async_function(self):
"""Test that an async function is not a generator."""
async def async_func():
return "I'm an async function!"
self.assertFalse(is_generator(async_func))

def test_async_generator_function(self):
"""Test that an async generator function is a generator."""
async def async_gen_func():
yield "I'm an async generator function!"
self.assertTrue(is_generator(async_gen_func))
12 changes: 10 additions & 2 deletions tests/test_serverless/test_modules/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class TestRunLocal(IsolatedAsyncioTestCase):
''' Tests for run_local function '''

@patch("runpod.serverless.modules.rp_local.run_job", return_value={})
@patch("runpod.serverless.modules.rp_local.run_job", return_value={"result": "success"})
@patch("builtins.open", new_callable=mock_open, read_data='{"input": "test"}')
async def test_run_local_with_test_input(self, mock_file, mock_run):
'''
Expand All @@ -21,12 +21,20 @@ async def test_run_local_with_test_input(self, mock_file, mock_run):
"test_input": {
"input": "test",
"id": "test_id"
},
"test_output": {
"result": "success"
}
}
}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 0)
self.assertEqual(sys_exit.exception.code, 0)

config["rp_args"]["test_output"] = {"result": "fail"}
with self.assertRaises(SystemExit) as sys_exit:
await rp_local.run_local(config)
self.assertEqual(sys_exit.exception.code, 1)

assert mock_file.called is False
assert mock_run.called
Expand Down
7 changes: 7 additions & 0 deletions tests/test_serverless/test_modules/test_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_start_ping(self, mock_get_return):
'''
Tests that the start_ping function works correctly
'''
# No RUNPOD_AI_API_KEY case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
assert mock_thread_start.call_count == 0

os.environ["RUNPOD_AI_API_KEY"] = "test_key"

# No RUNPOD_POD_ID case
with patch("threading.Thread.start") as mock_thread_start:
rp_ping.Heartbeat().start_ping(test=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_local_api(self):
'''
Test local FastAPI setup.
'''

known_args = argparse.Namespace()
known_args.rp_log_level = None
known_args.rp_debugger = None
Expand Down Expand Up @@ -126,6 +125,7 @@ def test_worker_bad_local(self):
known_args.rp_api_concurrency = 1
known_args.rp_api_host = "localhost"
known_args.test_input = '{"test": "test"}'
known_args.test_output = '{"test": "test"}'

with patch("argparse.ArgumentParser.parse_known_args") as mock_parse_known_args, \
self.assertRaises(SystemExit):
Expand Down