66import requests
77from requests .adapters import HTTPAdapter , Retry
88
9+ FINAL_STATES = ["COMPLETED" , "FAILED" , "TIMED_OUT" ]
10+
911# Exception Messages
1012UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid."
1113API_KEY_NOT_SET_MSG = ("Expected `run_pod.api_key` to be initialized. "
@@ -33,7 +35,7 @@ def __init__(self):
3335 raise RuntimeError (API_KEY_NOT_SET_MSG )
3436
3537 self .rp_session = requests .Session ()
36- retries = Retry (total = 5 , backoff_factor = 1 , status_forcelist = [429 ])
38+ retries = Retry (total = 5 , backoff_factor = 1 , status_forcelist = [408 , 429 ])
3739 self .rp_session .mount ('http://' , HTTPAdapter (max_retries = retries ))
3840
3941 self .headers = {
@@ -102,12 +104,12 @@ def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient):
102104 self .job_status = None
103105 self .job_output = None
104106
105- def _fetch_job (self ) :
107+ def _fetch_job (self , source : str = "status" ) -> Dict [ str , Any ] :
106108 """ Returns the raw json of the status, raises an exception if invalid """
107- status_url = f"{ self .endpoint_id } /status /{ self .job_id } "
109+ status_url = f"{ self .endpoint_id } /{ source } /{ self .job_id } "
108110 job_state = self .rp_client .get (endpoint = status_url )
109111
110- if job_state ["status" ] in [ "COMPLETED" , "FAILED" , "TIMEOUT" ] :
112+ if job_state ["status" ] in FINAL_STATES :
111113 self .job_status = job_state ["status" ]
112114 self .job_output = job_state .get ("output" , None )
113115
@@ -128,7 +130,7 @@ def output(self, timeout: int = 0) -> Any:
128130 timeout: The number of seconds to wait for the server to send data before giving up.
129131 """
130132 if timeout > 0 :
131- while self .status () not in [ "COMPLETED" , "FAILED" , "TIMEOUT" ] :
133+ while self .status () not in FINAL_STATES :
132134 time .sleep (1 )
133135 timeout -= 1
134136 if timeout <= 0 :
@@ -139,6 +141,17 @@ def output(self, timeout: int = 0) -> Any:
139141
140142 return self ._fetch_job ().get ("output" , None )
141143
144+ def stream (self ) -> Any :
145+ """ Returns a generator that yields the output of the job request. """
146+ while True :
147+ time .sleep (1 )
148+ stream_partial = self ._fetch_job (source = "stream" )
149+ if stream_partial ["status" ] not in FINAL_STATES or len (stream_partial ["stream" ]) > 0 :
150+ for chunk in stream_partial .get ("stream" , []):
151+ yield chunk ["output" ]
152+ elif stream_partial ["status" ] in FINAL_STATES :
153+ break
154+
142155
143156# ---------------------------------------------------------------------------- #
144157# Endpoint #
@@ -191,7 +204,11 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[
191204 job_request = self .rp_client .post (
192205 f"{ self .endpoint_id } /runsync" , request_input , timeout = timeout )
193206
194- if job_request ["status" ] in [ "COMPLETED" , "FAILED" , "TIMEOUT" ] :
207+ if job_request ["status" ] in FINAL_STATES :
195208 return job_request .get ("output" , None )
196209
197210 return Job (self .endpoint_id , job_request ["id" ], self .rp_client ).output (timeout = timeout )
211+
212+ def health (self ) -> Dict [str , Any ]:
213+ """ Returns the health of the endpoint. """
214+ return self .rp_client .get (f"{ self .endpoint_id } /health" )
0 commit comments