Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ runpod = "runpod.cli.entry:runpod_cli"
test = [
"asynctest",
"nest_asyncio",
"faker",
"pytest-asyncio",
"pytest-cov",
"pytest-timeout",
Expand Down
43 changes: 1 addition & 42 deletions runpod/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError

from .cli.groups.config.functions import get_credentials
from .tracer import create_aiohttp_tracer, create_request_tracer
from .user_agent import USER_AGENT


Expand Down Expand Up @@ -37,13 +36,11 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name
"""
Deprecation from aiohttp.ClientSession forbids inheritance.
This is now a factory method
TODO: use httpx
"""
return ClientSession(
connector=TCPConnector(limit=0),
headers=get_auth_header(),
timeout=ClientTimeout(600, ceil_threshold=400),
trace_configs=[create_aiohttp_tracer()],
*args,
**kwargs,
)
Expand All @@ -52,43 +49,5 @@ def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name
class SyncClientSession(requests.Session):
"""
Inherits requests.Session to override `request()` method for tracing
TODO: use httpx
"""

def request(self, method, url, **kwargs): # pylint: disable=arguments-differ
"""
Override for tracing. Not using super().request()
to capture metrics for connection and transfer times
"""
with create_request_tracer() as tracer:
# Separate out the kwargs that are not applicable to `requests.Request`
request_kwargs = {
k: v
for k, v in kwargs.items()
# contains the names of the arguments
if k in requests.Request.__init__.__code__.co_varnames
}

# Separate out the kwargs that are applicable to `requests.Request`
send_kwargs = {k: v for k, v in kwargs.items() if k not in request_kwargs}

# Create a PreparedRequest object to hold the request details
req = requests.Request(method, url, **request_kwargs)
prepped = self.prepare_request(req)
tracer.request = prepped # Assign the request to the tracer

# Merge environment settings
settings = self.merge_environment_settings(
prepped.url,
send_kwargs.get("proxies"),
send_kwargs.get("stream"),
send_kwargs.get("verify"),
send_kwargs.get("cert"),
)
send_kwargs.update(settings)

# Send the request
response = self.send(prepped, **send_kwargs)
tracer.response = response # Assign the response to the tracer

return response
pass
7 changes: 7 additions & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

log = RunPodLogger()


def handle_uncaught_exception(exc_type, exc_value, exc_traceback):
log.error(f"Uncaught exception | {exc_type}; {exc_value}; {exc_traceback};")

sys.excepthook = handle_uncaught_exception


# ---------------------------------------------------------------------------- #
# Run Time Arguments #
# ---------------------------------------------------------------------------- #
Expand Down
14 changes: 7 additions & 7 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@ async def _realtime(self, job: Job):
Performs model inference on the input data using the provided handler.
If handler is not provided, returns an error message.
"""
await job_list.add(job.id)
job_list.add(job.id)

# Process the job using the provided handler, passing in the job input.
job_results = await run_job(self.config["handler"], job.__dict__)

await job_list.remove(job.id)
job_list.remove(job.id)

# Return the results of the job processing.
return jsonable_encoder(job_results)
Expand All @@ -304,7 +304,7 @@ async def _realtime(self, job: Job):
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
"""Development endpoint to simulate run behavior."""
assigned_job_id = f"test-{uuid.uuid4()}"
await job_list.add({
job_list.add({
"id": assigned_job_id,
"input": job_request.input,
"webhook": job_request.webhook
Expand Down Expand Up @@ -345,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
# ---------------------------------- stream ---------------------------------- #
async def _sim_stream(self, job_id: str) -> StreamOutput:
"""Development endpoint to simulate stream behavior."""
stashed_job = await job_list.get(job_id)
stashed_job = job_list.get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -367,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
}
)

await job_list.remove(job.id)
job_list.remove(job.id)

if stashed_job.webhook:
thread = threading.Thread(
Expand All @@ -384,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
# ---------------------------------- status ---------------------------------- #
async def _sim_status(self, job_id: str) -> JobOutput:
"""Development endpoint to simulate status behavior."""
stashed_job = await job_list.get(job_id)
stashed_job = job_list.get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -400,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
else:
job_output = await run_job(self.config["handler"], job.__dict__)

await job_list.remove(job.id)
job_list.remove(job.id)

if job_output.get("error", None):
return jsonable_encoder(
Expand Down
109 changes: 75 additions & 34 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
from .rp_job import get_job, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsQueue, JobsProgress
from .worker_state import JobsProgress, IS_LOCAL_TEST

log = RunPodLogger()
job_list = JobsQueue()
job_progress = JobsProgress()


Expand All @@ -38,16 +37,50 @@ class JobScaler:
"""

def __init__(self, config: Dict[str, Any]):
concurrency_modifier = config.get("concurrency_modifier")
if concurrency_modifier is None:
self.concurrency_modifier = _default_concurrency_modifier
else:
self.concurrency_modifier = concurrency_modifier

self._shutdown_event = asyncio.Event()
self.current_concurrency = 1
self.config = config

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)

