@@ -383,6 +383,15 @@ class CreatePredictionParams(TypedDict):
383
383
stream : NotRequired [bool ]
384
384
"""Enable streaming of prediction output."""
385
385
386
+ wait : NotRequired [Union [int , bool ]]
387
+ """
388
+ Wait until the prediction is completed before returning.
389
+
390
+ If `True`, wait a predetermined number of seconds until the prediction
391
+ is completed before returning.
392
+ If an `int`, wait for the specified number of seconds.
393
+ """
394
+
386
395
file_encoding_strategy : NotRequired [FileEncodingStrategy ]
387
396
"""The strategy to use for encoding files in the prediction input."""
388
397
@@ -463,6 +472,7 @@ def create( # type: ignore
463
472
client = self ._client ,
464
473
file_encoding_strategy = file_encoding_strategy ,
465
474
)
475
+ headers = _create_prediction_headers (wait = params .pop ("wait" , None ))
466
476
body = _create_prediction_body (
467
477
version ,
468
478
input ,
@@ -472,6 +482,7 @@ def create( # type: ignore
472
482
resp = self ._client ._request (
473
483
"POST" ,
474
484
"/v1/predictions" ,
485
+ headers = headers ,
475
486
json = body ,
476
487
)
477
488
@@ -554,6 +565,7 @@ async def async_create( # type: ignore
554
565
client = self ._client ,
555
566
file_encoding_strategy = file_encoding_strategy ,
556
567
)
568
+ headers = _create_prediction_headers (wait = params .pop ("wait" , None ))
557
569
body = _create_prediction_body (
558
570
version ,
559
571
input ,
@@ -563,6 +575,7 @@ async def async_create( # type: ignore
563
575
resp = await self ._client ._async_request (
564
576
"POST" ,
565
577
"/v1/predictions" ,
578
+ headers = headers ,
566
579
json = body ,
567
580
)
568
581
@@ -603,6 +616,20 @@ async def async_cancel(self, id: str) -> Prediction:
603
616
return _json_to_prediction (self ._client , resp .json ())
604
617
605
618
619
+ def _create_prediction_headers (
620
+ * ,
621
+ wait : Optional [Union [int , bool ]] = None ,
622
+ ) -> Dict [str , Any ]:
623
+ headers = {}
624
+
625
+ if wait :
626
+ if isinstance (wait , bool ):
627
+ headers ["Prefer" ] = "wait"
628
+ elif isinstance (wait , int ):
629
+ headers ["Prefer" ] = f"wait={ wait } "
630
+ return headers
631
+
632
+
606
633
def _create_prediction_body ( # pylint: disable=too-many-arguments
607
634
version : Optional [Union [Version , str ]],
608
635
input : Optional [Dict [str , Any ]],
0 commit comments