diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 9648b4c8..b9f8055f 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -6,14 +6,12 @@ import time import json import traceback +import types import runpod.serverless.modules.logging as log -from .worker_state import JOB_GET_URL, get_done_url -from .retry import retry +from .worker_state import IS_LOCAL_TEST, JOB_GET_URL from .rp_tips import check_return_size -_IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None - def _get_local(): """ @@ -44,7 +42,7 @@ async def get_job(session, config): log.warn("test_input set, using test_input as job input") next_job = config["rp_args"]["test_input"] next_job["id"] = "test_input_provided" - elif _IS_LOCAL_TEST: + elif IS_LOCAL_TEST: log.warn("RUNPOD_WEBHOOK_GET_JOB not set, switching to get_local") next_job = _get_local() else: @@ -86,16 +84,25 @@ def run_job(handler, job): job_output = handler(job) log.debug(f'Job {job["id"]} handler output: {job_output}') - if isinstance(job_output, bool): + # Generator type is used for streaming jobs. + if isinstance(job_output, types.GeneratorType): + for output_partial in job_output: + yield {"output": output_partial} + run_result = None + + elif isinstance(job_output, bool): run_result = {"output": job_output} + elif "error" in job_output: run_result = {"error": str(job_output["error"])} + elif "refresh_worker" in job_output: job_output.pop("refresh_worker") run_result = { "stopPod": True, "output": job_output } + else: run_result = {"output": job_output} @@ -110,42 +117,3 @@ def run_job(handler, job): log.debug(f"Run result: {run_result}") return run_result # pylint: disable=lost-exception - - -@retry(max_attempts=3, base_delay=1, max_delay=3) -async def retry_send_result(session, job_data): - """ - Wrapper for sending results. - """ - headers = { - "charset": "utf-8", - "Content-Type": "application/x-www-form-urlencoded" - } - - log.debug("Initiating result API call") - async with session.post(get_done_url(), - data=job_data, - headers=headers, - raise_for_status=True) as resp: - result = await resp.text() - log.debug(f"Result API response: {result}") - - log.info("Completed result API call") - - -async def send_result(session, job_data, job): - ''' - Return the job results. - ''' - try: - job_data = json.dumps(job_data, ensure_ascii=False) - if not _IS_LOCAL_TEST: - log.info(f"Sending job results for {job['id']}: {job_data}") - await retry_send_result(session, job_data) - else: - log.warn(f"Local test job results for {job['id']}: {job_data}") - - except Exception as err: # pylint: disable=broad-except - log.error(f"Error while returning job result {job['id']}: {err}") - else: - log.info(f"Successfully returned job result {job['id']}") diff --git a/runpod/serverless/modules/rp_http.py b/runpod/serverless/modules/rp_http.py new file mode 100644 index 00000000..e5267723 --- /dev/null +++ b/runpod/serverless/modules/rp_http.py @@ -0,0 +1,66 @@ +""" + This module is used to handle HTTP requests. +""" + +import json + +import runpod.serverless.modules.logging as log +from .retry import retry +from .worker_state import IS_LOCAL_TEST, get_done_url, get_stream_url + + +@retry(max_attempts=3, base_delay=1, max_delay=3) +async def transmit(session, job_data, url): + """ + Wrapper for sending results. + """ + headers = { + "charset": "utf-8", + "Content-Type": "application/x-www-form-urlencoded" + } + + log.debug("Initiating result API call") + async with session.post(url, + data=job_data, + headers=headers, + raise_for_status=True) as resp: + result = await resp.text() + log.debug(f"Result API response: {result}") + + log.info("Completed result API call") + + +async def send_result(session, job_data, job): + ''' + Return the job results. + ''' + try: + job_data = json.dumps(job_data, ensure_ascii=False) + if not IS_LOCAL_TEST: + log.info(f"Sending job results for {job['id']}: {job_data}") + await transmit(session, job_data, get_done_url()) + else: + log.warn(f"Local test job results for {job['id']}: {job_data}") + + except Exception as err: # pylint: disable=broad-except + log.error(f"Error while returning job result {job['id']}: {err}") + else: + log.info(f"Successfully returned job result {job['id']}") + + +async def stream_result(session, job_data, job): + ''' + Return the stream job results. + ''' + try: + job_data = json.dumps(job_data, ensure_ascii=False) + if not IS_LOCAL_TEST: + log.info(f"Sending job results for {job['id']}: {job_data}") + await transmit(session, job_data, get_stream_url()) + else: + log.warn(f"Local test job results for {job['id']}: {job_data}") + + except Exception as err: # pylint: disable=broad-except + log.error(f"Error while returning job result {job['id']}: {err}") + else: + log.info(f"Successfully returned job result {job['id']}") diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index 70f6391f..022c064b 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -6,24 +6,20 @@ import uuid import time -REF_COUNT_ZERO = time.perf_counter() - -CURRENT_JOB_ID = None +REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger. WORKER_ID = os.environ.get('RUNPOD_POD_ID', str(uuid.uuid4())) - -def get_auth_header(): - ''' - Returns the authorization header with the API key. - ''' - return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} - +CURRENT_JOB_ID = None JOB_GET_URL = str(os.environ.get('RUNPOD_WEBHOOK_GET_JOB')).replace('$ID', WORKER_ID) + JOB_DONE_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT')) JOB_DONE_URL_TEMPLATE = JOB_DONE_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID) +JOB_STREAM_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_STREAM')) +JOB_STREAM_URL_TEMPLATE = JOB_STREAM_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID) + WEBHOOK_PING = os.environ.get('RUNPOD_WEBHOOK_PING', None) if WEBHOOK_PING is not None: PING_URL = WEBHOOK_PING.replace('$RUNPOD_POD_ID', WORKER_ID) @@ -33,6 +29,17 @@ def get_auth_header(): PING_INTERVAL = int(os.environ.get('RUNPOD_PING_INTERVAL', 10000)) +# ----------------------------------- Flags ---------------------------------- # +IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None + + +def get_auth_header(): + ''' + Returns the authorization header with the API key. + ''' + return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} + + def get_current_job_id(): ''' Returns the current job id. @@ -47,6 +54,13 @@ def get_done_url(): return JOB_DONE_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID) +def get_stream_url(): + ''' + Constructs the stream URL using the current job id. + ''' + return JOB_STREAM_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID) + + def set_job_id(new_job_id): ''' Sets the current job id. diff --git a/runpod/serverless/work_loop.py b/runpod/serverless/work_loop.py index 802a8e95..99253ee1 100644 --- a/runpod/serverless/work_loop.py +++ b/runpod/serverless/work_loop.py @@ -5,12 +5,14 @@ import os import sys +import types import aiohttp import runpod.serverless.modules.logging as log from .modules.heartbeat import HeartbeatSender -from .modules.job import get_job, run_job, send_result +from .modules.job import get_job, run_job +from .modules.rp_http import send_result, stream_result from .modules.worker_state import REF_COUNT_ZERO, set_job_id from .utils import rp_debugger @@ -61,6 +63,12 @@ async def start_worker(config): else: job_result = run_job(config["handler"], job) + # check if job result is a generator + if isinstance(job_result, types.GeneratorType): + for job_stream in job_result: + await stream_result(session, job_stream, job) + job_result = None + # If refresh_worker is set, pod will be reset after job is complete. if config.get("refresh_worker", False): log.info(f"Refresh worker flag set, stopping pod after job {job['id']}.")