From cb0beaf5570780ac16d67fe485ad0fa8ef6e0a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 11 Jun 2025 21:32:10 -0700 Subject: [PATCH 1/4] refactor: track JobsProgress state using shared memory --- runpod/serverless/modules/worker_state.py | 157 ++++++++++++++++------ 1 file changed, 113 insertions(+), 44 deletions(-) diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index 5e1a2f98..be5dc9db 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -5,6 +5,8 @@ import os import time import uuid +from multiprocessing import Manager +from multiprocessing.managers import SyncManager from typing import Any, Dict, Optional from .rp_logger import RunPodLogger @@ -61,82 +63,149 @@ def __str__(self) -> str: # ---------------------------------------------------------------------------- # # Tracker # # ---------------------------------------------------------------------------- # -class JobsProgress(set): - """Track the state of current jobs in progress.""" - - _instance = None +class JobsProgress: + """Track the state of current jobs in progress using shared memory.""" + + _instance: Optional['JobsProgress'] = None + _manager: SyncManager + _shared_data: Any + _lock: Any def __new__(cls): - if JobsProgress._instance is None: - JobsProgress._instance = set.__new__(cls) - return JobsProgress._instance + if cls._instance is None: + instance = object.__new__(cls) + # Initialize instance variables + instance._manager = Manager() + instance._shared_data = instance._manager.dict() + instance._shared_data['jobs'] = instance._manager.list() + instance._lock = instance._manager.Lock() + cls._instance = instance + return cls._instance + + def __init__(self): + # Everything is already initialized in __new__ + pass def __repr__(self) -> str: return f"<{self.__class__.__name__}>: {self.get_job_list()}" def clear(self) -> None: - return super().clear() + with self._lock: + self._shared_data['jobs'][:] = [] def add(self, element: Any): """ Adds a Job object to the set. + """ + if isinstance(element, str): + job_dict = {'id': element} + elif isinstance(element, dict): + job_dict = element + elif hasattr(element, 'id'): + job_dict = {'id': element.id} + else: + raise TypeError("Only Job objects can be added to JobsProgress.") - If the added element is a string, then `Job(id=element)` is added + with self._lock: + # Check if job already exists + job_list = self._shared_data['jobs'] + for existing_job in job_list: + if existing_job['id'] == job_dict['id']: + return # Job already exists + + # Add new job + job_list.append(job_dict) + log.debug(f"JobsProgress | Added job: {job_dict['id']}") + + def get(self, element: Any) -> Optional[Job]: + """ + Retrieves a Job object from the set. - If the added element is a dict, that `Job(**element)` is added + If the element is a string, searches for Job with that id. """ if isinstance(element, str): - element = Job(id=element) - - if isinstance(element, dict): - element = Job(**element) - - if not isinstance(element, Job): - raise TypeError("Only Job objects can be added to JobsProgress.") + search_id = element + elif isinstance(element, Job): + search_id = element.id + else: + raise TypeError("Only Job objects can be retrieved from JobsProgress.") - return super().add(element) + with self._lock: + for job_dict in self._shared_data['jobs']: + if job_dict['id'] == search_id: + log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}") + return Job(**job_dict) + + return None def remove(self, element: Any): """ Removes a Job object from the set. - - If the element is a string, then `Job(id=element)` is removed - - If the element is a dict, then `Job(**element)` is removed """ if isinstance(element, str): - element = Job(id=element) - - if isinstance(element, dict): - element = Job(**element) - - if not isinstance(element, Job): + job_id = element + elif isinstance(element, dict): + job_id = element.get('id') + elif hasattr(element, 'id'): + job_id = element.id + else: raise TypeError("Only Job objects can be removed from JobsProgress.") - return super().discard(element) - - def get(self, element: Any) -> Job: - if isinstance(element, str): - element = Job(id=element) - - if not isinstance(element, Job): - raise TypeError("Only Job objects can be retrieved from JobsProgress.") - - for job in self: - if job == element: - return job + with self._lock: + job_list = self._shared_data['jobs'] + # Find and remove the job + for i, job_dict in enumerate(job_list): + if job_dict['id'] == job_id: + del job_list[i] + log.debug(f"JobsProgress | Removed job: {job_dict['id']}") + break - def get_job_list(self) -> str: + def get_job_list(self) -> Optional[str]: """ Returns the list of job IDs as comma-separated string. """ - if not len(self): + with self._lock: + job_list = list(self._shared_data['jobs']) + + if not job_list: return None - return ",".join(str(job) for job in self) + log.debug(f"JobsProgress | Jobs in progress: {job_list}") + return ",".join(str(job_dict['id']) for job_dict in job_list) def get_job_count(self) -> int: """ Returns the number of jobs. """ - return len(self) + with self._lock: + return len(self._shared_data['jobs']) + + def __iter__(self): + """Make the class iterable - returns Job objects""" + with self._lock: + # Create a snapshot of jobs to avoid holding lock during iteration + job_dicts = list(self._shared_data['jobs']) + + # Return an iterator of Job objects + return iter(Job(**job_dict) for job_dict in job_dicts) + + def __len__(self): + """Support len() operation""" + return self.get_job_count() + + def __contains__(self, element: Any) -> bool: + """Support 'in' operator""" + if isinstance(element, str): + search_id = element + elif isinstance(element, Job): + search_id = element.id + elif isinstance(element, dict): + search_id = element.get('id') + else: + return False + + with self._lock: + for job_dict in self._shared_data['jobs']: + if job_dict['id'] == search_id: + return True + return False From b467ac75a71285f13dfb99222878b108207f7c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 11 Jun 2025 21:33:22 -0700 Subject: [PATCH 2/4] refactor: Heartbeat uses separate process for independent GIL --- runpod/serverless/modules/rp_ping.py | 31 +- .../test_serverless/test_modules/test_ping.py | 432 ++++++++++++------ 2 files changed, 303 insertions(+), 160 deletions(-) diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index 88fa1049..ae1499f7 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -4,9 +4,8 @@ """ import os -import threading import time - +from multiprocessing import Process import requests from urllib3.util.retry import Retry @@ -22,7 +21,7 @@ class Heartbeat: """Sends heartbeats to the Runpod server.""" - _thread_started = False + _process_started = False def __init__(self, pool_connections=10, retries=3) -> None: """ @@ -32,9 +31,10 @@ def __init__(self, pool_connections=10, retries=3) -> None: self.PING_URL = self.PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID) self.PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000 + # Create a new HTTP session self._session = SyncClientSession() self._session.headers.update( - {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"} + {"Authorization": os.environ.get("RUNPOD_AI_API_KEY", "")} ) retry_strategy = Retry( @@ -52,9 +52,18 @@ def __init__(self, pool_connections=10, retries=3) -> None: self._session.mount("http://", adapter) self._session.mount("https://", adapter) + @staticmethod + def process_loop(test=False): + """ + Static helper to run the ping loop in a separate process. + Creates a new Heartbeat instance to avoid pickling issues. + """ + hb = Heartbeat() + hb.ping_loop(test) + def start_ping(self, test=False): """ - Sends heartbeat pings to the Runpod server. + Sends heartbeat pings to the Runpod server in a separate process. """ if not os.environ.get("RUNPOD_AI_API_KEY"): log.debug("Not deployed on RunPod serverless, pings will not be sent.") @@ -68,18 +77,19 @@ def start_ping(self, test=False): log.error("Ping URL not set, cannot start ping.") return - if not Heartbeat._thread_started: - threading.Thread(target=self.ping_loop, daemon=True, args=(test,)).start() - Heartbeat._thread_started = True + if not Heartbeat._process_started: + process = Process(target=Heartbeat.process_loop, args=(test,)) + process.daemon = True + process.start() + Heartbeat._process_started = True def ping_loop(self, test=False): """ - Sends heartbeat pings to the Runpod server. + Sends heartbeat pings to the Runpod server until interrupted. """ while True: self._send_ping() time.sleep(self.PING_INTERVAL) - if test: return @@ -98,6 +108,5 @@ def _send_ping(self): log.debug( f"Heartbeat Sent | URL: {result.url} | Status: {result.status_code}" ) - except requests.RequestException as err: log.error(f"Ping Request Error: {err}, attempting to restart ping.") diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index 0a5517a3..3a5447d3 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -1,169 +1,303 @@ -""" Tests for runpod.serverless.modules.rp_ping """ - -import importlib import os -import unittest -from unittest.mock import patch, MagicMock - +import pytest +from unittest.mock import MagicMock, patch import requests -from runpod.serverless.modules import rp_ping from runpod.serverless.modules.rp_ping import Heartbeat -from runpod.serverless.modules.worker_state import JobsProgress - - -class MockResponse: - """Mock response for aiohttp""" - url = "" - status_code = 200 - - -def mock_get(*args, **kwargs): - """ - Mock get function for aiohttp - """ - return MockResponse() -class TestPing(unittest.TestCase): - """Tests for rp_ping""" - - def test_default_variables(self): - """ - Tests that the variables are set with default values - """ +class TestHeartbeat: + """Test suite for the Heartbeat class""" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Reset class state before and after each test""" + # Store original state + original_process_started = Heartbeat._process_started + + # Reset before test + Heartbeat._process_started = False + + yield + + # Reset after test + Heartbeat._process_started = original_process_started + + @pytest.fixture + def mock_env(self): + """Fixture to set up environment variables""" + env_vars = { + "RUNPOD_WEBHOOK_PING": "https://test.com/ping/$RUNPOD_POD_ID", + "RUNPOD_AI_API_KEY": "test_api_key", + "RUNPOD_POD_ID": "test_pod_id", + "RUNPOD_PING_INTERVAL": "5000" + } + with patch.dict(os.environ, env_vars): + yield env_vars + + @pytest.fixture + def mock_worker_id(self): + """Mock the WORKER_ID constant""" + with patch("runpod.serverless.modules.rp_ping.WORKER_ID", "test_worker_123"): + yield "test_worker_123" + + @pytest.fixture + def mock_session(self): + """Mock the SyncClientSession""" + with patch("runpod.serverless.modules.rp_ping.SyncClientSession") as mock: + session_instance = MagicMock() + mock.return_value = session_instance + yield session_instance + + @pytest.fixture + def mock_jobs(self): + """Mock the JobsProgress instance""" + with patch("runpod.serverless.modules.rp_ping.jobs") as mock: + mock.get_job_list.return_value = "job1,job2,job3" + yield mock + + @pytest.fixture + def mock_logger(self): + """Mock the logger""" + with patch("runpod.serverless.modules.rp_ping.log") as mock: + yield mock + + def test_heartbeat_initialization(self, mock_env, mock_worker_id, mock_session): + """Test Heartbeat initialization with various configurations""" heartbeat = Heartbeat() - assert heartbeat.PING_URL == "PING_NOT_SET" - assert heartbeat.PING_INTERVAL == 10 - - @patch.dict(os.environ, {"RUNPOD_WEBHOOK_PING": "https://test.com/ping"}) - @patch.dict(os.environ, {"RUNPOD_PING_INTERVAL": "1000"}) - def test_variables(self): - """ - Tests that the variables are set correctly - """ - importlib.reload(rp_ping) - - heartbeat = Heartbeat() - assert heartbeat.PING_URL == "https://test.com/ping" - assert heartbeat.PING_INTERVAL == 1 - - @patch.dict(os.environ, {"RUNPOD_PING_INTERVAL": "1000"}) - @patch( - "runpod.serverless.modules.rp_ping.SyncClientSession.get", side_effect=mock_get - ) - def test_start_ping(self, mock_get_return): - """ - Tests that the start_ping function works correctly - """ - # No RUNPOD_AI_API_KEY case - with patch("threading.Thread.start") as mock_thread_start: - rp_ping.Heartbeat().start_ping(test=True) - assert mock_thread_start.call_count == 0 - - os.environ["RUNPOD_AI_API_KEY"] = "test_key" - - # No RUNPOD_POD_ID case - with patch("threading.Thread.start") as mock_thread_start: - rp_ping.Heartbeat().start_ping(test=True) - assert mock_thread_start.call_count == 0 - - os.environ["RUNPOD_POD_ID"] = "test_pod_id" - - # No RUNPOD_WEBHOOK_PING case - with patch("threading.Thread.start") as mock_thread_start: - rp_ping.Heartbeat().start_ping(test=True) - assert mock_thread_start.call_count == 0 - - os.environ["RUNPOD_WEBHOOK_PING"] = "https://test.com/ping" - - importlib.reload(rp_ping) - - # Success case - with patch("threading.Thread.start") as mock_thread_start: - rp_ping.Heartbeat().start_ping(test=True) - assert mock_thread_start.call_count == 1 - - rp_ping.Heartbeat.PING_URL = "https://test.com/ping" - rp_ping.Heartbeat().ping_loop(test=True) - - self.assertEqual(rp_ping.Heartbeat.PING_URL, "https://test.com/ping") - - # Exception case - mock_get_return.side_effect = requests.RequestException("Test Error") - - with patch("runpod.serverless.modules.rp_ping.log.error") as mock_log_error: - rp_ping.Heartbeat().ping_loop(test=True) - assert mock_log_error.call_count == 1 - + + # Check URL construction + expected_url = "https://test.com/ping/test_worker_123" + assert heartbeat.PING_URL == expected_url + + # Check interval calculation + assert heartbeat.PING_INTERVAL == 5 # 5000 // 1000 + + # Check session setup + mock_session.headers.update.assert_called_once_with( + {"Authorization": "test_api_key"} + ) -@patch.dict(os.environ, {"RUNPOD_PING_INTERVAL": "1000"}) -class TestHeartbeat(unittest.IsolatedAsyncioTestCase): + def test_heartbeat_initialization_defaults(self, mock_worker_id, mock_session): + """Test Heartbeat initialization with default values""" + with patch.dict(os.environ, {}, clear=True): + heartbeat = Heartbeat() + + # Should use default values + assert heartbeat.PING_URL == "PING_NOT_SET" + assert heartbeat.PING_INTERVAL == 10 # 10000 // 1000 + + # Authorization should be None + mock_session.headers.update.assert_called_once_with( + {"Authorization": ""} + ) - @patch.dict(os.environ, {"RUNPOD_AI_API_KEY": ""}) - @patch("runpod.serverless.modules.rp_ping.log") - def test_start_ping_no_api_key(self, mock_logger): - """Test start_ping method when RUNPOD_AI_API_KEY is missing.""" + def test_start_ping_missing_api_key(self, mock_logger, mock_worker_id): + """Test start_ping when API key is missing""" + with patch.dict(os.environ, {"RUNPOD_POD_ID": "test", "RUNPOD_WEBHOOK_PING": "test"}, clear=True): + with patch("multiprocessing.Process") as mock_process: + heartbeat = Heartbeat() + heartbeat.start_ping() + + # Process should not be created + mock_process.assert_not_called() + mock_logger.debug.assert_called_with( + "Not deployed on RunPod serverless, pings will not be sent." + ) + + def test_start_ping_missing_pod_id(self, mock_logger, mock_worker_id): + """Test start_ping when POD_ID is missing""" + with patch.dict(os.environ, {"RUNPOD_AI_API_KEY": "test"}, clear=True): + with patch("multiprocessing.Process") as mock_process: + heartbeat = Heartbeat() + heartbeat.start_ping() + + # Process should not be created + mock_process.assert_not_called() + mock_logger.info.assert_called_with( + "Not running on RunPod, pings will not be sent." + ) + + def test_start_ping_missing_webhook_url(self, mock_logger, mock_worker_id): + """Test start_ping when webhook URL is not set""" + with patch.dict(os.environ, {"RUNPOD_AI_API_KEY": "test", "RUNPOD_POD_ID": "test"}, clear=True): + with patch("multiprocessing.Process") as mock_process: + heartbeat = Heartbeat() + heartbeat.start_ping() + + # Process should not be created + mock_process.assert_not_called() + mock_logger.error.assert_called_with( + "Ping URL not set, cannot start ping." + ) + + @patch("runpod.serverless.modules.rp_ping.Process") + @patch("runpod.serverless.modules.rp_ping.SyncClientSession") + @patch("runpod.serverless.modules.rp_ping.WORKER_ID", "test_worker_123") + @patch.dict(os.environ, { + "RUNPOD_WEBHOOK_PING": "https://test.com/ping/$RUNPOD_POD_ID", + "RUNPOD_AI_API_KEY": "test_api_key", + "RUNPOD_POD_ID": "test_pod_id", + "RUNPOD_PING_INTERVAL": "5000" + }) + def test_start_ping_success(self, mock_session_class, mock_process_class): + """Test successful start_ping""" + # Reset the class variable + Heartbeat._process_started = False + + mock_process = MagicMock() + mock_process_class.return_value = mock_process + heartbeat = Heartbeat() - heartbeat.start_ping() - mock_logger.debug.assert_called_once_with( - "Not deployed on RunPod serverless, pings will not be sent." + heartbeat.start_ping(test=True) + + # Verify process was created correctly + mock_process_class.assert_called_once_with( + target=Heartbeat.process_loop, + args=(True,) ) - - @patch.dict(os.environ, {"RUNPOD_POD_ID": ""}) - @patch("runpod.serverless.modules.rp_ping.log") - def _test_start_ping_no_pod_id(self, mock_logger): - """Test start_ping method when RUNPOD_POD_ID is missing.""" + + # Verify daemon and start + assert mock_process.daemon is True + mock_process.start.assert_called_once() + + # Verify flag is set + assert Heartbeat._process_started is True + + def test_start_ping_already_started(self, mock_env, mock_worker_id, mock_session): + """Test start_ping when process is already started""" + Heartbeat._process_started = True + + with patch("multiprocessing.Process") as mock_process: + heartbeat = Heartbeat() + heartbeat.start_ping() + + # Process should not be created again + mock_process.assert_not_called() + + def test_process_loop(self, mock_env, mock_worker_id, mock_session): + """Test the process_loop static method""" + with patch.object(Heartbeat, 'ping_loop') as mock_ping_loop: + Heartbeat.process_loop(test=True) + + # Should create new instance and call ping_loop + mock_ping_loop.assert_called_once_with(True) + + def test_ping_loop_test_mode(self, mock_env, mock_worker_id, mock_session): + """Test ping_loop in test mode (single iteration)""" heartbeat = Heartbeat() - heartbeat.start_ping() - mock_logger.info.assert_called_once_with( - "Not running on RunPod, pings will not be sent." - ) - - @patch("runpod.serverless.modules.rp_ping.Heartbeat._send_ping") - def test_ping_loop(self, mock_send_ping): - """Test ping_loop runs and exits correctly in test mode.""" - heartbeat = rp_ping.Heartbeat() - heartbeat.ping_loop(test=True) - mock_send_ping.assert_called_once() - - @patch("runpod.serverless.modules.rp_ping.SyncClientSession.get") - async def test_send_ping(self, mock_get): - """Test _send_ping method sends the correct request.""" + + with patch.object(heartbeat, '_send_ping') as mock_send: + heartbeat.ping_loop(test=True) + + # Should send ping once and return + mock_send.assert_called_once() + + def test_ping_loop_continuous(self, mock_env, mock_worker_id, mock_session): + """Test ping_loop in continuous mode""" + heartbeat = Heartbeat() + + # Mock time.sleep to break the loop after 3 iterations + call_count = 0 + def side_effect(interval): + nonlocal call_count + call_count += 1 + if call_count >= 3: + raise KeyboardInterrupt() + + with patch.object(heartbeat, '_send_ping') as mock_send: + with patch('time.sleep', side_effect=side_effect): + with pytest.raises(KeyboardInterrupt): + heartbeat.ping_loop(test=False) + + # Should have sent 3 pings + assert mock_send.call_count == 3 + + def test_send_ping_success(self, mock_env, mock_worker_id, mock_session, mock_jobs, mock_logger): + """Test successful ping send""" + heartbeat = Heartbeat() + + # Mock successful response mock_response = MagicMock() - mock_response.url = "http://localhost/ping" + mock_response.url = "https://test.com/ping/test_worker_123" mock_response.status_code = 200 - mock_get.return_value = mock_response - - jobs = JobsProgress() - jobs.add("job1") - jobs.add("job2") + mock_session.get.return_value = mock_response + + # Mock version + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + # Verify request was made correctly + mock_session.get.assert_called_once_with( + "https://test.com/ping/test_worker_123", + params={"job_id": "job1,job2,job3", "runpod_version": "1.0.0"}, + timeout=10 # PING_INTERVAL * 2 + ) + + # Verify debug log + mock_logger.debug.assert_called_once() + def test_send_ping_no_jobs(self, mock_env, mock_worker_id, mock_session, mock_logger): + """Test ping send with no jobs""" heartbeat = Heartbeat() - heartbeat._send_ping() - - mock_get.assert_called_once() - - # Extract the arguments passed to the mock_get call - _, kwargs = mock_get.call_args - - # Check that job_id is correct in params, ignoring other params - assert 'params' in kwargs - assert 'job_id' in kwargs['params'] - assert kwargs['params']['job_id'] in ["job1,job2", "job2,job1"] + + # Mock no jobs + with patch("runpod.serverless.modules.rp_ping.jobs.get_job_list", return_value=None): + mock_response = MagicMock() + mock_response.url = "https://test.com/ping/test_worker_123" + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + # Verify request params + mock_session.get.assert_called_once_with( + "https://test.com/ping/test_worker_123", + params={"job_id": None, "runpod_version": "1.0.0"}, + timeout=10 + ) - @patch("runpod.serverless.modules.rp_ping.log") - def test_send_ping_exception(self, mock_logger): - """Test _send_ping logs an error on exception.""" + def test_send_ping_request_exception(self, mock_env, mock_worker_id, mock_session, mock_jobs, mock_logger): + """Test ping send with request exception""" heartbeat = Heartbeat() - - with patch.object( - heartbeat._session, - "get", - side_effect=requests.RequestException("Error"), - ): + + # Mock request exception + mock_session.get.side_effect = requests.RequestException("Connection error") + + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): heartbeat._send_ping() + + # Verify error was logged + mock_logger.error.assert_called_once_with( + "Ping Request Error: Connection error, attempting to restart ping." + ) - mock_logger.error.assert_called_once_with( - "Ping Request Error: Error, attempting to restart ping." - ) + def test_custom_pool_connections(self, mock_env, mock_worker_id, mock_session): + """Test initialization with custom pool connections and retries""" + heartbeat = Heartbeat(pool_connections=20, retries=5) + + # Should still initialize properly + assert heartbeat.PING_URL == "https://test.com/ping/test_worker_123" + + @patch("requests.adapters.HTTPAdapter") + def test_http_adapter_configuration(self, mock_adapter, mock_env, mock_worker_id, mock_session): + """Test that HTTP adapter is configured correctly""" + mock_adapter_instance = MagicMock() + mock_adapter.return_value = mock_adapter_instance + + Heartbeat(pool_connections=15, retries=4) + + # Verify adapter was created + assert mock_adapter.called + + # Verify it was called with expected pool settings + call_kwargs = mock_adapter.call_args[1] + assert call_kwargs['pool_connections'] == 15 + assert call_kwargs['pool_maxsize'] == 15 + assert 'max_retries' in call_kwargs + + # Verify adapter was mounted on both protocols + assert mock_session.mount.call_count == 2 + mock_session.mount.assert_any_call("http://", mock_adapter_instance) + mock_session.mount.assert_any_call("https://", mock_adapter_instance) From a6b30275f8b28d38b2cf4eaf188a8e91fa89f660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 11 Jun 2025 23:00:46 -0700 Subject: [PATCH 3/4] update: run_scale test supports new JobsProgress --- .../test_serverless/test_modules/run_scale.py | 114 +++++++++--------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/tests/test_serverless/test_modules/run_scale.py b/tests/test_serverless/test_modules/run_scale.py index 5983c7a6..1310e463 100644 --- a/tests/test_serverless/test_modules/run_scale.py +++ b/tests/test_serverless/test_modules/run_scale.py @@ -3,61 +3,61 @@ from faker import Faker from typing import Any, Dict, Optional, List -from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger, JobsProgress -fake = Faker() -log = RunPodLogger() -job_progress = JobsProgress() - - -# Change this number to your desired concurrency -start = 1 - - -# sample concurrency modifier that loops -def collatz_conjecture(current_concurrency): - if current_concurrency == 1: - return start - - if current_concurrency % 2 == 0: - return math.floor(current_concurrency / 2) - else: - return current_concurrency * 3 + 1 - - -def fake_job(): - # Change this number to your desired delay - delay = fake.random_digit_above_two() - return { - "id": fake.uuid4(), - "input": fake.sentence(), - "mock_delay": delay, - } - - -async def fake_get_job(session, num_jobs: int = 1) -> Optional[List[Dict[str, Any]]]: - # Change this number to your desired delay - delay = fake.random_digit_above_two() - 1 - - log.info(f"... artificial delay ({delay}s)") - await asyncio.sleep(delay) # Simulates a blocking process - - jobs = [fake_job() for _ in range(num_jobs)] - log.info(f"... Generated # jobs: {len(jobs)}") - return jobs - - -async def fake_handle_job(session, config, job) -> dict: - await asyncio.sleep(job["mock_delay"]) # Simulates a blocking process - log.info(f"... Job handled ({job['mock_delay']}s)", job["id"]) - - -job_scaler = JobScaler( - { - # "concurrency_modifier": collatz_conjecture, - # "jobs_fetcher_timeout": 5, - "jobs_fetcher": fake_get_job, - "jobs_handler": fake_handle_job, - } -) -job_scaler.start() +def main(start=1): + """Main function to run the job scaler""" + from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger + + fake = Faker() + log = RunPodLogger() + + # sample concurrency modifier that loops + def collatz_conjecture(current_concurrency): + if current_concurrency == 1: + return start + + if current_concurrency % 2 == 0: + return math.floor(current_concurrency / 2) + else: + return current_concurrency * 3 + 1 + + def fake_job(): + # Change this number to your desired delay + delay = fake.random_digit_above_two() + return { + "id": fake.uuid4(), + "input": fake.sentence(), + "mock_delay": delay, + } + + async def fake_get_job(session, num_jobs: int = 1) -> Optional[List[Dict[str, Any]]]: + # Change this number to your desired delay + delay = fake.random_digit_above_two() - 1 + + log.info(f"... artificial delay ({delay}s)") + await asyncio.sleep(delay) # Simulates a blocking process + + jobs = [fake_job() for _ in range(num_jobs)] + log.info(f"... Generated # jobs: {len(jobs)}") + return jobs + + async def fake_handle_job(session, config, job) -> dict: + await asyncio.sleep(job["mock_delay"]) # Simulates a blocking process + log.info(f"... Job handled ({job['mock_delay']}s)", job["id"]) + + job_scaler = JobScaler( + { + "concurrency_modifier": collatz_conjecture, + # "jobs_fetcher_timeout": 5, + "jobs_fetcher": fake_get_job, + "jobs_handler": fake_handle_job, + } + ) + job_scaler.start() + + +if __name__ == '__main__': + # This is required for multiprocessing on macOS/Windows + import multiprocessing + multiprocessing.set_start_method('spawn', force=True) + main(start=10) From eefe8f3b330d69cf263c00bfd9dfe5a9d8a8ffe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dean=20Qui=C3=B1anola?= Date: Wed, 11 Jun 2025 23:56:04 -0700 Subject: [PATCH 4/4] fix: broken tests --- tests/test_cli/test_cli_groups/test_config_commands.py | 8 +++++++- tests/test_cli/test_cli_groups/test_exec_commands.py | 7 +++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_cli/test_cli_groups/test_config_commands.py b/tests/test_cli/test_cli_groups/test_config_commands.py index d9dbe3ba..70959cc0 100644 --- a/tests/test_cli/test_cli_groups/test_config_commands.py +++ b/tests/test_cli/test_cli_groups/test_config_commands.py @@ -100,9 +100,15 @@ def test_output_messages(self): def test_api_key_prompt(self): """Tests the API key prompt.""" - with patch("click.prompt", return_value="KEY") as mock_prompt: + with patch("click.prompt", return_value="KEY") as mock_prompt, patch( + "runpod.cli.groups.config.commands.set_credentials" + ) as mock_set_credentials, patch( + "runpod.cli.groups.config.commands.check_credentials", + return_value=(False, None) + ): result = self.runner.invoke(runpod_cli, ["config", "--profile", "test"]) mock_prompt.assert_called_with( " > RunPod API Key", hide_input=False, confirmation_prompt=False ) # pylint: disable=line-too-long + mock_set_credentials.assert_called_with("KEY", "test", overwrite=True) assert result.exit_code == 0 diff --git a/tests/test_cli/test_cli_groups/test_exec_commands.py b/tests/test_cli/test_cli_groups/test_exec_commands.py index 9b04528e..d8a0edf6 100644 --- a/tests/test_cli/test_cli_groups/test_exec_commands.py +++ b/tests/test_cli/test_cli_groups/test_exec_commands.py @@ -7,14 +7,15 @@ from unittest.mock import patch import click +from click.testing import CliRunner from runpod.cli.entry import runpod_cli class TestExecCommands(unittest.TestCase): """Tests for Runpod CLI exec commands.""" - def setUp(self): + self.runner = CliRunner() self.runner = click.testing.CliRunner() def test_remote_python_with_provided_pod_id(self): @@ -49,9 +50,7 @@ def test_remote_python_without_provided_pod_id_prompt(self): "runpod.cli.groups.exec.commands.python_over_ssh" ) as mock_python_over_ssh, patch( "runpod.cli.groups.exec.commands.get_session_pod", - side_effect=lambda: click.prompt( - "Please provide the pod ID", "prompted_pod_id" - ), + return_value="prompted_pod_id", ) as mock_get_pod_id: # pylint: disable=line-too-long mock_python_over_ssh.return_value = None result = self.runner.invoke(runpod_cli, ["exec", "python", temp_file.name])