Skip to content

Commit 2222320

Browse files
committed
add: streaming
1 parent 34ac99c commit 2222320

File tree

3 files changed

+49
-6
lines changed

3 files changed

+49
-6
lines changed

examples/endpoints/health.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import runpod
2+
3+
# Set your global API key with `runpod config` or uncomment the line below:
4+
# runpod.api_key = "YOUR_RUNPOD_API_KEY"
5+
6+
endpoint = runpod.Endpoint("gwp4kx5yd3nur1")
7+
8+
endpoint_health = endpoint.health()
9+
10+
print(endpoint_health)

examples/endpoints/streaming.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import runpod
2+
3+
# Set your global API key with `runpod config` or uncomment the line below:
4+
# runpod.api_key = "YOUR_RUNPOD_API_KEY"
5+
6+
endpoint = runpod.Endpoint("gwp4kx5yd3nur1")
7+
8+
run_request = endpoint.run({
9+
"input": {
10+
"mock_return": ["a", "b", "c", "d", "e", "f", "g"],
11+
"mock_delay": 1,
12+
}
13+
})
14+
15+
for output in run_request.stream():
16+
print(output)

runpod/endpoint/runner.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import requests
77
from requests.adapters import HTTPAdapter, Retry
88

9+
FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT"]
10+
911
# Exception Messages
1012
UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid."
1113
API_KEY_NOT_SET_MSG = ("Expected `run_pod.api_key` to be initialized. "
@@ -33,7 +35,7 @@ def __init__(self):
3335
raise RuntimeError(API_KEY_NOT_SET_MSG)
3436

3537
self.rp_session = requests.Session()
36-
retries = Retry(total=5, backoff_factor=1, status_forcelist=[429])
38+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[408, 429])
3739
self.rp_session.mount('http://', HTTPAdapter(max_retries=retries))
3840

3941
self.headers = {
@@ -102,12 +104,12 @@ def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient):
102104
self.job_status = None
103105
self.job_output = None
104106

105-
def _fetch_job(self):
107+
def _fetch_job(self, source: str = "status") -> Dict[str, Any]:
106108
""" Returns the raw json of the status, raises an exception if invalid """
107-
status_url = f"{self.endpoint_id}/status/{self.job_id}"
109+
status_url = f"{self.endpoint_id}/{source}/{self.job_id}"
108110
job_state = self.rp_client.get(endpoint=status_url)
109111

110-
if job_state["status"] in ["COMPLETED", "FAILED", "TIMEOUT"]:
112+
if job_state["status"] in FINAL_STATES:
111113
self.job_status = job_state["status"]
112114
self.job_output = job_state.get("output", None)
113115

@@ -128,7 +130,7 @@ def output(self, timeout: int = 0) -> Any:
128130
timeout: The number of seconds to wait for the server to send data before giving up.
129131
"""
130132
if timeout > 0:
131-
while self.status() not in ["COMPLETED", "FAILED", "TIMEOUT"]:
133+
while self.status() not in FINAL_STATES:
132134
time.sleep(1)
133135
timeout -= 1
134136
if timeout <= 0:
@@ -139,6 +141,17 @@ def output(self, timeout: int = 0) -> Any:
139141

140142
return self._fetch_job().get("output", None)
141143

144+
def stream(self) -> Any:
145+
""" Returns a generator that yields the output of the job request. """
146+
while True:
147+
time.sleep(1)
148+
stream_partial = self._fetch_job(source="stream")
149+
if stream_partial["status"] not in FINAL_STATES or len(stream_partial["stream"]) > 0:
150+
for chunk in stream_partial.get("stream", []):
151+
yield chunk["output"]
152+
elif stream_partial["status"] in FINAL_STATES:
153+
break
154+
142155

143156
# ---------------------------------------------------------------------------- #
144157
# Endpoint #
@@ -191,7 +204,11 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[
191204
job_request = self.rp_client.post(
192205
f"{self.endpoint_id}/runsync", request_input, timeout=timeout)
193206

194-
if job_request["status"] in ["COMPLETED", "FAILED", "TIMEOUT"]:
207+
if job_request["status"] in FINAL_STATES:
195208
return job_request.get("output", None)
196209

197210
return Job(self.endpoint_id, job_request["id"], self.rp_client).output(timeout=timeout)
211+
212+
def health(self) -> Dict[str, Any]:
213+
""" Returns the health of the endpoint. """
214+
return self.rp_client.get(f"{self.endpoint_id}/health")

0 commit comments

Comments
 (0)