Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import worker
from .modules import rp_fastapi
from .modules.rp_logger import RunPodLogger
from .modules.rp_progress import progress_update

log = RunPodLogger()

Expand Down
33 changes: 33 additions & 0 deletions runpod/serverless/modules/rp_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Provides a method to update the progress of a currently running job.
"""

import os
import aiohttp
from .rp_http import send_result


def _create_session():
"""
Creates an aiohttp session.
"""
auth_header = {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}
timeout = aiohttp.ClientTimeout(total=300, connect=2, sock_connect=2)

return aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=None),
headers=auth_header, timeout=timeout
)

def progress_update(job, progress):
"""
Updates the progress of a currently running job.
"""
session = _create_session()

job_data = {
"status": "IN_PROGRESS",
"output": progress
}

send_result(session, job_data, job)
38 changes: 38 additions & 0 deletions tests/test_serverless/test_modules/test_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Tests for the rp_progress.py module.
"""


import unittest
from unittest.mock import patch, Mock
from runpod.serverless import progress_update

class TestProgressUpdate(unittest.IsolatedAsyncioTestCase):
""" Tests for the progress_update function. """

@patch("runpod.serverless.modules.rp_progress.os.environ.get")
@patch("runpod.serverless.modules.rp_progress.aiohttp.ClientSession")
@patch("runpod.serverless.modules.rp_progress.send_result")
async def test_progress_update(self, mock_send_result, mock_client_session, mock_os_get):
"""
Tests that the progress_update function calls the send_result function with the correct
"""
# Set mock values
mock_os_get.return_value = "fake_api_key"
fake_session = Mock()
mock_client_session.return_value = fake_session

# Call the function
job = "fake_job"
progress = "50%"
progress_update(job, progress)

# Assertions
mock_os_get.assert_called_once_with('RUNPOD_AI_API_KEY')
mock_client_session.assert_called_once()

expected_job_data = {
"status": "IN_PROGRESS",
"output": progress
}
mock_send_result.assert_called_once_with(fake_session, expected_job_data, job)