@@ -51,7 +51,7 @@ def huggingface_api_completions(
51
51
token = constants .HUGGINGFACEHUB_API_TOKEN ,
52
52
)
53
53
54
- default_kwargs = dict (do_sample = do_sample , options = dict ( wait_for_model = True ), return_full_text = False )
54
+ default_kwargs = dict (do_sample = do_sample , return_full_text = False )
55
55
default_kwargs .update (kwargs )
56
56
logging .info (f"Kwargs to completion: { default_kwargs } " )
57
57
@@ -72,32 +72,33 @@ def huggingface_api_completions(
72
72
)
73
73
logging .info (f"Time for { n_examples } completions: { t } " )
74
74
75
- completions = [completion ["generated_text" ] for completion in completions ]
76
-
77
75
# unclear pricing
78
76
price = [np .nan ] * len (completions )
79
77
avg_time = [t .duration / n_examples ] * len (completions )
80
78
81
79
return dict (completions = completions , price_per_example = price , time_per_example = avg_time )
82
80
83
81
84
- def inference_helper (prompt : str , inference , params , n_retries = 100 , waiting_time = 2 ) -> dict :
82
+ def inference_helper (prompt : str , inference , params , n_retries = 100 , waiting_time = 2 ) -> str :
85
83
for _ in range (n_retries ):
86
- output = inference (inputs = prompt , params = params )
87
- if "error" in output and n_retries > 0 :
88
- error = output ["error" ]
89
- if "Rate limit reached" in output ["error" ]:
90
- logging .warning (f"Rate limit reached... Trying again in { waiting_time } seconds. Full error: { error } " )
91
- time .sleep (waiting_time )
92
- elif "Input validation error" in error and "max_new_tokens" in error :
93
- params ["max_new_tokens" ] = int (params ["max_new_tokens" ] * 0.8 )
94
- logging .warning (
95
- f"`max_new_tokens` too large. Reducing target length to { params ['max_new_tokens' ]} , " f"Retrying..."
96
- )
97
- if params ["max_new_tokens" ] == 0 :
84
+ try :
85
+ # TODO: check why doesn't stop after </s>
86
+ output = inference (prompt = prompt , ** params )
87
+ except Exception as error :
88
+ if n_retries > 0 :
89
+ if "Rate limit reached" in error :
90
+ logging .warning (f"Rate limit reached... Trying again in { waiting_time } seconds." )
91
+ time .sleep (waiting_time )
92
+ elif "Input validation error" in error and "max_new_tokens" in error :
93
+ params ["max_new_tokens" ] = int (params ["max_new_tokens" ] * 0.8 )
94
+ logging .warning (
95
+ f"`max_new_tokens` too large. Reducing target length to { params ['max_new_tokens' ]} , "
96
+ f"Retrying..."
97
+ )
98
+ if params ["max_new_tokens" ] == 0 :
99
+ raise ValueError (f"Error in inference. Full error: { error } " )
100
+ else :
98
101
raise ValueError (f"Error in inference. Full error: { error } " )
99
102
else :
100
- raise ValueError (f"Error in inference. Full error: { error } " )
101
- else :
102
- return output [0 ]
103
- raise ValueError (f"Error in inference. We tried { n_retries } times and failed." )
103
+ raise ValueError (f"Error in inference. We tried { n_retries } times and failed. Full error: { error } " )
104
+ return output
0 commit comments