Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 13 additions & 45 deletions runpod/serverless/modules/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import time
import json
import traceback
import types

import runpod.serverless.modules.logging as log
from .worker_state import JOB_GET_URL, get_done_url
from .retry import retry
from .worker_state import IS_LOCAL_TEST, JOB_GET_URL
from .rp_tips import check_return_size

_IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None


def _get_local():
"""
Expand Down Expand Up @@ -44,7 +42,7 @@ async def get_job(session, config):
log.warn("test_input set, using test_input as job input")
next_job = config["rp_args"]["test_input"]
next_job["id"] = "test_input_provided"
elif _IS_LOCAL_TEST:
elif IS_LOCAL_TEST:
log.warn("RUNPOD_WEBHOOK_GET_JOB not set, switching to get_local")
next_job = _get_local()
else:
Expand Down Expand Up @@ -86,16 +84,25 @@ def run_job(handler, job):
job_output = handler(job)
log.debug(f'Job {job["id"]} handler output: {job_output}')

if isinstance(job_output, bool):
# Generator type is used for streaming jobs.
if isinstance(job_output, types.GeneratorType):
for output_partial in job_output:
yield {"output": output_partial}
run_result = None

elif isinstance(job_output, bool):
run_result = {"output": job_output}

elif "error" in job_output:
run_result = {"error": str(job_output["error"])}

elif "refresh_worker" in job_output:
job_output.pop("refresh_worker")
run_result = {
"stopPod": True,
"output": job_output
}

else:
run_result = {"output": job_output}

Expand All @@ -110,42 +117,3 @@ def run_job(handler, job):
log.debug(f"Run result: {run_result}")

return run_result # pylint: disable=lost-exception


@retry(max_attempts=3, base_delay=1, max_delay=3)
async def retry_send_result(session, job_data):
"""
Wrapper for sending results.
"""
headers = {
"charset": "utf-8",
"Content-Type": "application/x-www-form-urlencoded"
}

log.debug("Initiating result API call")
async with session.post(get_done_url(),
data=job_data,
headers=headers,
raise_for_status=True) as resp:
result = await resp.text()
log.debug(f"Result API response: {result}")

log.info("Completed result API call")


async def send_result(session, job_data, job):
'''
Return the job results.
'''
try:
job_data = json.dumps(job_data, ensure_ascii=False)
if not _IS_LOCAL_TEST:
log.info(f"Sending job results for {job['id']}: {job_data}")
await retry_send_result(session, job_data)
else:
log.warn(f"Local test job results for {job['id']}: {job_data}")

except Exception as err: # pylint: disable=broad-except
log.error(f"Error while returning job result {job['id']}: {err}")
else:
log.info(f"Successfully returned job result {job['id']}")
66 changes: 66 additions & 0 deletions runpod/serverless/modules/rp_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
This module is used to handle HTTP requests.
"""

import json

import runpod.serverless.modules.logging as log
from .retry import retry
from .worker_state import IS_LOCAL_TEST, get_done_url, get_stream_url


@retry(max_attempts=3, base_delay=1, max_delay=3)
async def transmit(session, job_data, url):
"""
Wrapper for sending results.
"""
headers = {
"charset": "utf-8",
"Content-Type": "application/x-www-form-urlencoded"
}

log.debug("Initiating result API call")
async with session.post(url,
data=job_data,
headers=headers,
raise_for_status=True) as resp:
result = await resp.text()
log.debug(f"Result API response: {result}")

log.info("Completed result API call")


async def send_result(session, job_data, job):
'''
Return the job results.
'''
try:
job_data = json.dumps(job_data, ensure_ascii=False)
if not IS_LOCAL_TEST:
log.info(f"Sending job results for {job['id']}: {job_data}")
await transmit(session, job_data, get_done_url())
else:
log.warn(f"Local test job results for {job['id']}: {job_data}")

except Exception as err: # pylint: disable=broad-except
log.error(f"Error while returning job result {job['id']}: {err}")
else:
log.info(f"Successfully returned job result {job['id']}")


async def stream_result(session, job_data, job):
'''
Return the stream job results.
'''
try:
job_data = json.dumps(job_data, ensure_ascii=False)
if not IS_LOCAL_TEST:
log.info(f"Sending job results for {job['id']}: {job_data}")
await transmit(session, job_data, get_stream_url())
else:
log.warn(f"Local test job results for {job['id']}: {job_data}")

except Exception as err: # pylint: disable=broad-except
log.error(f"Error while returning job result {job['id']}: {err}")
else:
log.info(f"Successfully returned job result {job['id']}")
34 changes: 24 additions & 10 deletions runpod/serverless/modules/worker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,20 @@
import uuid
import time

REF_COUNT_ZERO = time.perf_counter()

CURRENT_JOB_ID = None
REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger.

WORKER_ID = os.environ.get('RUNPOD_POD_ID', str(uuid.uuid4()))


def get_auth_header():
'''
Returns the authorization header with the API key.
'''
return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}

CURRENT_JOB_ID = None

JOB_GET_URL = str(os.environ.get('RUNPOD_WEBHOOK_GET_JOB')).replace('$ID', WORKER_ID)

JOB_DONE_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT'))
JOB_DONE_URL_TEMPLATE = JOB_DONE_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID)

JOB_STREAM_URL_TEMPLATE = str(os.environ.get('RUNPOD_WEBHOOK_POST_STREAM'))
JOB_STREAM_URL_TEMPLATE = JOB_STREAM_URL_TEMPLATE.replace('$RUNPOD_POD_ID', WORKER_ID)

WEBHOOK_PING = os.environ.get('RUNPOD_WEBHOOK_PING', None)
if WEBHOOK_PING is not None:
PING_URL = WEBHOOK_PING.replace('$RUNPOD_POD_ID', WORKER_ID)
Expand All @@ -33,6 +29,17 @@ def get_auth_header():
PING_INTERVAL = int(os.environ.get('RUNPOD_PING_INTERVAL', 10000))


# ----------------------------------- Flags ---------------------------------- #
IS_LOCAL_TEST = os.environ.get("RUNPOD_WEBHOOK_GET_JOB", None) is None


def get_auth_header():
'''
Returns the authorization header with the API key.
'''
return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}


def get_current_job_id():
'''
Returns the current job id.
Expand All @@ -47,6 +54,13 @@ def get_done_url():
return JOB_DONE_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID)


def get_stream_url():
'''
Constructs the stream URL using the current job id.
'''
return JOB_STREAM_URL_TEMPLATE.replace('$ID', CURRENT_JOB_ID)


def set_job_id(new_job_id):
'''
Sets the current job id.
Expand Down
10 changes: 9 additions & 1 deletion runpod/serverless/work_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import os
import sys
import types

import aiohttp

import runpod.serverless.modules.logging as log
from .modules.heartbeat import HeartbeatSender
from .modules.job import get_job, run_job, send_result
from .modules.job import get_job, run_job
from .modules.rp_http import send_result, stream_result
from .modules.worker_state import REF_COUNT_ZERO, set_job_id
from .utils import rp_debugger

Expand Down Expand Up @@ -61,6 +63,12 @@ async def start_worker(config):
else:
job_result = run_job(config["handler"], job)

# check if job result is a generator
if isinstance(job_result, types.GeneratorType):
for job_stream in job_result:
await stream_result(session, job_stream, job)
job_result = None

# If refresh_worker is set, pod will be reset after job is complete.
if config.get("refresh_worker", False):
log.info(f"Refresh worker flag set, stopping pod after job {job['id']}.")
Expand Down