Skip to content

Commit 6a07787

Browse files
Merge pull request #227 from DireLines/alignSdk
Align sdk skeleton
2 parents 34b0a5f + 295c0f2 commit 6a07787

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

runpod/endpoint/runner.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
"An API key can be generated at "
1414
"https://runpod.io/console/user/settings")
1515

16-
16+
def is_completed(status:str)->bool:
17+
"""Returns true if status is one of the possible final states for a serverless request."""
18+
return status in ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"]
1719
# ---------------------------------------------------------------------------- #
1820
# Client #
1921
# ---------------------------------------------------------------------------- #
@@ -107,7 +109,7 @@ def _fetch_job(self):
107109
status_url = f"{self.endpoint_id}/status/{self.job_id}"
108110
job_state = self.rp_client.get(endpoint=status_url)
109111

110-
if job_state["status"] in ["COMPLETED", "FAILED", "TIMEOUT"]:
112+
if is_completed(job_state["status"]):
111113
self.job_status = job_state["status"]
112114
self.job_output = job_state.get("output", None)
113115

@@ -128,7 +130,7 @@ def output(self, timeout: int = 0) -> Any:
128130
timeout: The number of seconds to wait for the server to send data before giving up.
129131
"""
130132
if timeout > 0:
131-
while self.status() not in ["COMPLETED", "FAILED", "TIMEOUT"]:
133+
while not is_completed(self.status()):
132134
time.sleep(1)
133135
timeout -= 1
134136
if timeout <= 0:
@@ -139,6 +141,17 @@ def output(self, timeout: int = 0) -> Any:
139141

140142
return self._fetch_job().get("output", None)
141143

144+
def cancel(self, timeout: int = 3) -> Any:
145+
"""
146+
Cancels the job and returns the result of the cancellation request.
147+
148+
Args:
149+
timeout: The number of seconds to wait for the server to respond before giving up.
150+
"""
151+
return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}",
152+
data=None,timeout=timeout)
153+
154+
142155

143156
# ---------------------------------------------------------------------------- #
144157
# Endpoint #
@@ -195,3 +208,20 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[
195208
return job_request.get("output", None)
196209

197210
return Job(self.endpoint_id, job_request["id"], self.rp_client).output(timeout=timeout)
211+
212+
def health(self,timeout: int = 3) -> Dict[str, Any]:
213+
"""
214+
Check the health of the endpoint (number/state of workers, number/state of requests).
215+
216+
Args:
217+
timeout: The number of seconds to wait for the server to respond before giving up.
218+
"""
219+
return self.rp_client.get(f"{self.endpoint_id}/health",timeout=timeout)
220+
def purge_queue(self,timeout: int = 3) -> Dict[str, Any]:
221+
"""
222+
Purges the endpoint's job queue and returns the result of the purge request.
223+
224+
Args:
225+
timeout: The number of seconds to wait for the server to respond before giving up.
226+
"""
227+
return self.rp_client.post(f"{self.endpoint_id}/purge-queue",data=None,timeout=timeout)

tests/test_endpoint/test_runner.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,23 @@ def test_endpoint_run_sync(self, mock_client_request):
132132
{'input': {'YOUR_MODEL_INPUT_JSON': 'YOUR_MODEL_INPUT_VALUE'}}, 86400
133133
)
134134

135+
@patch('runpod.endpoint.runner.RunPodClient._request')
136+
def test_endpoint_health(self, mock_client_request):
137+
''' Test the health method of Endpoint '''
138+
self.endpoint.health()
139+
140+
mock_client_request.assert_called_once_with('GET', f"{self.ENDPOINT_ID}/health", timeout=3)
141+
142+
@patch('runpod.endpoint.runner.RunPodClient._request')
143+
def test_endpoint_purge_queue(self, mock_client_request):
144+
''' Test the health method of Endpoint '''
145+
self.endpoint.purge_queue()
146+
147+
mock_client_request.assert_called_once_with(
148+
'POST', f"{self.ENDPOINT_ID}/purge-queue",
149+
None, 3
150+
)
151+
135152
def test_missing_api_key(self):
136153
'''
137154
Tests Endpoint.run without api_key
@@ -213,6 +230,7 @@ def test_run_sync_with_timeout(self, mock_client_request):
213230

214231
class TestJob(unittest.TestCase):
215232
''' Tests for Job '''
233+
MODEL_OUTPUT = {"result": "YOUR_MODEL_OUTPUT_VALUE"}
216234

217235
@patch('runpod.endpoint.runner.RunPodClient')
218236
def test_status(self, mock_client):
@@ -274,6 +292,17 @@ def test_output_timeout(self, mock_client):
274292
with self.assertRaises(TimeoutError):
275293
job.output(timeout=1)
276294

295+
@patch('runpod.endpoint.runner.RunPodClient')
296+
def test_cancel(self, mock_client):
297+
''' Test the cancel method of Job with a successful job initiation. '''
298+
job = runner.Job("endpoint_id", "job_id", mock_client)
299+
300+
job.cancel()
301+
302+
mock_client.post.assert_called_with("endpoint_id/cancel/job_id",
303+
data=None, timeout=3)
304+
305+
277306
@patch('runpod.endpoint.runner.RunPodClient')
278307
def test_job_status(self, mock_client):
279308
'''

0 commit comments

Comments
 (0)