Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 16 additions & 36 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,34 +82,15 @@ def handle_shutdown(self, signum, frame):

async def run(self):
# Create an async session that will be closed when the worker is killed.

async with AsyncClientSession() as session:
# Create tasks for getting and running jobs.
jobtake_task = asyncio.create_task(self.get_jobs(session))
jobrun_task = asyncio.create_task(self.run_jobs(session))

tasks = [jobtake_task, jobrun_task]

try:
# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)
except asyncio.CancelledError: # worker is killed
log.debug("Worker tasks cancelled.")
self.kill_worker()
finally:
# Handle the task cancellation gracefully
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await self.cleanup() # Ensure resources are cleaned up

async def cleanup(self):
# Perform any necessary cleanup here, such as closing connections
log.debug("Cleaning up resources before shutdown.")
# TODO: stop heartbeat or close any open connections
await asyncio.sleep(0) # Give a chance for other tasks to run (optional)
log.debug("Cleanup complete.")
# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)

def is_alive(self):
"""
Expand All @@ -121,6 +102,7 @@ def kill_worker(self):
"""
Whether to kill the worker.
"""
log.info("Kill worker.")
self._shutdown_event.set()

async def get_jobs(self, session: ClientSession):
Expand All @@ -142,42 +124,40 @@ async def get_jobs(self, session: ClientSession):
jobs_needed = self.current_concurrency - job_progress.get_job_count()
if jobs_needed <= 0:
log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.")
await asyncio.sleep(0.1) # don't go rapidly
await asyncio.sleep(1) # don't go rapidly
continue

try:
# Keep the connection to the blocking call up to 30 seconds
acquired_jobs = await asyncio.wait_for(
get_job(session, jobs_needed), timeout=30
)

if not acquired_jobs:
log.debug("JobScaler.get_jobs | No jobs acquired.")
continue

for job in acquired_jobs:
await job_list.add_job(job)

log.info(f"Jobs in queue: {job_list.get_job_count()}")

except TooManyRequests:
log.debug(f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds.")
await asyncio.sleep(5) # debounce for 5 seconds
continue
except asyncio.CancelledError:
log.debug("JobScaler.get_jobs | Request was cancelled.")
continue
except TimeoutError:
log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.")
continue
except TypeError as error:
log.debug(f"JobScaler.get_jobs | Unexpected error: {error}.")
continue
except Exception as error:
log.error(
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
)
continue

if not acquired_jobs:
log.debug("JobScaler.get_jobs | No jobs acquired.")
finally:
# Yield control back to the event loop
await asyncio.sleep(0)
continue

for job in acquired_jobs:
await job_list.add_job(job)

log.info(f"Jobs in queue: {job_list.get_job_count()}")

async def run_jobs(self, session: ClientSession):
"""
Expand Down
Loading