diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index 3498bacb..c5181a31 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -8,6 +8,7 @@ import os import json +import asyncio import traceback from aiohttp import ClientSession @@ -53,54 +54,62 @@ async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any] async with session.get(_job_get_url()) as response: if response.status == 204: log.debug("No content, no job to process.") - if not retry: - return None + if retry is False: + break continue if response.status == 400: log.debug("Received 400 status, expected when FlashBoot is enabled.") - if not retry: - return None + if retry is False: + break continue if response.status != 200: log.error(f"Failed to get job, status code: {response.status}") - if not retry: - return None + if retry is False: + break continue - next_job = await response.json() - log.debug(f"Request Received | {next_job}") + received_request = await response.json() + log.debug("Request Received", {next_job}) - # Check if the job is valid - job_id = next_job.get("id", None) - job_input = next_job.get("input", None) + # Check if the job is valid + job_id = received_request.get("id", None) + job_input = received_request.get("input", None) - if None in [job_id, job_input]: - missing_fields = [] - if job_id is None: - missing_fields.append("id") - if job_input is None: - missing_fields.append("input") + if None in [job_id, job_input]: + missing_fields = [] + if job_id is None: + missing_fields.append("id") + if job_input is None: + missing_fields.append("input") - log.error(f"Job has missing field(s): {', '.join(missing_fields)}.") - next_job = None + log.error(f"Job has missing field(s): {', '.join(missing_fields)}.") + else: + next_job = received_request except Exception as err: # pylint: disable=broad-except - log.error(f"Error while getting job: {err}") + err_type = type(err).__name__ + err_message = str(err) + err_traceback = traceback.format_exc() + log.error(f"Failed to get job, error type: {err_type}, error message: {err_message}") + log.error(f"Traceback: {err_traceback}") if next_job is None: log.debug("No job available, waiting for the next one.") - if not retry: - return None + if retry is False: + break - log.debug("Confirmed valid request.", next_job['id']) + await asyncio.sleep(1) + else: + log.debug("Confirmed valid request.", next_job['id']) - if next_job: job_list.add_job(next_job["id"]) log.debug("Request ID added.", next_job['id']) - return next_job + return next_job + + return None async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 39b4efa3..fe7a9199 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -13,6 +13,7 @@ log = RunPodLogger() job_list = Jobs() + class JobScaler(): """ A class for automatically retrieving new jobs from the server and processing them concurrently. @@ -100,7 +101,7 @@ async def get_jobs(self, session): break for _ in range(self.num_concurrent_get_job_requests): - job = await get_job(session, retry=False) + job = await get_job(session) self.job_history.append(1 if job else 0) if job: yield job @@ -128,8 +129,6 @@ async def get_jobs(self, session): f"{self.num_concurrent_get_job_requests}." ) - - def upscale_rate(self) -> None: """ Upscale the job retrieval rate by adjusting the number of concurrent requests. diff --git a/tests/test_serverless/test_modules/test_job.py b/tests/test_serverless/test_modules/test_job.py index 588bfd76..1de1b4d7 100644 --- a/tests/test_serverless/test_modules/test_job.py +++ b/tests/test_serverless/test_modules/test_job.py @@ -155,7 +155,7 @@ async def test_get_job_exception(self): job = await rp_job.get_job(mock_session_exception, retry=False) assert job is None - assert mock_log.error.call_count == 1 + assert mock_log.error.call_count == 2 class TestRunJob(IsolatedAsyncioTestCase):