diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index a7cbe659..6b3abcbf 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -15,6 +15,7 @@ from . import worker from .modules import rp_fastapi from .modules.rp_logger import RunPodLogger +from .modules.rp_progress import progress_update log = RunPodLogger() diff --git a/runpod/serverless/modules/rp_progress.py b/runpod/serverless/modules/rp_progress.py new file mode 100644 index 00000000..91f6b138 --- /dev/null +++ b/runpod/serverless/modules/rp_progress.py @@ -0,0 +1,33 @@ +""" +Provides a method to update the progress of a currently running job. +""" + +import os +import aiohttp +from .rp_http import send_result + + +def _create_session(): + """ + Creates an aiohttp session. + """ + auth_header = {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} + timeout = aiohttp.ClientTimeout(total=300, connect=2, sock_connect=2) + + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=None), + headers=auth_header, timeout=timeout + ) + +def progress_update(job, progress): + """ + Updates the progress of a currently running job. + """ + session = _create_session() + + job_data = { + "status": "IN_PROGRESS", + "output": progress + } + + send_result(session, job_data, job) diff --git a/tests/test_serverless/test_modules/test_progress.py b/tests/test_serverless/test_modules/test_progress.py new file mode 100644 index 00000000..4e89281a --- /dev/null +++ b/tests/test_serverless/test_modules/test_progress.py @@ -0,0 +1,38 @@ +""" +Tests for the rp_progress.py module. +""" + + +import unittest +from unittest.mock import patch, Mock +from runpod.serverless import progress_update + +class TestProgressUpdate(unittest.IsolatedAsyncioTestCase): + """ Tests for the progress_update function. """ + + @patch("runpod.serverless.modules.rp_progress.os.environ.get") + @patch("runpod.serverless.modules.rp_progress.aiohttp.ClientSession") + @patch("runpod.serverless.modules.rp_progress.send_result") + async def test_progress_update(self, mock_send_result, mock_client_session, mock_os_get): + """ + Tests that the progress_update function calls the send_result function with the correct + """ + # Set mock values + mock_os_get.return_value = "fake_api_key" + fake_session = Mock() + mock_client_session.return_value = fake_session + + # Call the function + job = "fake_job" + progress = "50%" + progress_update(job, progress) + + # Assertions + mock_os_get.assert_called_once_with('RUNPOD_AI_API_KEY') + mock_client_session.assert_called_once() + + expected_job_data = { + "status": "IN_PROGRESS", + "output": progress + } + mock_send_result.assert_called_once_with(fake_session, expected_job_data, job)