Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions runpod/serverless/modules/rp_progress.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""
Provides a method to update the progress of a currently running job.
RunPod Progress Module
"""

import os
import asyncio
import threading
from typing import Dict, Any

import aiohttp
from .rp_http import send_result

from .rp_http import send_result

def _create_session():
async def _create_session_async():
"""
Creates an aiohttp session.
"""
Expand All @@ -19,15 +23,33 @@ def _create_session():
headers=auth_header, timeout=timeout
)

def progress_update(job, progress):
async def _async_progress_update(session, job, progress):
"""
Updates the progress of a currently running job.
The actual asynchronous function that sends the update.
"""
session = _create_session()

job_data = {
"status": "IN_PROGRESS",
"output": progress
}

send_result(session, job_data, job)
await send_result(session, job_data, job)

def _thread_target(job: Dict[str, Any], progress: str):
"""
A wrapper around _async_progress_update to handle the event loop.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

session = loop.run_until_complete(_create_session_async())
loop.run_until_complete(_async_progress_update(session, job, progress))

session.close()


def progress_update(job, progress):
"""
Updates the progress of a currently running job in a separate thread.
"""
thread = threading.Thread(target=_thread_target, args=(job, progress), daemon=True)
thread.start()
45 changes: 31 additions & 14 deletions tests/test_serverless/test_modules/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,54 @@
Tests for the rp_progress.py module.
"""


import unittest
from unittest.mock import patch, Mock
from runpod.serverless import progress_update
from unittest.mock import ANY, patch
from threading import Event

from runpod.serverless.modules.rp_progress import progress_update, _thread_target

class TestProgressUpdate(unittest.IsolatedAsyncioTestCase):
class TestProgressUpdate(unittest.TestCase):
""" Tests for the progress_update function. """

@patch("runpod.serverless.modules.rp_progress.os.environ.get")
@patch("runpod.serverless.modules.rp_progress.aiohttp.ClientSession")
@patch("runpod.serverless.modules.rp_progress.send_result")
async def test_progress_update(self, mock_send_result, mock_client_session, mock_os_get):
@patch("runpod.serverless.modules.rp_progress._thread_target")
def test_progress_update(self, mock_thread_target, mock_result, mock_os_get):
"""
Tests that the progress_update function calls the send_result function with the correct
Tests that the progress_update function.
"""
# Create an event to track thread completion
thread_event = Event()

def mock_thread_function(job, progress):
try:
assert job == "fake_job", "Job ID was not passed correctly"
assert progress == "50%", "Progress was not passed correctly"
except Exception as err: # pylint: disable=broad-except
print(f"Exception in mocked function: {err}")
finally:
thread_event.set()

mock_thread_target.side_effect = mock_thread_function

# Set mock values
mock_os_get.return_value = "fake_api_key"
fake_session = Mock()
mock_client_session.return_value = fake_session

# Call the function
job = "fake_job"
job = {"id": "fake_job"}
progress = "50%"
progress_update(job, progress)
_thread_target(job, progress)

# Assertions
mock_os_get.assert_called_once_with('RUNPOD_AI_API_KEY')
mock_client_session.assert_called_once()

assert mock_thread_target.called, "Thread function was not started"
mock_thread_target.assert_called_once_with(job, progress)
assert thread_event.wait(timeout=30), "Thread did not complete within expected time"

# Assertions
mock_os_get.assert_called_with('RUNPOD_AI_API_KEY')
expected_job_data = {
"status": "IN_PROGRESS",
"output": progress
}
mock_send_result.assert_called_once_with(fake_session, expected_job_data, job)
mock_result.assert_called_once_with(ANY, expected_job_data, job)