@@ -132,10 +132,10 @@ def openai_completions(
132
132
133
133
prompt_batches = [prompts [batch_id * batch_size : (batch_id + 1 ) * batch_size ] for batch_id in range (n_batches )]
134
134
135
- if isinstance ( max_tokens , int ) :
136
- max_tokens = [ max_tokens ] * n_examples
137
-
138
- inputs = zip (prompt_batches , max_tokens )
135
+ try :
136
+ inputs = zip ( prompt_batches , max_tokens )
137
+ except TypeError :
138
+ inputs = zip (prompt_batches , [ max_tokens ] * n_batches )
139
139
140
140
kwargs = dict (model = model_name , ** decoding_kwargs )
141
141
kwargs_to_log = {k : v for k , v in kwargs .items () if "api_key" not in k }
@@ -216,7 +216,11 @@ def _openai_completion_helper(
216
216
client = all_clients [curr_client_idx ]
217
217
218
218
# copy shared_kwargs to avoid modifying it
219
- kwargs .update (dict (max_tokens = max_tokens , top_p = top_p , temperature = temperature ))
219
+ to_update = dict ()
220
+ for k in ["max_tokens" , "top_p" , "temperature" ]:
221
+ if locals ()[k ] is not None :
222
+ to_update [k ] = locals ()[k ]
223
+ kwargs .update (to_update )
220
224
curr_kwargs = copy .deepcopy (kwargs )
221
225
222
226
# ensure no infinite loop
@@ -242,7 +246,7 @@ def _openai_completion_helper(
242
246
# currently we only use function calls to get a JSON object => return raw text of json
243
247
choices [i ]["text" ] = choice .message .function_call .arguments
244
248
245
- if choice .message .tool_calls is not None :
249
+ if choice .message .tool_calls :
246
250
# currently we only use function calls to get a JSON object => return raw text of json
247
251
choices [i ]["text" ] = choice .message .tool_calls [0 ].function .arguments
248
252
@@ -273,10 +277,11 @@ def _openai_completion_helper(
273
277
return choices
274
278
275
279
else :
276
- if "rate limit " in str (e ).lower ():
280
+ if "rate " in str (e ).lower ():
277
281
logging .warning (f"Hit request rate limit; retrying..." )
278
282
else :
279
- logging .warning (f"Unknown error. \n It's likely a rate limit so we are retrying..." )
283
+ logging .exception ("Unknown error:" )
284
+ raise e
280
285
if len (all_clients ) > 1 :
281
286
curr_client_idx = random .choice ([idx for idx in client_idcs if idx != curr_client_idx ])
282
287
client = all_clients [curr_client_idx ]
0 commit comments