self.concurrency_modifier = _default_concurrency_modifier
self.jobs_fetcher = get_job
self.jobs_fetcher_timeout = 90
self.jobs_handler = handle_job

if concurrency_modifier := config.get("concurrency_modifier"):
self.concurrency_modifier = concurrency_modifier

if not IS_LOCAL_TEST:
# below cannot be changed unless local
return

if jobs_fetcher := self.config.get("jobs_fetcher"):
self.jobs_fetcher = jobs_fetcher

if jobs_fetcher_timeout := self.config.get("jobs_fetcher_timeout"):
self.jobs_fetcher_timeout = jobs_fetcher_timeout

if jobs_handler := self.config.get("jobs_handler"):
self.jobs_handler = jobs_handler

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)

if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize):
# no need to resize
return

while self.current_occupancy() > 0:
# not safe to scale when jobs are in flight
await asyncio.sleep(1)
continue

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)
log.debug(
f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}"
)

def start(self):
"""
This is required for the worker to be able to shut down gracefully
Expand Down Expand Up @@ -105,6 +138,15 @@ def kill_worker(self):
log.info("Kill worker.")
self._shutdown_event.set()

def current_occupancy(self) -> int:
current_queue_count = self.jobs_queue.qsize()
current_progress_count = job_progress.get_job_count()

log.debug(
f"JobScaler.status | concurrency: {self.current_concurrency}; queue: {current_queue_count}; progress: {current_progress_count}"
)
return current_progress_count + current_queue_count

async def get_jobs(self, session: ClientSession):
"""
Retrieve multiple jobs from the server in batches using blocking requests.
Expand All @@ -114,45 +156,42 @@ async def get_jobs(self, session: ClientSession):
Adds jobs to the JobsQueue
"""
while self.is_alive():
log.debug("JobScaler.get_jobs | Starting job acquisition.")

self.current_concurrency = self.concurrency_modifier(
self.current_concurrency
)
log.debug(f"JobScaler.get_jobs | current Concurrency set to: {self.current_concurrency}")
await self.set_scale()

current_progress_count = await job_progress.get_job_count()
log.debug(f"JobScaler.get_jobs | current Jobs in progress: {current_progress_count}")

current_queue_count = job_list.get_job_count()
log.debug(f"JobScaler.get_jobs | current Jobs in queue: {current_queue_count}")

jobs_needed = self.current_concurrency - current_progress_count - current_queue_count
jobs_needed = self.current_concurrency - self.current_occupancy()
if jobs_needed <= 0:
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
await asyncio.sleep(1) # don't go rapidly
continue

try:
# Keep the connection to the blocking call up to 30 seconds
log.debug("JobScaler.get_jobs | Starting job acquisition.")

# Keep the connection to the blocking call with timeout
acquired_jobs = await asyncio.wait_for(
get_job(session, jobs_needed), timeout=30
self.jobs_fetcher(session, jobs_needed),
timeout=self.jobs_fetcher_timeout,
)

if not acquired_jobs:
log.debug("JobScaler.get_jobs | No jobs acquired.")
continue

for job in acquired_jobs:
await job_list.add_job(job)
await self.jobs_queue.put(job)
job_progress.add(job)
log.debug("Job Queued", job["id"])

log.info(f"Jobs in queue: {job_list.get_job_count()}")
log.info(f"Jobs in queue: {self.jobs_queue.qsize()}")

except TooManyRequests:
log.debug(f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds.")
log.debug(
f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds."
)
await asyncio.sleep(5) # debounce for 5 seconds
except asyncio.CancelledError:
log.debug("JobScaler.get_jobs | Request was cancelled.")
raise # CancelledError is a BaseException
except TimeoutError:
log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.")
except TypeError as error:
Expand All @@ -173,10 +212,10 @@ async def run_jobs(self, session: ClientSession):
"""
tasks = [] # Store the tasks for concurrent job processing

while self.is_alive() or not job_list.empty():
while self.is_alive() or not self.jobs_queue.empty():
# Fetch as many jobs as the concurrency allows
while len(tasks) < self.current_concurrency and not job_list.empty():
job = await job_list.get_job()
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
job = await self.jobs_queue.get()

# Create a new task for each job and add it to the task list
task = asyncio.create_task(self.handle_job(session, job))
Expand Down Expand Up @@ -204,9 +243,9 @@ async def handle_job(self, session: ClientSession, job: dict):
Process an individual job. This function is run concurrently for multiple jobs.
"""
try:
await job_progress.add(job)
log.debug("Handling Job", job["id"])

await handle_job(session, self.config, job)
await self.jobs_handler(session, self.config, job)

if self.config.get("refresh_worker", False):
self.kill_worker()
Expand All @@ -216,8 +255,10 @@ async def handle_job(self, session: ClientSession, job: dict):
raise err

finally:
# Inform JobsQueue of a task completion
job_list.task_done()
# Inform Queue of a task completion
self.jobs_queue.task_done()

# Job is no longer in progress
await job_progress.remove(job["id"])
job_progress.remove(job)

log.debug("Finished Job", job["id"])
Loading
Loading