Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Change Log

## Release 1.3.8 (12/1/23)

### Added

- Stream support for calling endpoints.

---

## Release 1.3.7 (11/29/23)

### Fixed
Expand Down
12 changes: 12 additions & 0 deletions examples/endpoints/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
""" Example of getting the health of an endpoint. """

import runpod

# Set your global API key with `runpod config` or uncomment the line below:
# runpod.api_key = "YOUR_RUNPOD_API_KEY"

endpoint = runpod.Endpoint("gwp4kx5yd3nur1")

endpoint_health = endpoint.health()

print(endpoint_health)
18 changes: 18 additions & 0 deletions examples/endpoints/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
""" Example of streaming data from an endpoint. """

import runpod

# Set your global API key with `runpod config` or uncomment the line below:
# runpod.api_key = "YOUR_RUNPOD_API_KEY"

endpoint = runpod.Endpoint("gwp4kx5yd3nur1")

run_request = endpoint.run({
"input": {
"mock_return": ["a", "b", "c", "d", "e", "f", "g"],
"mock_delay": 1,
}
})

for output in run_request.stream():
print(output)
36 changes: 26 additions & 10 deletions runpod/endpoint/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
import requests
from requests.adapters import HTTPAdapter, Retry

FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT"]

# Exception Messages
UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid."
API_KEY_NOT_SET_MSG = ("Expected `run_pod.api_key` to be initialized. "
"You can solve this by setting `run_pod.api_key = 'your-key'. "
"An API key can be generated at "
"https://runpod.io/console/user/settings")

def is_completed(status:str)->bool:

def is_completed(status: str) -> bool:
"""Returns true if status is one of the possible final states for a serverless request."""
return status in ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"]


# ---------------------------------------------------------------------------- #
# Client #
# ---------------------------------------------------------------------------- #
Expand All @@ -35,7 +40,7 @@ def __init__(self):
raise RuntimeError(API_KEY_NOT_SET_MSG)

self.rp_session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[429])
retries = Retry(total=5, backoff_factor=1, status_forcelist=[408, 429])
self.rp_session.mount('http://', HTTPAdapter(max_retries=retries))

self.headers = {
Expand Down Expand Up @@ -104,9 +109,9 @@ def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient):
self.job_status = None
self.job_output = None

def _fetch_job(self):
def _fetch_job(self, source: str = "status") -> Dict[str, Any]:
""" Returns the raw json of the status, raises an exception if invalid """
status_url = f"{self.endpoint_id}/status/{self.job_id}"
status_url = f"{self.endpoint_id}/{source}/{self.job_id}"
job_state = self.rp_client.get(endpoint=status_url)

if is_completed(job_state["status"]):
Expand Down Expand Up @@ -149,8 +154,18 @@ def cancel(self, timeout: int = 3) -> Any:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}",
data=None,timeout=timeout)
data=None, timeout=timeout)

def stream(self) -> Any:
""" Returns a generator that yields the output of the job request. """
while True:
time.sleep(1)
stream_partial = self._fetch_job(source="stream")
if stream_partial["status"] not in FINAL_STATES or len(stream_partial["stream"]) > 0:
for chunk in stream_partial.get("stream", []):
yield chunk["output"]
elif stream_partial["status"] in FINAL_STATES:
break


# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -204,24 +219,25 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[
job_request = self.rp_client.post(
f"{self.endpoint_id}/runsync", request_input, timeout=timeout)

if job_request["status"] in ["COMPLETED", "FAILED", "TIMEOUT"]:
if job_request["status"] in FINAL_STATES:
return job_request.get("output", None)

return Job(self.endpoint_id, job_request["id"], self.rp_client).output(timeout=timeout)

def health(self,timeout: int = 3) -> Dict[str, Any]:
def health(self, timeout: int = 3) -> Dict[str, Any]:
"""
Check the health of the endpoint (number/state of workers, number/state of requests).

Args:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.get(f"{self.endpoint_id}/health",timeout=timeout)
def purge_queue(self,timeout: int = 3) -> Dict[str, Any]:
return self.rp_client.get(f"{self.endpoint_id}/health", timeout=timeout)

def purge_queue(self, timeout: int = 3) -> Dict[str, Any]:
"""
Purges the endpoint's job queue and returns the result of the purge request.

Args:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.post(f"{self.endpoint_id}/purge-queue",data=None,timeout=timeout)
return self.rp_client.post(f"{self.endpoint_id}/purge-queue", data=None, timeout=timeout)
26 changes: 24 additions & 2 deletions tests/test_endpoint/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,14 @@ def test_output_timeout(self, mock_client):

@patch('runpod.endpoint.runner.RunPodClient')
def test_cancel(self, mock_client):
''' Test the cancel method of Job with a successful job initiation. '''
''' Test the cancel method of Job with a successful job initiation. '''
job = runner.Job("endpoint_id", "job_id", mock_client)

job.cancel()

mock_client.post.assert_called_with("endpoint_id/cancel/job_id",
data=None, timeout=3)


@patch('runpod.endpoint.runner.RunPodClient')
def test_job_status(self, mock_client):
'''
Expand All @@ -321,3 +320,26 @@ def test_job_status(self, mock_client):
self.assertEqual(job.status(), "IN_PROGRESS")
self.assertEqual(job.status(), "COMPLETED")
self.assertEqual(job.status(), "COMPLETED")

@patch('runpod.endpoint.runner.RunPodClient')
def test_job_stream(self, mock_client):
'''
Tests Job.stream
'''
mock_client.get.side_effect = [
{
"status": "IN_PROGRESS",
"stream": [
{"output": "Job output 1"},
{"output": "Job output 2"}
]
},
{
"status": "COMPLETED",
"stream": []
}
]

job = runner.Job("endpoint_id", "job_id", mock_client)
output = list(job.stream())
self.assertEqual(output, ['Job output 1', 'Job output 2'])