From ae194ef32ddeb0d2df110a98145547823e106f04 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 Jun 2023 15:17:02 -0400 Subject: [PATCH 1/4] support for streaming job: --- runpod/serverless/modules/job.py | 34 +++++++++++++++++++---- runpod/serverless/modules/worker_state.py | 10 +++++++ runpod/serverless/work_loop.py | 7 ++++- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 9648b4c8..cec6c35c 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -6,9 +6,10 @@ import time import json import traceback +import types import runpod.serverless.modules.logging as log -from .worker_state import JOB_GET_URL, get_done_url +from .worker_state import JOB_GET_URL, get_done_url, get_stream_url from .retry import retry from .rp_tips import check_return_size @@ -86,7 +87,12 @@ def run_job(handler, job): job_output = handler(job) log.debug(f'Job {job["id"]} handler output: {job_output}') - if isinstance(job_output, bool): + if isinstance(job_output, types.GeneratorType): + # should this emit the whole job stream to this point? + for ji_partial in job_output: + yield {"output": ji_partial} + return + elif isinstance(job_output, bool): run_result = {"output": job_output} elif "error" in job_output: run_result = {"error": str(job_output["error"])} @@ -113,7 +119,7 @@ def run_job(handler, job): @retry(max_attempts=3, base_delay=1, max_delay=3) -async def retry_send_result(session, job_data): +async def retry_send_result(session, job_data, url): """ Wrapper for sending results. """ @@ -123,7 +129,7 @@ async def retry_send_result(session, job_data): } log.debug("Initiating result API call") - async with session.post(get_done_url(), + async with session.post(url, data=job_data, headers=headers, raise_for_status=True) as resp: @@ -141,7 +147,25 @@ async def send_result(session, job_data, job): job_data = json.dumps(job_data, ensure_ascii=False) if not _IS_LOCAL_TEST: log.info(f"Sending job results for {job['id']}: {job_data}") - await retry_send_result(session, job_data) + await retry_send_result(session, job_data, get_done_url()) + else: + log.warn(f"Local test job results for {job['id']}: {job_data}") + + except Exception as err: # pylint: disable=broad-except + log.error(f"Error while returning job result {job['id']}: {err}") + else: + log.info(f"Successfully returned job result {job['id']}") + + +async def send_stream_result(session, job_data, job): + ''' + Return the stream job results. + ''' + try: + job_data = json.dumps(job_data, ensure_ascii=False) + if not _IS_LOCAL_TEST: + log.info(f"Sending job results for {job['id']}: {job_data}") + await retry_send_result(session, job_data, get_stream_url()) else: log.warn(f"Local test job results for {job['id']}: {job_data}") diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index 70f6391f..588cc944 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -24,6 +24,9 @@ def get_auth_header(): JOB_DONE_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT')) JOB_DONE_URL_TEMPLATE = JOB_DONE_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID) +JOB_STREAM_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_STREAM')) +JOB_STREAM_URL_TEMPLATE = JOB_STREAM_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID) + WEBHOOK_PING = os.environ.get('RUNPOD_WEBHOOK_PING', None) if WEBHOOK_PING is not None: PING_URL = WEBHOOK_PING.replace('$RUNPOD_POD_ID', WORKER_ID) @@ -47,6 +50,13 @@ def get_done_url(): return JOB_DONE_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID) +def get_stream_url(): + ''' + Constructs the stream URL using the current job id. + ''' + return JOB_STREAM_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID) + + def set_job_id(new_job_id): ''' Sets the current job id. diff --git a/runpod/serverless/work_loop.py b/runpod/serverless/work_loop.py index 802a8e95..93505c81 100644 --- a/runpod/serverless/work_loop.py +++ b/runpod/serverless/work_loop.py @@ -5,12 +5,13 @@ import os import sys +import types import aiohttp import runpod.serverless.modules.logging as log from .modules.heartbeat import HeartbeatSender -from .modules.job import get_job, run_job, send_result +from .modules.job import get_job, run_job, send_result, send_stream_result from .modules.worker_state import REF_COUNT_ZERO, set_job_id from .utils import rp_debugger @@ -60,6 +61,10 @@ async def start_worker(config): job_result = {"error": error_msg} else: job_result = run_job(config["handler"], job) + # check if job result is a generator + if isinstance(job_result, types.GeneratorType): + for job_stream in job_result: + await send_stream_result(session, job_stream, job) # If refresh_worker is set, pod will be reset after job is complete. if config.get("refresh_worker", False): From f0fc7b787d4eacaca5b0185d0bec187c9f380b52 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 6 Jun 2023 17:46:13 -0400 Subject: [PATCH 2/4] fix job completion --- runpod/serverless/modules/job.py | 1 - runpod/serverless/work_loop.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index cec6c35c..3084b3ee 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -88,7 +88,6 @@ def run_job(handler, job): log.debug(f'Job {job["id"]} handler output: {job_output}') if isinstance(job_output, types.GeneratorType): - # should this emit the whole job stream to this point? for ji_partial in job_output: yield {"output": ji_partial} return diff --git a/runpod/serverless/work_loop.py b/runpod/serverless/work_loop.py index 93505c81..81be9244 100644 --- a/runpod/serverless/work_loop.py +++ b/runpod/serverless/work_loop.py @@ -65,6 +65,7 @@ async def start_worker(config): if isinstance(job_result, types.GeneratorType): for job_stream in job_result: await send_stream_result(session, job_stream, job) + job_result = None # If refresh_worker is set, pod will be reset after job is complete. if config.get("refresh_worker", False): From 15e6a0f7fbbf90ce80252acf0fed83fc3a32a44b Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Mon, 12 Jun 2023 12:31:41 -0400 Subject: [PATCH 3/4] fix: refactored generator stream --- runpod/serverless/modules/job.py | 72 +++-------------------- runpod/serverless/modules/rp_http.py | 66 +++++++++++++++++++++ runpod/serverless/modules/worker_state.py | 24 ++++---- runpod/serverless/work_loop.py | 6 +- 4 files changed, 91 insertions(+), 77 deletions(-) create mode 100644 runpod/serverless/modules/rp_http.py diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 3084b3ee..8ce54c19 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -9,12 +9,9 @@ import types import runpod.serverless.modules.logging as log -from .worker_state import JOB_GET_URL, get_done_url, get_stream_url -from .retry import retry +from .worker_state import IS_LOCAL_TEST, JOB_GET_URL from .rp_tips import check_return_size -_IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None - def _get_local(): """ @@ -45,7 +42,7 @@ async def get_job(session, config): log.warn("test_input set, using test_input as job input") next_job = config["rp_args"]["test_input"] next_job["id"] = "test_input_provided" - elif _IS_LOCAL_TEST: + elif IS_LOCAL_TEST: log.warn("RUNPOD_WEBHOOK_GET_JOB not set, switching to get_local") next_job = _get_local() else: @@ -87,11 +84,13 @@ def run_job(handler, job): job_output = handler(job) log.debug(f'Job {job["id"]} handler output: {job_output}') + # Generator type is used for streaming jobs. if isinstance(job_output, types.GeneratorType): - for ji_partial in job_output: - yield {"output": ji_partial} + for output_partial in job_output: + yield {"output": output_partial} return - elif isinstance(job_output, bool): + + if isinstance(job_output, bool): run_result = {"output": job_output} elif "error" in job_output: run_result = {"error": str(job_output["error"])} @@ -115,60 +114,3 @@ def run_job(handler, job): log.debug(f"Run result: {run_result}") return run_result # pylint: disable=lost-exception - - -@retry(max_attempts=3, base_delay=1, max_delay=3) -async def retry_send_result(session, job_data, url): - """ - Wrapper for sending results. - """ - headers = { - "charset": "utf-8", - "Content-Type": "application/x-www-form-urlencoded" - } - - log.debug("Initiating result API call") - async with session.post(url, - data=job_data, - headers=headers, - raise_for_status=True) as resp: - result = await resp.text() - log.debug(f"Result API response: {result}") - - log.info("Completed result API call") - - -async def send_result(session, job_data, job): - ''' - Return the job results. - ''' - try: - job_data = json.dumps(job_data, ensure_ascii=False) - if not _IS_LOCAL_TEST: - log.info(f"Sending job results for {job['id']}: {job_data}") - await retry_send_result(session, job_data, get_done_url()) - else: - log.warn(f"Local test job results for {job['id']}: {job_data}") - - except Exception as err: # pylint: disable=broad-except - log.error(f"Error while returning job result {job['id']}: {err}") - else: - log.info(f"Successfully returned job result {job['id']}") - - -async def send_stream_result(session, job_data, job): - ''' - Return the stream job results. - ''' - try: - job_data = json.dumps(job_data, ensure_ascii=False) - if not _IS_LOCAL_TEST: - log.info(f"Sending job results for {job['id']}: {job_data}") - await retry_send_result(session, job_data, get_stream_url()) - else: - log.warn(f"Local test job results for {job['id']}: {job_data}") - - except Exception as err: # pylint: disable=broad-except - log.error(f"Error while returning job result {job['id']}: {err}") - else: - log.info(f"Successfully returned job result {job['id']}") diff --git a/runpod/serverless/modules/rp_http.py b/runpod/serverless/modules/rp_http.py new file mode 100644 index 00000000..3fc8bb61 --- /dev/null +++ b/runpod/serverless/modules/rp_http.py @@ -0,0 +1,66 @@ +""" + This module is used to handle HTTP requests. +""" + +import json + +from .retry import retry +import runpod.serverless.modules.logging as log +from .worker_state import IS_LOCAL_TEST, get_done_url, get_stream_url + + +@retry(max_attempts=3, base_delay=1, max_delay=3) +async def transmit(session, job_data, url): + """ + Wrapper for sending results. + """ + headers = { + "charset": "utf-8", + "Content-Type": "application/x-www-form-urlencoded" + } + + log.debug("Initiating result API call") + async with session.post(url, + data=job_data, + headers=headers, + raise_for_status=True) as resp: + result = await resp.text() + log.debug(f"Result API response: {result}") + + log.info("Completed result API call") + + +async def send_result(session, job_data, job): + ''' + Return the job results. + ''' + try: + job_data = json.dumps(job_data, ensure_ascii=False) + if not IS_LOCAL_TEST: + log.info(f"Sending job results for {job['id']}: {job_data}") + await transmit(session, job_data, get_done_url()) + else: + log.warn(f"Local test job results for {job['id']}: {job_data}") + + except Exception as err: # pylint: disable=broad-except + log.error(f"Error while returning job result {job['id']}: {err}") + else: + log.info(f"Successfully returned job result {job['id']}") + + +async def stream_result(session, job_data, job): + ''' + Return the stream job results. + ''' + try: + job_data = json.dumps(job_data, ensure_ascii=False) + if not IS_LOCAL_TEST: + log.info(f"Sending job results for {job['id']}: {job_data}") + await transmit(session, job_data, get_stream_url()) + else: + log.warn(f"Local test job results for {job['id']}: {job_data}") + + except Exception as err: # pylint: disable=broad-except + log.error(f"Error while returning job result {job['id']}: {err}") + else: + log.info(f"Successfully returned job result {job['id']}") diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index 588cc944..022c064b 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -6,21 +6,14 @@ import uuid import time -REF_COUNT_ZERO = time.perf_counter() - -CURRENT_JOB_ID = None +REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger. WORKER_ID = os.environ.get('RUNPOD_POD_ID', str(uuid.uuid4())) - -def get_auth_header(): - ''' - Returns the authorization header with the API key. - ''' - return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} - +CURRENT_JOB_ID = None JOB_GET_URL = str(os.environ.get('RUNPOD_WEBHOOK_GET_JOB')).replace('$ID', WORKER_ID) + JOB_DONE_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT')) JOB_DONE_URL_TEMPLATE = JOB_DONE_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID) @@ -36,6 +29,17 @@ def get_auth_header(): PING_INTERVAL = int(os.environ.get('RUNPOD_PING_INTERVAL', 10000)) +# ----------------------------------- Flags ---------------------------------- # +IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None + + +def get_auth_header(): + ''' + Returns the authorization header with the API key. + ''' + return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} + + def get_current_job_id(): ''' Returns the current job id. diff --git a/runpod/serverless/work_loop.py b/runpod/serverless/work_loop.py index 81be9244..99253ee1 100644 --- a/runpod/serverless/work_loop.py +++ b/runpod/serverless/work_loop.py @@ -11,7 +11,8 @@ import runpod.serverless.modules.logging as log from .modules.heartbeat import HeartbeatSender -from .modules.job import get_job, run_job, send_result, send_stream_result +from .modules.job import get_job, run_job +from .modules.rp_http import send_result, stream_result from .modules.worker_state import REF_COUNT_ZERO, set_job_id from .utils import rp_debugger @@ -61,10 +62,11 @@ async def start_worker(config): job_result = {"error": error_msg} else: job_result = run_job(config["handler"], job) + # check if job result is a generator if isinstance(job_result, types.GeneratorType): for job_stream in job_result: - await send_stream_result(session, job_stream, job) + await stream_result(session, job_stream, job) job_result = None # If refresh_worker is set, pod will be reset after job is complete. From aa40c3dd5c10c3d920e3d4036767bbf19364ed0e Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Mon, 12 Jun 2023 12:44:11 -0400 Subject: [PATCH 4/4] fix: linting --- runpod/serverless/modules/job.py | 7 +++++-- runpod/serverless/modules/rp_http.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 8ce54c19..b9f8055f 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -88,18 +88,21 @@ def run_job(handler, job): if isinstance(job_output, types.GeneratorType): for output_partial in job_output: yield {"output": output_partial} - return + run_result = None - if isinstance(job_output, bool): + elif isinstance(job_output, bool): run_result = {"output": job_output} + elif "error" in job_output: run_result = {"error": str(job_output["error"])} + elif "refresh_worker" in job_output: job_output.pop("refresh_worker") run_result = { "stopPod": True, "output": job_output } + else: run_result = {"output": job_output} diff --git a/runpod/serverless/modules/rp_http.py b/runpod/serverless/modules/rp_http.py index 3fc8bb61..e5267723 100644 --- a/runpod/serverless/modules/rp_http.py +++ b/runpod/serverless/modules/rp_http.py @@ -4,8 +4,8 @@ import json -from .retry import retry import runpod.serverless.modules.logging as log +from .retry import retry from .worker_state import IS_LOCAL_TEST, get_done_url, get_stream_url